Refactor Orc JIT to enable lazy compilation (#974)

Refactor LLVM Orc JIT to actually enable the lazy compilation and speedup
the launching process:
  https://llvm.org/docs/ORCv2.html#laziness

Main modifications:
- Create LLVM module for each wasm function, wrap it with thread safe module
  so that the modules can be compiled parallelly
- Lookup function from aot module instance's func_ptrs but not directly call the
  function to decouple the module relationship
- Compile the function when it is first called and hasn't been compiled
- Create threads to pre-compile the WASM functions parallelly when loading
- Set Lazy JIT as default, update document and build/test scripts
This commit is contained in:
Wenyong Huang
2022-01-20 18:40:13 +08:00
committed by GitHub
parent 260d36a62d
commit 7636d86a76
27 changed files with 861 additions and 464 deletions

View File

@ -2760,6 +2760,90 @@ aot_load_from_aot_file(const uint8 *buf, uint32 size, char *error_buf,
}
#if WASM_ENABLE_JIT != 0
#if WASM_ENABLE_LAZY_JIT != 0
/* Orc JIT thread arguments */
typedef struct OrcJitThreadArg {
AOTCompData *comp_data;
AOTCompContext *comp_ctx;
AOTModule *module;
int32 group_idx;
int32 group_stride;
} OrcJitThreadArg;
static bool orcjit_stop_compiling = false;
static korp_tid orcjit_threads[WASM_LAZY_JIT_COMPILE_THREAD_NUM];
static OrcJitThreadArg orcjit_thread_args[WASM_LAZY_JIT_COMPILE_THREAD_NUM];
static void *
orcjit_thread_callback(void *arg)
{
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
AOTCompData *comp_data = thread_arg->comp_data;
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
AOTModule *module = thread_arg->module;
char func_name[32];
int32 i;
/* Compile wasm functions of this group */
for (i = thread_arg->group_idx; i < (int32)comp_data->func_count;
i += thread_arg->group_stride) {
if (!module->func_ptrs[i]) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit, &func_addr,
func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
/**
* No need to lock the func_ptr[func_idx] here as it is basic
* data type, the load/store for it can be finished by one cpu
* instruction, and there can be only one cpu instruction
* loading/storing at the same time.
*/
module->func_ptrs[i] = (void *)func_addr;
}
if (orcjit_stop_compiling) {
break;
}
}
/* Try to compile functions that haven't been compiled by other threads */
for (i = (int32)comp_data->func_count - 1; i > 0; i--) {
if (orcjit_stop_compiling) {
break;
}
if (!module->func_ptrs[i]) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit, &func_addr,
func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
module->func_ptrs[i] = (void *)func_addr;
}
}
return NULL;
}
static void
orcjit_stop_compile_threads()
{
uint32 i;
orcjit_stop_compiling = true;
for (i = 0; i < WASM_LAZY_JIT_COMPILE_THREAD_NUM; i++) {
os_thread_join(orcjit_threads[i], NULL);
}
}
#endif
static AOTModule *
aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
char *error_buf, uint32 error_buf_size)
@ -2769,13 +2853,6 @@ aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
char func_name[32];
AOTModule *module;
#if WASM_ENABLE_LAZY_JIT != 0
LLVMOrcThreadSafeModuleRef ts_module;
LLVMOrcJITDylibRef main_dylib;
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
#endif
/* Allocate memory for module */
if (!(module =
loader_malloc(sizeof(AOTModule), error_buf, error_buf_size))) {
@ -2839,44 +2916,28 @@ aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
}
#if WASM_ENABLE_LAZY_JIT != 0
bh_assert(comp_ctx->lazy_orcjit);
/* Create threads to compile the wasm functions */
for (i = 0; i < WASM_LAZY_JIT_COMPILE_THREAD_NUM; i++) {
orcjit_thread_args[i].comp_data = comp_data;
orcjit_thread_args[i].comp_ctx = comp_ctx;
orcjit_thread_args[i].module = module;
orcjit_thread_args[i].group_idx = (int32)i;
orcjit_thread_args[i].group_stride = WASM_LAZY_JIT_COMPILE_THREAD_NUM;
if (os_thread_create(&orcjit_threads[i], orcjit_thread_callback,
(void *)&orcjit_thread_args[i],
APP_THREAD_STACK_SIZE_DEFAULT)
!= 0) {
uint32 j;
main_dylib = LLVMOrcLLLazyJITGetMainJITDylib(comp_ctx->lazy_orcjit);
if (!main_dylib) {
set_error_buf(error_buf, error_buf_size,
"failed to get dynmaic library reference");
goto fail3;
}
ts_module = LLVMOrcCreateNewThreadSafeModule(comp_ctx->module,
comp_ctx->ts_context);
if (!ts_module) {
set_error_buf(error_buf, error_buf_size,
"failed to create thread safe module");
goto fail3;
}
if ((error = LLVMOrcLLLazyJITAddLLVMIRModule(comp_ctx->lazy_orcjit,
main_dylib, ts_module))) {
/*
* If adding the ThreadSafeModule fails then we need to clean it up
* ourselves. If adding it succeeds the JIT will manage the memory.
*/
aot_handle_llvm_errmsg(error_buf, error_buf_size,
"failed to addIRModule: ", error);
goto fail4;
}
for (i = 0; i < comp_data->func_count; i++) {
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
if ((error = LLVMOrcLLLazyJITLookup(comp_ctx->lazy_orcjit, &func_addr,
func_name))) {
aot_handle_llvm_errmsg(error_buf, error_buf_size,
"cannot lookup: ", error);
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
/* Terminate the threads created */
orcjit_stop_compiling = true;
for (j = 0; j < i; j++) {
os_thread_join(orcjit_threads[j], NULL);
}
goto fail3;
}
module->func_ptrs[i] = (void *)func_addr;
func_addr = 0;
}
#else
/* Resolve function addresses */
@ -2897,7 +2958,7 @@ aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
if (size > 0
&& !(module->func_type_indexes =
loader_malloc(size, error_buf, error_buf_size))) {
goto fail3;
goto fail4;
}
for (i = 0; i < comp_data->func_count; i++)
module->func_type_indexes[i] = comp_data->funcs[i]->func_type_index;
@ -2911,6 +2972,29 @@ aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
< module->import_func_count + module->func_count);
/* TODO: fix issue that start func cannot be import func */
if (comp_data->start_func_index >= module->import_func_count) {
#if WASM_ENABLE_LAZY_JIT != 0
if (!module->func_ptrs[comp_data->start_func_index
- module->import_func_count]) {
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX,
comp_data->start_func_index
- module->import_func_count);
if ((error = LLVMOrcLLJITLookup(comp_ctx->orc_lazyjit,
&func_addr, func_name))) {
char *err_msg = LLVMGetErrorMessage(error);
set_error_buf_v(error_buf, error_buf_size,
"failed to compile orc jit function: %s",
err_msg);
LLVMDisposeErrorMessage(err_msg);
goto fail5;
}
module->func_ptrs[comp_data->start_func_index
- module->import_func_count] =
(void *)func_addr;
}
#endif
module->start_function =
module->func_ptrs[comp_data->start_func_index
- module->import_func_count];
@ -2945,10 +3029,16 @@ aot_load_from_comp_data(AOTCompData *comp_data, AOTCompContext *comp_ctx,
return module;
#if WASM_ENABLE_LAZY_JIT != 0
fail4:
LLVMOrcDisposeThreadSafeModule(ts_module);
fail5:
if (module->func_type_indexes)
wasm_runtime_free(module->func_type_indexes);
#endif
fail4:
#if WASM_ENABLE_LAZY_JIT != 0
/* Terminate all threads before free module->func_ptrs */
orcjit_stop_compile_threads();
#endif
fail3:
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
@ -3034,6 +3124,10 @@ void
aot_unload(AOTModule *module)
{
#if WASM_ENABLE_JIT != 0
#if WASM_ENABLE_LAZY_JIT != 0
orcjit_stop_compile_threads();
#endif
if (module->comp_data)
aot_destroy_comp_data(module->comp_data);

View File

@ -1419,6 +1419,17 @@ aot_call_function(WASMExecEnv *exec_env, AOTFunctionInstance *function,
}
argc = func_type->param_cell_num;
#if WASM_ENABLE_LAZY_JIT != 0
if (!function->u.func.func_ptr) {
AOTModule *aot_module = (AOTModule *)module_inst->aot_module.ptr;
if (!(function->u.func.func_ptr =
aot_lookup_orcjit_func(aot_module->comp_ctx->orc_lazyjit,
module_inst, function->func_index))) {
return false;
}
}
#endif
/* set thread handle and stack boundary */
wasm_exec_env_set_thread_info(exec_env);
@ -2300,6 +2311,15 @@ aot_call_indirect(WASMExecEnv *exec_env, uint32 tbl_idx, uint32 table_elem_idx,
func_type_idx = func_type_indexes[func_idx];
func_type = aot_module->func_types[func_type_idx];
#if WASM_ENABLE_LAZY_JIT != 0
if (func_idx >= aot_module->import_func_count && !func_ptrs[func_idx]) {
if (!(func_ptr = aot_lookup_orcjit_func(
aot_module->comp_ctx->orc_lazyjit, module_inst, func_idx))) {
return false;
}
}
#endif
if (!(func_ptr = func_ptrs[func_idx])) {
bh_assert(func_idx < aot_module->import_func_count);
import_func = aot_module->import_funcs + func_idx;