Implement call Fast JIT function from LLVM JIT jitted code (#1714)

Basically implement the Multi-tier JIT engine.
And update document and wamr-test-suites script.
This commit is contained in:
Wenyong Huang
2022-11-21 10:42:18 +08:00
committed by GitHub
parent 3daa512925
commit cf7b01ad82
20 changed files with 1775 additions and 221 deletions

View File

@ -428,6 +428,12 @@ destroy_wasm_type(WASMType *type)
return;
}
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
if (type->call_to_llvm_jit_from_fast_jit)
jit_code_cache_free(type->call_to_llvm_jit_from_fast_jit);
#endif
wasm_runtime_free(type);
}
@ -2925,93 +2931,78 @@ calculate_global_data_offset(WASMModule *module)
module->global_data_size = data_offset;
}
#if WASM_ENABLE_JIT != 0
static void *
orcjit_thread_callback(void *arg)
{
LLVMOrcJITTargetAddress func_addr = 0;
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
WASMModule *module = thread_arg->module;
uint32 group_idx = thread_arg->group_idx;
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
uint32 func_count = module->function_count;
uint32 i, j;
typedef void (*F)(void);
LLVMErrorRef error;
char func_name[48];
union {
F f;
void *v;
} u;
/* Compile jit functions of this group */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
"_wrapper");
LOG_DEBUG("compile func %s", func_name);
error =
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
continue;
}
/* Call the jit wrapper function to trigger its compilation, so as
to compile the actual jit functions, since we add the latter to
function list in the PartitionFunction callback */
u.v = (void *)func_addr;
u.f();
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count)
module->func_ptrs_compiled[i + j * group_stride] = true;
}
if (module->orcjit_stop_compiling) {
break;
}
}
return NULL;
}
static void
orcjit_stop_compile_threads(WASMModule *module)
{
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
/ sizeof(OrcJitThreadArg));
module->orcjit_stop_compiling = true;
for (i = 0; i < thread_num; i++) {
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
}
#if WASM_ENABLE_FAST_JIT != 0
static bool
compile_llvm_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
init_fast_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
#if WASM_ENABLE_LAZY_JIT != 0
JitGlobals *jit_globals = jit_compiler_get_jit_globals();
#endif
uint32 i;
if (!module->function_count)
return true;
if (!(module->fast_jit_func_ptrs =
loader_malloc(sizeof(void *) * module->function_count, error_buf,
error_buf_size))) {
return false;
}
#if WASM_ENABLE_LAZY_JIT != 0
for (i = 0; i < module->function_count; i++) {
module->fast_jit_func_ptrs[i] =
jit_globals->compile_fast_jit_and_then_call;
}
#endif
for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) {
if (os_mutex_init(&module->fast_jit_thread_locks[i]) != 0) {
set_error_buf(error_buf, error_buf_size,
"init fast jit thread lock failed");
return false;
}
module->fast_jit_thread_locks_inited[i] = true;
}
return true;
}
#endif /* end of WASM_ENABLE_FAST_JIT != 0 */
#if WASM_ENABLE_JIT != 0
static bool
init_llvm_jit_functions_stage1(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
AOTCompOption option = { 0 };
char *aot_last_error;
uint64 size;
uint32 thread_num, i;
if (module->function_count > 0) {
size = sizeof(void *) * (uint64)module->function_count
+ sizeof(bool) * (uint64)module->function_count;
if (!(module->func_ptrs =
loader_malloc(size, error_buf, error_buf_size))) {
return false;
}
module->func_ptrs_compiled =
(bool *)((uint8 *)module->func_ptrs
+ sizeof(void *) * module->function_count);
if (module->function_count == 0)
return true;
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LLVM_JIT != 0
if (os_mutex_init(&module->tierup_wait_lock) != 0) {
set_error_buf(error_buf, error_buf_size, "init jit tierup lock failed");
return false;
}
if (os_cond_init(&module->tierup_wait_cond) != 0) {
set_error_buf(error_buf, error_buf_size, "init jit tierup cond failed");
os_mutex_destroy(&module->tierup_wait_lock);
return false;
}
module->tierup_wait_lock_inited = true;
#endif
size = sizeof(void *) * (uint64)module->function_count
+ sizeof(bool) * (uint64)module->function_count;
if (!(module->func_ptrs = loader_malloc(size, error_buf, error_buf_size))) {
return false;
}
module->func_ptrs_compiled =
(bool *)((uint8 *)module->func_ptrs
+ sizeof(void *) * module->function_count);
module->comp_data = aot_create_comp_data(module);
if (!module->comp_data) {
@ -3052,6 +3043,19 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
return false;
}
return true;
}
static bool
init_llvm_jit_functions_stage2(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
char *aot_last_error;
uint32 i;
if (module->function_count == 0)
return true;
if (!aot_compile_wasm(module->comp_ctx)) {
aot_last_error = aot_get_last_error();
bh_assert(aot_last_error != NULL);
@ -3059,7 +3063,7 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
return false;
}
bh_print_time("Begin to lookup jit functions");
bh_print_time("Begin to lookup llvm jit functions");
for (i = 0; i < module->function_count; i++) {
LLVMOrcJITTargetAddress func_addr = 0;
@ -3084,17 +3088,206 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
* loading/storing at the same time.
*/
module->func_ptrs[i] = (void *)func_addr;
module->functions[i]->llvm_jit_func_ptr = (void *)func_addr;
}
bh_print_time("End lookup llvm jit functions");
return true;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
static void *
init_llvm_jit_functions_stage2_callback(void *arg)
{
WASMModule *module = (WASMModule *)arg;
char error_buf[128];
uint32 error_buf_size = (uint32)sizeof(error_buf);
if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) {
module->orcjit_stop_compiling = true;
return NULL;
}
os_mutex_lock(&module->tierup_wait_lock);
module->llvm_jit_inited = true;
os_cond_broadcast(&module->tierup_wait_cond);
os_mutex_unlock(&module->tierup_wait_lock);
return NULL;
}
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* The callback function to compile jit functions */
static void *
orcjit_thread_callback(void *arg)
{
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
#if WASM_ENABLE_JIT != 0
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
#endif
WASMModule *module = thread_arg->module;
uint32 group_idx = thread_arg->group_idx;
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
uint32 func_count = module->function_count;
uint32 i;
#if WASM_ENABLE_FAST_JIT != 0
/* Compile fast jit funcitons of this group */
for (i = group_idx; i < func_count; i += group_stride) {
if (!jit_compiler_compile(module, i + module->import_function_count)) {
os_printf("failed to compile fast jit function %u\n", i);
break;
}
if (module->orcjit_stop_compiling) {
return NULL;
}
}
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
/* For JIT tier-up, set each llvm jit func to call_to_fast_jit */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
uint32 j;
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count) {
if (!jit_compiler_set_call_to_fast_jit(
module,
i + j * group_stride + module->import_function_count)) {
os_printf(
"failed to compile call_to_fast_jit for func %u\n",
i + j * group_stride + module->import_function_count);
module->orcjit_stop_compiling = true;
return NULL;
}
}
if (module->orcjit_stop_compiling) {
return NULL;
}
}
}
/* Wait until init_llvm_jit_functions_stage2 finishes */
os_mutex_lock(&module->tierup_wait_lock);
while (!module->llvm_jit_inited) {
os_cond_reltimedwait(&module->tierup_wait_cond,
&module->tierup_wait_lock, 10);
if (module->orcjit_stop_compiling) {
/* init_llvm_jit_functions_stage2 failed */
os_mutex_unlock(&module->tierup_wait_lock);
return NULL;
}
}
os_mutex_unlock(&module->tierup_wait_lock);
#endif
#if WASM_ENABLE_JIT != 0
/* Compile llvm jit functions of this group */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
LLVMOrcJITTargetAddress func_addr = 0;
LLVMErrorRef error;
char func_name[48];
typedef void (*F)(void);
union {
F f;
void *v;
} u;
uint32 j;
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
"_wrapper");
LOG_DEBUG("compile llvm jit func %s", func_name);
error =
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile llvm jit function %u: %s", i, err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
/* Call the jit wrapper function to trigger its compilation, so as
to compile the actual jit functions, since we add the latter to
function list in the PartitionFunction callback */
u.v = (void *)func_addr;
u.f();
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count) {
module->func_ptrs_compiled[i + j * group_stride] = true;
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX,
i + j * group_stride);
error = LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr,
func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile llvm jit function %u: %s", i,
err_msg);
LLVMDisposeErrorMessage(err_msg);
/* Ignore current llvm jit func, as its func ptr is
previous set to call_to_fast_jit, which also works */
continue;
}
jit_compiler_set_llvm_jit_func_ptr(
module,
i + j * group_stride + module->import_function_count,
(void *)func_addr);
/* Try to switch to call this llvm jit funtion instead of
fast jit function from fast jit jitted code */
jit_compiler_set_call_to_llvm_jit(
module,
i + j * group_stride + module->import_function_count);
#endif
}
}
if (module->orcjit_stop_compiling) {
break;
}
}
#endif
return NULL;
}
static void
orcjit_stop_compile_threads(WASMModule *module)
{
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
/ sizeof(OrcJitThreadArg));
module->orcjit_stop_compiling = true;
for (i = 0; i < thread_num; i++) {
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
}
static bool
compile_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
uint32 thread_num =
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
uint32 i, j;
bh_print_time("Begin to compile jit functions");
thread_num =
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
/* Create threads to compile the jit functions */
for (i = 0; i < thread_num; i++) {
for (i = 0; i < thread_num && i < module->function_count; i++) {
#if WASM_ENABLE_JIT != 0
module->orcjit_thread_args[i].comp_ctx = module->comp_ctx;
#endif
module->orcjit_thread_args[i].module = module;
module->orcjit_thread_args[i].group_idx = i;
@ -3102,8 +3295,6 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
(void *)&module->orcjit_thread_args[i],
APP_THREAD_STACK_SIZE_DEFAULT)
!= 0) {
uint32 j;
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
/* Terminate the threads created */
@ -3118,15 +3309,39 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
#if WASM_ENABLE_LAZY_JIT == 0
/* Wait until all jit functions are compiled for eager mode */
for (i = 0; i < thread_num; i++) {
os_thread_join(module->orcjit_threads[i], NULL);
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
#if WASM_ENABLE_FAST_JIT != 0
/* Ensure all the fast-jit functions are compiled */
for (i = 0; i < module->function_count; i++) {
if (!jit_compiler_is_compiled(module,
i + module->import_function_count)) {
set_error_buf(error_buf, error_buf_size,
"failed to compile fast jit function");
return false;
}
}
#endif
#if WASM_ENABLE_JIT != 0
/* Ensure all the llvm-jit functions are compiled */
for (i = 0; i < module->function_count; i++) {
if (!module->func_ptrs_compiled[i]) {
set_error_buf(error_buf, error_buf_size,
"failed to compile llvm jit function");
return false;
}
}
#endif
#endif /* end of WASM_ENABLE_LAZY_JIT == 0 */
bh_print_time("End compile jit functions");
return true;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#endif /* end of WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 */
static bool
wasm_loader_prepare_bytecode(WASMModule *module, WASMFunction *func,
@ -3538,23 +3753,41 @@ load_from_sections(WASMModule *module, WASMSection *sections,
calculate_global_data_offset(module);
#if WASM_ENABLE_FAST_JIT != 0
if (module->function_count
&& !(module->fast_jit_func_ptrs =
loader_malloc(sizeof(void *) * module->function_count,
error_buf, error_buf_size))) {
return false;
}
if (!jit_compiler_compile_all(module)) {
set_error_buf(error_buf, error_buf_size, "fast jit compilation failed");
if (!init_fast_jit_functions(module, error_buf, error_buf_size)) {
return false;
}
#endif
#if WASM_ENABLE_JIT != 0
if (!compile_llvm_jit_functions(module, error_buf, error_buf_size)) {
if (!init_llvm_jit_functions_stage1(module, error_buf, error_buf_size)) {
return false;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#if !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0)
if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) {
return false;
}
#else
/* Run aot_compile_wasm in a backend thread, so as not to block the main
thread fast jit execution, since applying llvm optimizations in
aot_compile_wasm may cost a lot of time.
Create thread with enough native stack to apply llvm optimizations */
if (os_thread_create(&module->llvm_jit_init_thread,
init_llvm_jit_functions_stage2_callback,
(void *)module, APP_THREAD_STACK_SIZE_DEFAULT * 8)
!= 0) {
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
return false;
}
#endif
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* Create threads to compile the jit functions */
if (!compile_jit_functions(module, error_buf, error_buf_size)) {
return false;
}
#endif
#if WASM_ENABLE_MEMORY_TRACING != 0
wasm_runtime_dump_module_mem_consumption((WASMModuleCommon *)module);
@ -3567,9 +3800,7 @@ create_module(char *error_buf, uint32 error_buf_size)
{
WASMModule *module =
loader_malloc(sizeof(WASMModule), error_buf, error_buf_size);
#if WASM_ENABLE_FAST_INTERP == 0
bh_list_status ret;
#endif
if (!module) {
return NULL;
@ -3584,19 +3815,31 @@ create_module(char *error_buf, uint32 error_buf_size)
module->br_table_cache_list = &module->br_table_cache_list_head;
ret = bh_list_init(module->br_table_cache_list);
bh_assert(ret == BH_LIST_SUCCESS);
(void)ret;
#endif
#if WASM_ENABLE_MULTI_MODULE != 0
module->import_module_list = &module->import_module_list_head;
ret = bh_list_init(module->import_module_list);
bh_assert(ret == BH_LIST_SUCCESS);
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0
bh_list_init(&module->fast_opcode_list);
if (os_mutex_init(&module->ref_count_lock) != 0) {
ret = bh_list_init(&module->fast_opcode_list);
bh_assert(ret == BH_LIST_SUCCESS);
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0 \
|| (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \
&& WASM_ENABLE_LAZY_JIT != 0)
if (os_mutex_init(&module->instance_list_lock) != 0) {
set_error_buf(error_buf, error_buf_size,
"init instance list lock failed");
wasm_runtime_free(module);
return NULL;
}
#endif
(void)ret;
return module;
}
@ -3964,10 +4207,19 @@ wasm_loader_unload(WASMModule *module)
if (!module)
return;
#if WASM_ENABLE_JIT != 0
/* Stop LLVM JIT compilation firstly to avoid accessing
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
module->orcjit_stop_compiling = true;
if (module->llvm_jit_init_thread)
os_thread_join(module->llvm_jit_init_thread, NULL);
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* Stop Fast/LLVM JIT compilation firstly to avoid accessing
module internal data after they were freed */
orcjit_stop_compile_threads(module);
#endif
#if WASM_ENABLE_JIT != 0
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
if (module->comp_ctx)
@ -3976,6 +4228,13 @@ wasm_loader_unload(WASMModule *module)
aot_destroy_comp_data(module->comp_data);
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
if (module->tierup_wait_lock_inited) {
os_mutex_destroy(&module->tierup_wait_lock);
os_cond_destroy(&module->tierup_wait_cond);
}
#endif
if (module->types) {
for (i = 0; i < module->type_count; i++) {
if (module->types[i])
@ -3997,6 +4256,18 @@ wasm_loader_unload(WASMModule *module)
wasm_runtime_free(module->functions[i]->code_compiled);
if (module->functions[i]->consts)
wasm_runtime_free(module->functions[i]->consts);
#endif
#if WASM_ENABLE_FAST_JIT != 0
if (module->functions[i]->fast_jit_jitted_code) {
jit_code_cache_free(
module->functions[i]->fast_jit_jitted_code);
}
#if WASM_ENABLE_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
if (module->functions[i]->llvm_jit_func_ptr) {
jit_code_cache_free(
module->functions[i]->llvm_jit_func_ptr);
}
#endif
#endif
wasm_runtime_free(module->functions[i]);
}
@ -4084,7 +4355,12 @@ wasm_loader_unload(WASMModule *module)
wasm_runtime_free(fast_opcode);
fast_opcode = next;
}
os_mutex_destroy(&module->ref_count_lock);
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0 \
|| (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \
&& WASM_ENABLE_LAZY_JIT != 0)
os_mutex_destroy(&module->instance_list_lock);
#endif
#if WASM_ENABLE_LOAD_CUSTOM_SECTION != 0
@ -4093,12 +4369,14 @@ wasm_loader_unload(WASMModule *module)
#if WASM_ENABLE_FAST_JIT != 0
if (module->fast_jit_func_ptrs) {
for (i = 0; i < module->function_count; i++) {
if (module->fast_jit_func_ptrs[i])
jit_code_cache_free(module->fast_jit_func_ptrs[i]);
}
wasm_runtime_free(module->fast_jit_func_ptrs);
}
for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) {
if (module->fast_jit_thread_locks_inited[i]) {
os_mutex_destroy(&module->fast_jit_thread_locks[i]);
}
}
#endif
wasm_runtime_free(module);