Refactor interpreter/AOT module instance layout (#1559)

Refactor the layout of interpreter and AOT module instance:
- Unify the interp/AOT module instance, use the same WASMModuleInstance/
  WASMMemoryInstance/WASMTableInstance data structures for both interpreter
  and AOT
- Make the offset of most fields the same in module instance for both interpreter
  and AOT, append memory instance structure, global data and table instances to
  the end of module instance for interpreter mode (like AOT mode)
- For extra fields in WASM module instance, use WASMModuleInstanceExtra to
  create a field `e` for interpreter
- Change the LLVM JIT module instance creating process, LLVM JIT uses the WASM
  module and module instance same as interpreter/Fast-JIT mode. So that Fast JIT
  and LLVM JIT can access the same data structures, and make it possible to
  implement the Multi-tier JIT (tier-up from Fast JIT to LLVM JIT) in the future
- Unify some APIs: merge some APIs for module instance and memory instance's
  related operations (only implement one copy)

Note that the AOT ABI is same, the AOT file format, AOT relocation types, how AOT
code accesses the AOT module instance and so on are kept unchanged.

Refer to:
https://github.com/bytecodealliance/wasm-micro-runtime/issues/1384
This commit is contained in:
Wenyong Huang
2022-10-18 10:59:28 +08:00
committed by GitHub
parent dc4dcc3d6f
commit a182926a73
49 changed files with 3790 additions and 3274 deletions

View File

@ -10,10 +10,6 @@
#include "../common/wasm_runtime_common.h"
#include "../common/wasm_native.h"
#include "../compilation/aot.h"
#if WASM_ENABLE_JIT != 0
#include "../compilation/aot_llvm.h"
#include "../interpreter/wasm_loader.h"
#endif
#if WASM_ENABLE_DEBUG_AOT != 0
#include "debug/elf_parser.h"
@ -716,23 +712,19 @@ fail:
}
static void
destroy_import_memories(AOTImportMemory *import_memories, bool is_jit_mode)
destroy_import_memories(AOTImportMemory *import_memories)
{
if (!is_jit_mode)
wasm_runtime_free(import_memories);
wasm_runtime_free(import_memories);
}
static void
destroy_mem_init_data_list(AOTMemInitData **data_list, uint32 count,
bool is_jit_mode)
destroy_mem_init_data_list(AOTMemInitData **data_list, uint32 count)
{
if (!is_jit_mode) {
uint32 i;
for (i = 0; i < count; i++)
if (data_list[i])
wasm_runtime_free(data_list[i]);
wasm_runtime_free(data_list);
}
uint32 i;
for (i = 0; i < count; i++)
if (data_list[i])
wasm_runtime_free(data_list[i]);
wasm_runtime_free(data_list);
}
static bool
@ -828,30 +820,25 @@ fail:
}
static void
destroy_import_tables(AOTImportTable *import_tables, bool is_jit_mode)
destroy_import_tables(AOTImportTable *import_tables)
{
if (!is_jit_mode)
wasm_runtime_free(import_tables);
wasm_runtime_free(import_tables);
}
static void
destroy_tables(AOTTable *tables, bool is_jit_mode)
destroy_tables(AOTTable *tables)
{
if (!is_jit_mode)
wasm_runtime_free(tables);
wasm_runtime_free(tables);
}
static void
destroy_table_init_data_list(AOTTableInitData **data_list, uint32 count,
bool is_jit_mode)
destroy_table_init_data_list(AOTTableInitData **data_list, uint32 count)
{
if (!is_jit_mode) {
uint32 i;
for (i = 0; i < count; i++)
if (data_list[i])
wasm_runtime_free(data_list[i]);
wasm_runtime_free(data_list);
}
uint32 i;
for (i = 0; i < count; i++)
if (data_list[i])
wasm_runtime_free(data_list[i]);
wasm_runtime_free(data_list);
}
static bool
@ -1003,15 +990,13 @@ fail:
}
static void
destroy_func_types(AOTFuncType **func_types, uint32 count, bool is_jit_mode)
destroy_func_types(AOTFuncType **func_types, uint32 count)
{
if (!is_jit_mode) {
uint32 i;
for (i = 0; i < count; i++)
if (func_types[i])
wasm_runtime_free(func_types[i]);
wasm_runtime_free(func_types);
}
uint32 i;
for (i = 0; i < count; i++)
if (func_types[i])
wasm_runtime_free(func_types[i]);
wasm_runtime_free(func_types);
}
static bool
@ -1094,10 +1079,9 @@ fail:
}
static void
destroy_import_globals(AOTImportGlobal *import_globals, bool is_jit_mode)
destroy_import_globals(AOTImportGlobal *import_globals)
{
if (!is_jit_mode)
wasm_runtime_free(import_globals);
wasm_runtime_free(import_globals);
}
static bool
@ -1177,10 +1161,9 @@ fail:
}
static void
destroy_globals(AOTGlobal *globals, bool is_jit_mode)
destroy_globals(AOTGlobal *globals)
{
if (!is_jit_mode)
wasm_runtime_free(globals);
wasm_runtime_free(globals);
}
static bool
@ -1259,10 +1242,9 @@ fail:
}
static void
destroy_import_funcs(AOTImportFunc *import_funcs, bool is_jit_mode)
destroy_import_funcs(AOTImportFunc *import_funcs)
{
if (!is_jit_mode)
wasm_runtime_free(import_funcs);
wasm_runtime_free(import_funcs);
}
static bool
@ -1652,10 +1634,9 @@ fail:
}
static void
destroy_exports(AOTExport *exports, bool is_jit_mode)
destroy_exports(AOTExport *exports)
{
if (!is_jit_mode)
wasm_runtime_free(exports);
wasm_runtime_free(exports);
}
static bool
@ -2826,427 +2807,46 @@ aot_load_from_aot_file(const uint8 *buf, uint32 size, char *error_buf,
return module;
}
#if WASM_ENABLE_JIT != 0
#if WASM_ENABLE_LAZY_JIT != 0
/* Orc JIT thread arguments */
typedef struct OrcJitThreadArg {
AOTCompData *comp_data;
AOTCompContext *comp_ctx;
AOTModule *module;
int32 group_idx;
int32 group_stride;
} OrcJitThreadArg;
static bool orcjit_stop_compiling = false;
static korp_tid orcjit_threads[WASM_LAZY_JIT_COMPILE_THREAD_NUM];
static OrcJitThreadArg orcjit_thread_args[WASM_LAZY_JIT_COMPILE_THREAD_NUM];
static void *
orcjit_thread_callback(void *arg)
{
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
AOTCompData *comp_data = thread_arg->comp_data;
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
AOTModule *module = thread_arg->module;
char func_name[32];
int32 i;
/* Compile wasm functions of this group */
for (i = thread_arg->group_idx; i < (int32)comp_data->func_count;
i += thread_arg->group_stride) {
if (!module->func_ptrs[i]) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit, &func_addr,
func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
/**
* No need to lock the func_ptr[func_idx] here as it is basic
* data type, the load/store for it can be finished by one cpu
* instruction, and there can be only one cpu instruction
* loading/storing at the same time.
*/
module->func_ptrs[i] = (void *)func_addr;
}
if (orcjit_stop_compiling) {
break;
}
}
/* Try to compile functions that haven't been compiled by other threads */
for (i = (int32)comp_data->func_count - 1; i > 0; i--) {
if (orcjit_stop_compiling) {
break;
}
if (!module->func_ptrs[i]) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit, &func_addr,
func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
module->func_ptrs[i] = (void *)func_addr;
}
}
return NULL;
}
static void
orcjit_stop_compile_threads()
{
uint32 i;
orcjit_stop_compiling = true;
for (i = 0; i < WASM_LAZY_JIT_COMPILE_THREAD_NUM; i++) {
os_thread_join(orcjit_threads[i], NULL);
}
}
#endif
static AOTModule *
aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
char *error_buf, uint32 error_buf_size)
{
uint32 i;
uint64 size;
char func_name[32];
AOTModule *module;
/* Allocate memory for module */
if (!(module =
loader_malloc(sizeof(AOTModule), error_buf, error_buf_size))) {
return NULL;
}
module->module_type = Wasm_Module_AoT;
module->import_memory_count = comp_data->import_memory_count;
module->import_memories = comp_data->import_memories;
module->memory_count = comp_data->memory_count;
if (module->memory_count) {
size = sizeof(AOTMemory) * (uint64)module->memory_count;
if (!(module->memories =
loader_malloc(size, error_buf, error_buf_size))) {
goto fail1;
}
bh_memcpy_s(module->memories, (uint32)size, comp_data->memories,
(uint32)size);
}
module->mem_init_data_list = comp_data->mem_init_data_list;
module->mem_init_data_count = comp_data->mem_init_data_count;
module->import_table_count = comp_data->import_table_count;
module->import_tables = comp_data->import_tables;
module->table_count = comp_data->table_count;
module->tables = comp_data->tables;
module->table_init_data_list = comp_data->table_init_data_list;
module->table_init_data_count = comp_data->table_init_data_count;
module->func_type_count = comp_data->func_type_count;
module->func_types = comp_data->func_types;
module->import_global_count = comp_data->import_global_count;
module->import_globals = comp_data->import_globals;
module->global_count = comp_data->global_count;
module->globals = comp_data->globals;
module->global_count = comp_data->global_count;
module->globals = comp_data->globals;
module->global_data_size = comp_data->global_data_size;
module->import_func_count = comp_data->import_func_count;
module->import_funcs = comp_data->import_funcs;
module->func_count = comp_data->func_count;
/* Allocate memory for function pointers */
size = (uint64)module->func_count * sizeof(void *);
if (size > 0
&& !(module->func_ptrs =
loader_malloc(size, error_buf, error_buf_size))) {
goto fail2;
}
#if WASM_ENABLE_LAZY_JIT != 0
/* Create threads to compile the wasm functions */
for (i = 0; i < WASM_LAZY_JIT_COMPILE_THREAD_NUM; i++) {
orcjit_thread_args[i].comp_data = comp_data;
orcjit_thread_args[i].comp_ctx = comp_ctx;
orcjit_thread_args[i].module = module;
orcjit_thread_args[i].group_idx = (int32)i;
orcjit_thread_args[i].group_stride = WASM_LAZY_JIT_COMPILE_THREAD_NUM;
if (os_thread_create(&orcjit_threads[i], orcjit_thread_callback,
(void *)&orcjit_thread_args[i],
APP_THREAD_STACK_SIZE_DEFAULT)
!= 0) {
uint32 j;
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
/* Terminate the threads created */
orcjit_stop_compiling = true;
for (j = 0; j < i; j++) {
os_thread_join(orcjit_threads[j], NULL);
}
goto fail3;
}
}
#else
/* Resolve function addresses */
bh_assert(comp_ctx->exec_engine);
for (i = 0; i < comp_data->func_count; i++) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if (!(module->func_ptrs[i] = (void *)LLVMGetFunctionAddress(
comp_ctx->exec_engine, func_name))) {
set_error_buf(error_buf, error_buf_size,
"get function address failed");
goto fail3;
}
}
#endif /* WASM_ENABLE_LAZY_JIT != 0 */
/* Allocation memory for function type indexes */
size = (uint64)module->func_count * sizeof(uint32);
if (size > 0
&& !(module->func_type_indexes =
loader_malloc(size, error_buf, error_buf_size))) {
goto fail4;
}
for (i = 0; i < comp_data->func_count; i++)
module->func_type_indexes[i] = comp_data->funcs[i]->func_type_index;
module->export_count = comp_data->wasm_module->export_count;
module->exports = comp_data->wasm_module->exports;
module->start_func_index = comp_data->start_func_index;
if (comp_data->start_func_index != (uint32)-1) {
bh_assert(comp_data->start_func_index
< module->import_func_count + module->func_count);
/* TODO: fix issue that start func cannot be import func */
if (comp_data->start_func_index >= module->import_func_count) {
#if WASM_ENABLE_LAZY_JIT != 0
if (!module->func_ptrs[comp_data->start_func_index
- module->import_func_count]) {
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX,
comp_data->start_func_index
- module->import_func_count);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit,
&func_addr, func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
set_error_buf_v(error_buf, error_buf_size,
"failed to compile orc jit function: %s",
err_msg);
LLVMDisposeErrorMessage(err_msg);
goto fail5;
}
module->func_ptrs[comp_data->start_func_index
- module->import_func_count] =
(void *)func_addr;
}
#endif
module->start_function =
module->func_ptrs[comp_data->start_func_index
- module->import_func_count];
}
}
module->malloc_func_index = comp_data->malloc_func_index;
module->free_func_index = comp_data->free_func_index;
module->retain_func_index = comp_data->retain_func_index;
module->aux_data_end_global_index = comp_data->aux_data_end_global_index;
module->aux_data_end = comp_data->aux_data_end;
module->aux_heap_base_global_index = comp_data->aux_heap_base_global_index;
module->aux_heap_base = comp_data->aux_heap_base;
module->aux_stack_top_global_index = comp_data->aux_stack_top_global_index;
module->aux_stack_bottom = comp_data->aux_stack_bottom;
module->aux_stack_size = comp_data->aux_stack_size;
module->code = NULL;
module->code_size = 0;
module->is_jit_mode = true;
module->wasm_module = comp_data->wasm_module;
module->comp_ctx = comp_ctx;
module->comp_data = comp_data;
#if WASM_ENABLE_LIBC_WASI != 0
module->import_wasi_api = comp_data->wasm_module->import_wasi_api;
#endif
return module;
#if WASM_ENABLE_LAZY_JIT != 0
fail5:
if (module->func_type_indexes)
wasm_runtime_free(module->func_type_indexes);
#endif
fail4:
#if WASM_ENABLE_LAZY_JIT != 0
/* Terminate all threads before free module->func_ptrs */
orcjit_stop_compile_threads();
#endif
fail3:
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
fail2:
if (module->memory_count > 0)
wasm_runtime_free(module->memories);
fail1:
wasm_runtime_free(module);
return NULL;
}
AOTModule *
aot_convert_wasm_module(WASMModule *wasm_module, char *error_buf,
uint32 error_buf_size)
{
AOTCompData *comp_data;
AOTCompContext *comp_ctx;
AOTModule *aot_module;
AOTCompOption option = { 0 };
char *aot_last_error;
comp_data = aot_create_comp_data(wasm_module);
if (!comp_data) {
aot_last_error = aot_get_last_error();
bh_assert(aot_last_error != NULL);
set_error_buf(error_buf, error_buf_size, aot_last_error);
return NULL;
}
option.is_jit_mode = true;
option.opt_level = 3;
option.size_level = 3;
#if WASM_ENABLE_BULK_MEMORY != 0
option.enable_bulk_memory = true;
#endif
#if WASM_ENABLE_THREAD_MGR != 0
option.enable_thread_mgr = true;
#endif
#if WASM_ENABLE_TAIL_CALL != 0
option.enable_tail_call = true;
#endif
#if WASM_ENABLE_SIMD != 0
option.enable_simd = true;
#endif
#if WASM_ENABLE_REF_TYPES != 0
option.enable_ref_types = true;
#endif
option.enable_aux_stack_check = true;
#if (WASM_ENABLE_PERF_PROFILING != 0) || (WASM_ENABLE_DUMP_CALL_STACK != 0)
option.enable_aux_stack_frame = true;
#endif
comp_ctx = aot_create_comp_context(comp_data, &option);
if (!comp_ctx) {
aot_last_error = aot_get_last_error();
bh_assert(aot_last_error != NULL);
set_error_buf(error_buf, error_buf_size, aot_last_error);
goto fail1;
}
if (!aot_compile_wasm(comp_ctx)) {
aot_last_error = aot_get_last_error();
bh_assert(aot_last_error != NULL);
set_error_buf(error_buf, error_buf_size, aot_last_error);
goto fail2;
}
aot_module =
aot_load_from_comp_data(comp_data, comp_ctx, error_buf, error_buf_size);
if (!aot_module) {
goto fail2;
}
return aot_module;
fail2:
aot_destroy_comp_context(comp_ctx);
fail1:
aot_destroy_comp_data(comp_data);
return NULL;
}
#endif
void
aot_unload(AOTModule *module)
{
#if WASM_ENABLE_JIT != 0
#if WASM_ENABLE_LAZY_JIT != 0
orcjit_stop_compile_threads();
#endif
if (module->comp_data)
aot_destroy_comp_data(module->comp_data);
if (module->comp_ctx)
aot_destroy_comp_context(module->comp_ctx);
if (module->wasm_module)
wasm_loader_unload(module->wasm_module);
#endif
if (module->import_memories)
destroy_import_memories(module->import_memories, module->is_jit_mode);
destroy_import_memories(module->import_memories);
if (module->memories)
wasm_runtime_free(module->memories);
if (module->mem_init_data_list)
destroy_mem_init_data_list(module->mem_init_data_list,
module->mem_init_data_count,
module->is_jit_mode);
module->mem_init_data_count);
if (module->native_symbol_list)
wasm_runtime_free(module->native_symbol_list);
if (module->import_tables)
destroy_import_tables(module->import_tables, module->is_jit_mode);
destroy_import_tables(module->import_tables);
if (module->tables)
destroy_tables(module->tables, module->is_jit_mode);
destroy_tables(module->tables);
if (module->table_init_data_list)
destroy_table_init_data_list(module->table_init_data_list,
module->table_init_data_count,
module->is_jit_mode);
module->table_init_data_count);
if (module->func_types)
destroy_func_types(module->func_types, module->func_type_count,
module->is_jit_mode);
destroy_func_types(module->func_types, module->func_type_count);
if (module->import_globals)
destroy_import_globals(module->import_globals, module->is_jit_mode);
destroy_import_globals(module->import_globals);
if (module->globals)
destroy_globals(module->globals, module->is_jit_mode);
destroy_globals(module->globals);
if (module->import_funcs)
destroy_import_funcs(module->import_funcs, module->is_jit_mode);
destroy_import_funcs(module->import_funcs);
if (module->exports)
destroy_exports(module->exports, module->is_jit_mode);
destroy_exports(module->exports);
if (module->func_type_indexes)
wasm_runtime_free(module->func_type_indexes);