Refactor LLVM JIT (#1613)
Refactor LLVM JIT for some purposes: - To simplify the source code of JIT compilation - To simplify the JIT modes - To align with LLVM latest changes - To prepare for the Multi-tier JIT compilation, refer to #1302 The changes mainly include: - Remove the MCJIT mode, replace it with ORC JIT eager mode - Remove the LLVM legacy pass manager (only keep the LLVM new pass manager) - Change the lazy mode's LLVM module/function binding: change each function in an individual LLVM module into all functions in a single LLVM module - Upgraded ORC JIT to ORCv2 JIT to enable lazy compilation Refer to #1468
This commit is contained in:
@ -369,6 +369,13 @@ typedef struct WASMCustomSection {
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
struct AOTCompData;
|
||||
struct AOTCompContext;
|
||||
|
||||
/* Orc JIT thread arguments */
|
||||
typedef struct OrcJitThreadArg {
|
||||
struct AOTCompContext *comp_ctx;
|
||||
struct WASMModule *module;
|
||||
uint32 group_idx;
|
||||
} OrcJitThreadArg;
|
||||
#endif
|
||||
|
||||
struct WASMModule {
|
||||
@ -501,14 +508,20 @@ struct WASMModule {
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_FAST_JIT != 0
|
||||
/* point to JITed functions */
|
||||
/* func pointers of Fast JITed (un-imported) functions */
|
||||
void **fast_jit_func_ptrs;
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
struct AOTCompData *comp_data;
|
||||
struct AOTCompContext *comp_ctx;
|
||||
/* func pointers of LLVM JITed (un-imported) functions */
|
||||
void **func_ptrs;
|
||||
/* whether the func pointers are compiled */
|
||||
bool *func_ptrs_compiled;
|
||||
bool orcjit_stop_compiling;
|
||||
korp_tid orcjit_threads[WASM_ORC_JIT_BACKEND_THREAD_NUM];
|
||||
OrcJitThreadArg orcjit_thread_args[WASM_ORC_JIT_BACKEND_THREAD_NUM];
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@ -2953,20 +2953,91 @@ calculate_global_data_offset(WASMModule *module)
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
static void *
|
||||
orcjit_thread_callback(void *arg)
|
||||
{
|
||||
LLVMOrcJITTargetAddress func_addr = 0;
|
||||
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
|
||||
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
|
||||
WASMModule *module = thread_arg->module;
|
||||
uint32 group_idx = thread_arg->group_idx;
|
||||
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
|
||||
uint32 func_count = module->function_count;
|
||||
uint32 i, j;
|
||||
typedef void (*F)(void);
|
||||
LLVMErrorRef error;
|
||||
char func_name[48];
|
||||
union {
|
||||
F f;
|
||||
void *v;
|
||||
} u;
|
||||
|
||||
/* Compile jit functions of this group */
|
||||
for (i = group_idx; i < func_count;
|
||||
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
|
||||
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
|
||||
"_wrapper");
|
||||
LOG_DEBUG("compile func %s", func_name);
|
||||
error =
|
||||
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
|
||||
if (error != LLVMErrorSuccess) {
|
||||
char *err_msg = LLVMGetErrorMessage(error);
|
||||
os_printf("failed to compile orc jit function: %s", err_msg);
|
||||
LLVMDisposeErrorMessage(err_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Call the jit wrapper function to trigger its compilation, so as
|
||||
to compile the actual jit functions, since we add the latter to
|
||||
function list in the PartitionFunction callback */
|
||||
u.v = (void *)func_addr;
|
||||
u.f();
|
||||
|
||||
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
|
||||
if (i + j * group_stride < func_count)
|
||||
module->func_ptrs_compiled[i + j * group_stride] = true;
|
||||
}
|
||||
|
||||
if (module->orcjit_stop_compiling) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
orcjit_stop_compile_threads(WASMModule *module)
|
||||
{
|
||||
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
|
||||
/ sizeof(OrcJitThreadArg));
|
||||
|
||||
module->orcjit_stop_compiling = true;
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
if (module->orcjit_threads[i])
|
||||
os_thread_join(module->orcjit_threads[i], NULL);
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
uint32 error_buf_size)
|
||||
{
|
||||
AOTCompOption option = { 0 };
|
||||
char func_name[32], *aot_last_error;
|
||||
char *aot_last_error;
|
||||
uint64 size;
|
||||
uint32 i;
|
||||
uint32 thread_num, i;
|
||||
|
||||
size = sizeof(void *) * (uint64)module->function_count;
|
||||
if (size > 0
|
||||
&& !(module->func_ptrs =
|
||||
loader_malloc(size, error_buf, error_buf_size))) {
|
||||
return false;
|
||||
if (module->function_count > 0) {
|
||||
size = sizeof(void *) * (uint64)module->function_count
|
||||
+ sizeof(bool) * (uint64)module->function_count;
|
||||
if (!(module->func_ptrs =
|
||||
loader_malloc(size, error_buf, error_buf_size))) {
|
||||
return false;
|
||||
}
|
||||
module->func_ptrs_compiled =
|
||||
(bool *)((uint8 *)module->func_ptrs
|
||||
+ sizeof(void *) * module->function_count);
|
||||
}
|
||||
|
||||
module->comp_data = aot_create_comp_data(module);
|
||||
@ -3015,20 +3086,24 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
return false;
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_LAZY_JIT != 0
|
||||
for (i = 0; i < module->comp_data->func_count; i++) {
|
||||
LLVMErrorRef error;
|
||||
bh_print_time("Begin to lookup jit functions");
|
||||
|
||||
for (i = 0; i < module->function_count; i++) {
|
||||
LLVMOrcJITTargetAddress func_addr = 0;
|
||||
LLVMErrorRef error;
|
||||
char func_name[48];
|
||||
|
||||
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
|
||||
if ((error = LLVMOrcLLJITLookup(module->comp_ctx->orc_lazyjit,
|
||||
&func_addr, func_name))) {
|
||||
error = LLVMOrcLLLazyJITLookup(module->comp_ctx->orc_jit, &func_addr,
|
||||
func_name);
|
||||
if (error != LLVMErrorSuccess) {
|
||||
char *err_msg = LLVMGetErrorMessage(error);
|
||||
set_error_buf_v(error_buf, error_buf_size,
|
||||
"failed to compile orc jit function: %s", err_msg);
|
||||
LLVMDisposeErrorMessage(err_msg);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* No need to lock the func_ptr[func_idx] here as it is basic
|
||||
* data type, the load/store for it can be finished by one cpu
|
||||
@ -3038,20 +3113,43 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
module->func_ptrs[i] = (void *)func_addr;
|
||||
module->functions[i]->llvm_jit_func_ptr = (void *)func_addr;
|
||||
}
|
||||
#else
|
||||
/* Resolve function addresses */
|
||||
bh_assert(module->comp_ctx->exec_engine);
|
||||
for (i = 0; i < module->comp_data->func_count; i++) {
|
||||
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
|
||||
if (!(module->func_ptrs[i] = (void *)LLVMGetFunctionAddress(
|
||||
module->comp_ctx->exec_engine, func_name))) {
|
||||
|
||||
bh_print_time("Begin to compile jit functions");
|
||||
|
||||
thread_num =
|
||||
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
|
||||
|
||||
/* Create threads to compile the jit functions */
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
module->orcjit_thread_args[i].comp_ctx = module->comp_ctx;
|
||||
module->orcjit_thread_args[i].module = module;
|
||||
module->orcjit_thread_args[i].group_idx = i;
|
||||
|
||||
if (os_thread_create(&module->orcjit_threads[i], orcjit_thread_callback,
|
||||
(void *)&module->orcjit_thread_args[i],
|
||||
APP_THREAD_STACK_SIZE_DEFAULT)
|
||||
!= 0) {
|
||||
uint32 j;
|
||||
|
||||
set_error_buf(error_buf, error_buf_size,
|
||||
"failed to compile llvm mc jit function");
|
||||
"create orcjit compile thread failed");
|
||||
/* Terminate the threads created */
|
||||
module->orcjit_stop_compiling = true;
|
||||
for (j = 0; j < i; j++) {
|
||||
os_thread_join(module->orcjit_threads[j], NULL);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
module->functions[i]->llvm_jit_func_ptr = module->func_ptrs[i];
|
||||
}
|
||||
#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */
|
||||
|
||||
#if WASM_ENABLE_LAZY_JIT == 0
|
||||
/* Wait until all jit functions are compiled for eager mode */
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
os_thread_join(module->orcjit_threads[i], NULL);
|
||||
}
|
||||
#endif
|
||||
|
||||
bh_print_time("End compile jit functions");
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -3893,6 +3991,18 @@ wasm_loader_unload(WASMModule *module)
|
||||
if (!module)
|
||||
return;
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
/* Stop LLVM JIT compilation firstly to avoid accessing
|
||||
module internal data after they were freed */
|
||||
orcjit_stop_compile_threads(module);
|
||||
if (module->func_ptrs)
|
||||
wasm_runtime_free(module->func_ptrs);
|
||||
if (module->comp_ctx)
|
||||
aot_destroy_comp_context(module->comp_ctx);
|
||||
if (module->comp_data)
|
||||
aot_destroy_comp_data(module->comp_data);
|
||||
#endif
|
||||
|
||||
if (module->types) {
|
||||
for (i = 0; i < module->type_count; i++) {
|
||||
if (module->types[i])
|
||||
@ -4018,15 +4128,6 @@ wasm_loader_unload(WASMModule *module)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
if (module->func_ptrs)
|
||||
wasm_runtime_free(module->func_ptrs);
|
||||
if (module->comp_ctx)
|
||||
aot_destroy_comp_context(module->comp_ctx);
|
||||
if (module->comp_data)
|
||||
aot_destroy_comp_data(module->comp_data);
|
||||
#endif
|
||||
|
||||
wasm_runtime_free(module);
|
||||
}
|
||||
|
||||
|
||||
@ -1784,20 +1784,91 @@ calculate_global_data_offset(WASMModule *module)
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
static void *
|
||||
orcjit_thread_callback(void *arg)
|
||||
{
|
||||
LLVMOrcJITTargetAddress func_addr = 0;
|
||||
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
|
||||
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
|
||||
WASMModule *module = thread_arg->module;
|
||||
uint32 group_idx = thread_arg->group_idx;
|
||||
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
|
||||
uint32 func_count = module->function_count;
|
||||
uint32 i, j;
|
||||
typedef void (*F)(void);
|
||||
LLVMErrorRef error;
|
||||
char func_name[48];
|
||||
union {
|
||||
F f;
|
||||
void *v;
|
||||
} u;
|
||||
|
||||
/* Compile jit functions of this group */
|
||||
for (i = group_idx; i < func_count;
|
||||
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
|
||||
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
|
||||
"_wrapper");
|
||||
LOG_DEBUG("compile func %s", func_name);
|
||||
error =
|
||||
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
|
||||
if (error != LLVMErrorSuccess) {
|
||||
char *err_msg = LLVMGetErrorMessage(error);
|
||||
os_printf("failed to compile orc jit function: %s", err_msg);
|
||||
LLVMDisposeErrorMessage(err_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Call the jit wrapper function to trigger its compilation, so as
|
||||
to compile the actual jit functions, since we add the latter to
|
||||
function list in the PartitionFunction callback */
|
||||
u.v = (void *)func_addr;
|
||||
u.f();
|
||||
|
||||
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
|
||||
if (i + j * group_stride < func_count)
|
||||
module->func_ptrs_compiled[i + j * group_stride] = true;
|
||||
}
|
||||
|
||||
if (module->orcjit_stop_compiling) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
orcjit_stop_compile_threads(WASMModule *module)
|
||||
{
|
||||
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
|
||||
/ sizeof(OrcJitThreadArg));
|
||||
|
||||
module->orcjit_stop_compiling = true;
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
if (module->orcjit_threads[i])
|
||||
os_thread_join(module->orcjit_threads[i], NULL);
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
uint32 error_buf_size)
|
||||
{
|
||||
AOTCompOption option = { 0 };
|
||||
char func_name[32], *aot_last_error;
|
||||
char *aot_last_error;
|
||||
uint64 size;
|
||||
uint32 i;
|
||||
uint32 thread_num, i;
|
||||
|
||||
size = sizeof(void *) * (uint64)module->function_count;
|
||||
if (size > 0
|
||||
&& !(module->func_ptrs =
|
||||
loader_malloc(size, error_buf, error_buf_size))) {
|
||||
return false;
|
||||
if (module->function_count > 0) {
|
||||
size = sizeof(void *) * (uint64)module->function_count
|
||||
+ sizeof(bool) * (uint64)module->function_count;
|
||||
if (!(module->func_ptrs =
|
||||
loader_malloc(size, error_buf, error_buf_size))) {
|
||||
return false;
|
||||
}
|
||||
module->func_ptrs_compiled =
|
||||
(bool *)((uint8 *)module->func_ptrs
|
||||
+ sizeof(void *) * module->function_count);
|
||||
}
|
||||
|
||||
module->comp_data = aot_create_comp_data(module);
|
||||
@ -1846,20 +1917,26 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
return false;
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_LAZY_JIT != 0
|
||||
for (i = 0; i < module->comp_data->func_count; i++) {
|
||||
LLVMErrorRef error;
|
||||
bh_print_time("Begin to lookup jit functions");
|
||||
|
||||
for (i = 0; i < module->function_count; i++) {
|
||||
LLVMOrcJITTargetAddress func_addr = 0;
|
||||
LLVMErrorRef error;
|
||||
char func_name[48];
|
||||
|
||||
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
|
||||
if ((error = LLVMOrcLLJITLookup(module->comp_ctx->orc_lazyjit,
|
||||
&func_addr, func_name))) {
|
||||
error = LLVMOrcLLLazyJITLookup(module->comp_ctx->orc_jit, &func_addr,
|
||||
func_name);
|
||||
if (error != LLVMErrorSuccess) {
|
||||
char *err_msg = LLVMGetErrorMessage(error);
|
||||
set_error_buf_v(error_buf, error_buf_size,
|
||||
"failed to compile orc jit function: %s", err_msg);
|
||||
char buf[128];
|
||||
snprintf(buf, sizeof(buf), "failed to compile orc jit function: %s",
|
||||
err_msg);
|
||||
set_error_buf(error_buf, error_buf_size, buf);
|
||||
LLVMDisposeErrorMessage(err_msg);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* No need to lock the func_ptr[func_idx] here as it is basic
|
||||
* data type, the load/store for it can be finished by one cpu
|
||||
@ -1869,20 +1946,43 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
|
||||
module->func_ptrs[i] = (void *)func_addr;
|
||||
module->functions[i]->llvm_jit_func_ptr = (void *)func_addr;
|
||||
}
|
||||
#else
|
||||
/* Resolve function addresses */
|
||||
bh_assert(module->comp_ctx->exec_engine);
|
||||
for (i = 0; i < module->comp_data->func_count; i++) {
|
||||
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, i);
|
||||
if (!(module->func_ptrs[i] = (void *)LLVMGetFunctionAddress(
|
||||
module->comp_ctx->exec_engine, func_name))) {
|
||||
|
||||
bh_print_time("Begin to compile jit functions");
|
||||
|
||||
thread_num =
|
||||
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
|
||||
|
||||
/* Create threads to compile the jit functions */
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
module->orcjit_thread_args[i].comp_ctx = module->comp_ctx;
|
||||
module->orcjit_thread_args[i].module = module;
|
||||
module->orcjit_thread_args[i].group_idx = i;
|
||||
|
||||
if (os_thread_create(&module->orcjit_threads[i], orcjit_thread_callback,
|
||||
(void *)&module->orcjit_thread_args[i],
|
||||
APP_THREAD_STACK_SIZE_DEFAULT)
|
||||
!= 0) {
|
||||
uint32 j;
|
||||
|
||||
set_error_buf(error_buf, error_buf_size,
|
||||
"failed to compile llvm mc jit function");
|
||||
"create orcjit compile thread failed");
|
||||
/* Terminate the threads created */
|
||||
module->orcjit_stop_compiling = true;
|
||||
for (j = 0; j < i; j++) {
|
||||
os_thread_join(module->orcjit_threads[j], NULL);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
module->functions[i]->llvm_jit_func_ptr = module->func_ptrs[i];
|
||||
}
|
||||
#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */
|
||||
|
||||
#if WASM_ENABLE_LAZY_JIT == 0
|
||||
/* Wait until all jit functions are compiled for eager mode */
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
os_thread_join(module->orcjit_threads[i], NULL);
|
||||
}
|
||||
#endif
|
||||
|
||||
bh_print_time("End compile jit functions");
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -2586,6 +2686,18 @@ wasm_loader_unload(WASMModule *module)
|
||||
if (!module)
|
||||
return;
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
/* Stop LLVM JIT compilation firstly to avoid accessing
|
||||
module internal data after they were freed */
|
||||
orcjit_stop_compile_threads(module);
|
||||
if (module->func_ptrs)
|
||||
wasm_runtime_free(module->func_ptrs);
|
||||
if (module->comp_ctx)
|
||||
aot_destroy_comp_context(module->comp_ctx);
|
||||
if (module->comp_data)
|
||||
aot_destroy_comp_data(module->comp_data);
|
||||
#endif
|
||||
|
||||
if (module->types) {
|
||||
for (i = 0; i < module->type_count; i++) {
|
||||
if (module->types[i])
|
||||
@ -2673,15 +2785,6 @@ wasm_loader_unload(WASMModule *module)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_JIT != 0
|
||||
if (module->func_ptrs)
|
||||
wasm_runtime_free(module->func_ptrs);
|
||||
if (module->comp_ctx)
|
||||
aot_destroy_comp_context(module->comp_ctx);
|
||||
if (module->comp_data)
|
||||
aot_destroy_comp_data(module->comp_data);
|
||||
#endif
|
||||
|
||||
wasm_runtime_free(module);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user