Refactor LLVM JIT (#1613)

Refactor LLVM JIT for some purposes:
- To simplify the source code of JIT compilation
- To simplify the JIT modes
- To align with LLVM latest changes
- To prepare for the Multi-tier JIT compilation, refer to #1302

The changes mainly include:
- Remove the MCJIT mode, replace it with ORC JIT eager mode
- Remove the LLVM legacy pass manager (only keep the LLVM new pass manager)
- Change the lazy mode's LLVM module/function binding:
  change each function in an individual LLVM module into all functions in a single LLVM module
- Upgraded ORC JIT to ORCv2 JIT to enable lazy compilation

Refer to #1468
This commit is contained in:
Wenyong Huang
2022-10-18 20:17:34 +08:00
committed by GitHub
parent 84b1a6c10e
commit e87a554616
22 changed files with 1058 additions and 1104 deletions

View File

@ -1784,20 +1784,91 @@ calculate_global_data_offset(WASMModule *module)
}
#if WASM_ENABLE_JIT != 0
static void *
orcjit_thread_callback(void *arg)
{
LLVMOrcJITTargetAddress func_addr = 0;
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
WASMModule *module = thread_arg->module;
uint32 group_idx = thread_arg->group_idx;
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
uint32 func_count = module->function_count;
uint32 i, j;
typedef void (*F)(void);
LLVMErrorRef error;
char func_name[48];
union {
F f;
void *v;
} u;
/* Compile jit functions of this group */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
"_wrapper");
LOG_DEBUG("compile func %s", func_name);
error =
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
continue;
}
/* Call the jit wrapper function to trigger its compilation, so as
to compile the actual jit functions, since we add the latter to
function list in the PartitionFunction callback */
u.v = (void *)func_addr;
u.f();
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count)
module->func_ptrs_compiled[i + j * group_stride] = true;
}
if (module->orcjit_stop_compiling) {
break;
}
}
return NULL;
}
static void
orcjit_stop_compile_threads(WASMModule *module)
{
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
/ sizeof(OrcJitThreadArg));
module->orcjit_stop_compiling = true;
for (i = 0; i < thread_num; i++) {
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
}
static bool
compile_llvm_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
AOTCompOption option = { 0 };
char func_name[32], *aot_last_error;
char *aot_last_error;
uint64 size;
uint32 i;
uint32 thread_num, i;
size = sizeof(void *) * (uint64)module->function_count;
if (size > 0
&& !(module->func_ptrs =
loader_malloc(size, error_buf, error_buf_size))) {
return false;
if (module->function_count > 0) {
size = sizeof(void *) * (uint64)module->function_count
+ sizeof(bool) * (uint64)module->function_count;
if (!(module->func_ptrs =
loader_malloc(size, error_buf, error_buf_size))) {
return false;
}
module->func_ptrs_compiled =
(bool *)((uint8 *)module->func_ptrs
+ sizeof(void *) * module->function_count);
}
module->comp_data = aot_create_comp_data(module);
@ -1846,20 +1917,26 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
return false;
}
#if WASM_ENABLE_LAZY_JIT != 0
for (i = 0; i < module->comp_data->func_count; i++) {
LLVMErrorRef error;
bh_print_time("Begin to lookup jit functions");
for (i = 0; i < module->function_count; i++) {
LLVMOrcJITTargetAddress func_addr = 0;
LLVMErrorRef error;
char func_name[48];
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLJITLookup(module->comp_ctx->orc_lazyjit,
&func_addr, func_name))) {
error = LLVMOrcLLLazyJITLookup(module->comp_ctx->orc_jit, &func_addr,
func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
set_error_buf_v(error_buf, error_buf_size,
"failed to compile orc jit function: %s", err_msg);
char buf[128];
snprintf(buf, sizeof(buf), "failed to compile orc jit function: %s",
err_msg);
set_error_buf(error_buf, error_buf_size, buf);
LLVMDisposeErrorMessage(err_msg);
return false;
}
/**
* No need to lock the func_ptr[func_idx] here as it is basic
* data type, the load/store for it can be finished by one cpu
@ -1869,20 +1946,43 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
module->func_ptrs[i] = (void *)func_addr;
module->functions[i]->llvm_jit_func_ptr = (void *)func_addr;
}
#else
/* Resolve function addresses */
bh_assert(module->comp_ctx->exec_engine);
for (i = 0; i < module->comp_data->func_count; i++) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if (!(module->func_ptrs[i] = (void *)LLVMGetFunctionAddress(
module->comp_ctx->exec_engine, func_name))) {
bh_print_time("Begin to compile jit functions");
thread_num =
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
/* Create threads to compile the jit functions */
for (i = 0; i < thread_num; i++) {
module->orcjit_thread_args[i].comp_ctx = module->comp_ctx;
module->orcjit_thread_args[i].module = module;
module->orcjit_thread_args[i].group_idx = i;
if (os_thread_create(&module->orcjit_threads[i], orcjit_thread_callback,
(void *)&module->orcjit_thread_args[i],
APP_THREAD_STACK_SIZE_DEFAULT)
!= 0) {
uint32 j;
set_error_buf(error_buf, error_buf_size,
"failed to compile llvm mc jit function");
"create orcjit compile thread failed");
/* Terminate the threads created */
module->orcjit_stop_compiling = true;
for (j = 0; j < i; j++) {
os_thread_join(module->orcjit_threads[j], NULL);
}
return false;
}
module->functions[i]->llvm_jit_func_ptr = module->func_ptrs[i];
}
#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */
#if WASM_ENABLE_LAZY_JIT == 0
/* Wait until all jit functions are compiled for eager mode */
for (i = 0; i < thread_num; i++) {
os_thread_join(module->orcjit_threads[i], NULL);
}
#endif
bh_print_time("End compile jit functions");
return true;
}
@ -2586,6 +2686,18 @@ wasm_loader_unload(WASMModule *module)
if (!module)
return;
#if WASM_ENABLE_JIT != 0
/* Stop LLVM JIT compilation firstly to avoid accessing
module internal data after they were freed */
orcjit_stop_compile_threads(module);
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
if (module->comp_ctx)
aot_destroy_comp_context(module->comp_ctx);
if (module->comp_data)
aot_destroy_comp_data(module->comp_data);
#endif
if (module->types) {
for (i = 0; i < module->type_count; i++) {
if (module->types[i])
@ -2673,15 +2785,6 @@ wasm_loader_unload(WASMModule *module)
}
#endif
#if WASM_ENABLE_JIT != 0
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
if (module->comp_ctx)
aot_destroy_comp_context(module->comp_ctx);
if (module->comp_data)
aot_destroy_comp_data(module->comp_data);
#endif
wasm_runtime_free(module);
}