Implement Multi-tier JIT (#1774)

Implement 2-level Multi-tier JIT engine: tier-up from Fast JIT to LLVM JIT to
get quick cold startup by Fast JIT and better performance by gradually
switching to LLVM JIT when the LLVM JIT functions are compiled by the
backend threads.

Refer to:
https://github.com/bytecodealliance/wasm-micro-runtime/issues/1302
This commit is contained in:
Wenyong Huang
2022-12-19 11:24:46 +08:00
parent 7db49db777
commit e8ce4c542e
21 changed files with 2180 additions and 338 deletions

View File

@ -261,6 +261,12 @@ destroy_wasm_type(WASMType *type)
return;
}
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
if (type->call_to_llvm_jit_from_fast_jit)
jit_code_cache_free(type->call_to_llvm_jit_from_fast_jit);
#endif
wasm_runtime_free(type);
}
@ -1783,93 +1789,78 @@ calculate_global_data_offset(WASMModule *module)
module->global_data_size = data_offset;
}
#if WASM_ENABLE_JIT != 0
static void *
orcjit_thread_callback(void *arg)
{
LLVMOrcJITTargetAddress func_addr = 0;
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
WASMModule *module = thread_arg->module;
uint32 group_idx = thread_arg->group_idx;
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
uint32 func_count = module->function_count;
uint32 i, j;
typedef void (*F)(void);
LLVMErrorRef error;
char func_name[48];
union {
F f;
void *v;
} u;
/* Compile jit functions of this group */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
"_wrapper");
LOG_DEBUG("compile func %s", func_name);
error =
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile orc jit function: %s", err_msg);
LLVMDisposeErrorMessage(err_msg);
continue;
}
/* Call the jit wrapper function to trigger its compilation, so as
to compile the actual jit functions, since we add the latter to
function list in the PartitionFunction callback */
u.v = (void *)func_addr;
u.f();
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count)
module->func_ptrs_compiled[i + j * group_stride] = true;
}
if (module->orcjit_stop_compiling) {
break;
}
}
return NULL;
}
static void
orcjit_stop_compile_threads(WASMModule *module)
{
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
/ sizeof(OrcJitThreadArg));
module->orcjit_stop_compiling = true;
for (i = 0; i < thread_num; i++) {
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
}
#if WASM_ENABLE_FAST_JIT != 0
static bool
compile_llvm_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
init_fast_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
#if WASM_ENABLE_LAZY_JIT != 0
JitGlobals *jit_globals = jit_compiler_get_jit_globals();
#endif
uint32 i;
if (!module->function_count)
return true;
if (!(module->fast_jit_func_ptrs =
loader_malloc(sizeof(void *) * module->function_count, error_buf,
error_buf_size))) {
return false;
}
#if WASM_ENABLE_LAZY_JIT != 0
for (i = 0; i < module->function_count; i++) {
module->fast_jit_func_ptrs[i] =
jit_globals->compile_fast_jit_and_then_call;
}
#endif
for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) {
if (os_mutex_init(&module->fast_jit_thread_locks[i]) != 0) {
set_error_buf(error_buf, error_buf_size,
"init fast jit thread lock failed");
return false;
}
module->fast_jit_thread_locks_inited[i] = true;
}
return true;
}
#endif /* end of WASM_ENABLE_FAST_JIT != 0 */
#if WASM_ENABLE_JIT != 0
static bool
init_llvm_jit_functions_stage1(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
AOTCompOption option = { 0 };
char *aot_last_error;
uint64 size;
uint32 thread_num, i;
if (module->function_count > 0) {
size = sizeof(void *) * (uint64)module->function_count
+ sizeof(bool) * (uint64)module->function_count;
if (!(module->func_ptrs =
loader_malloc(size, error_buf, error_buf_size))) {
return false;
}
module->func_ptrs_compiled =
(bool *)((uint8 *)module->func_ptrs
+ sizeof(void *) * module->function_count);
if (module->function_count == 0)
return true;
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LLVM_JIT != 0
if (os_mutex_init(&module->tierup_wait_lock) != 0) {
set_error_buf(error_buf, error_buf_size, "init jit tierup lock failed");
return false;
}
if (os_cond_init(&module->tierup_wait_cond) != 0) {
set_error_buf(error_buf, error_buf_size, "init jit tierup cond failed");
os_mutex_destroy(&module->tierup_wait_lock);
return false;
}
module->tierup_wait_lock_inited = true;
#endif
size = sizeof(void *) * (uint64)module->function_count
+ sizeof(bool) * (uint64)module->function_count;
if (!(module->func_ptrs = loader_malloc(size, error_buf, error_buf_size))) {
return false;
}
module->func_ptrs_compiled =
(bool *)((uint8 *)module->func_ptrs
+ sizeof(void *) * module->function_count);
module->comp_data = aot_create_comp_data(module);
if (!module->comp_data) {
@ -1910,6 +1901,19 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
return false;
}
return true;
}
static bool
init_llvm_jit_functions_stage2(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
char *aot_last_error;
uint32 i;
if (module->function_count == 0)
return true;
if (!aot_compile_wasm(module->comp_ctx)) {
aot_last_error = aot_get_last_error();
bh_assert(aot_last_error != NULL);
@ -1917,7 +1921,12 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
return false;
}
bh_print_time("Begin to lookup jit functions");
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
if (module->orcjit_stop_compiling)
return false;
#endif
bh_print_time("Begin to lookup llvm jit functions");
for (i = 0; i < module->function_count; i++) {
LLVMOrcJITTargetAddress func_addr = 0;
@ -1929,9 +1938,9 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
char buf[128];
char buf[96];
snprintf(buf, sizeof(buf),
"failed to compile orc jit function: %s\n", err_msg);
"failed to compile llvm jit function: %s", err_msg);
set_error_buf(error_buf, error_buf_size, buf);
LLVMDisposeErrorMessage(err_msg);
return false;
@ -1944,17 +1953,211 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
* loading/storing at the same time.
*/
module->func_ptrs[i] = (void *)func_addr;
module->functions[i]->llvm_jit_func_ptr = (void *)func_addr;
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
if (module->orcjit_stop_compiling)
return false;
#endif
}
bh_print_time("End lookup llvm jit functions");
return true;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
static void *
init_llvm_jit_functions_stage2_callback(void *arg)
{
WASMModule *module = (WASMModule *)arg;
char error_buf[128];
uint32 error_buf_size = (uint32)sizeof(error_buf);
if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) {
module->orcjit_stop_compiling = true;
return NULL;
}
os_mutex_lock(&module->tierup_wait_lock);
module->llvm_jit_inited = true;
os_cond_broadcast(&module->tierup_wait_cond);
os_mutex_unlock(&module->tierup_wait_lock);
return NULL;
}
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* The callback function to compile jit functions */
static void *
orcjit_thread_callback(void *arg)
{
OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg;
#if WASM_ENABLE_JIT != 0
AOTCompContext *comp_ctx = thread_arg->comp_ctx;
#endif
WASMModule *module = thread_arg->module;
uint32 group_idx = thread_arg->group_idx;
uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM;
uint32 func_count = module->function_count;
uint32 i;
#if WASM_ENABLE_FAST_JIT != 0
/* Compile fast jit funcitons of this group */
for (i = group_idx; i < func_count; i += group_stride) {
if (!jit_compiler_compile(module, i + module->import_function_count)) {
os_printf("failed to compile fast jit function %u\n", i);
break;
}
if (module->orcjit_stop_compiling) {
return NULL;
}
}
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \
&& WASM_ENABLE_LAZY_JIT != 0
/* For JIT tier-up, set each llvm jit func to call_to_fast_jit */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
uint32 j;
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count) {
if (!jit_compiler_set_call_to_fast_jit(
module,
i + j * group_stride + module->import_function_count)) {
os_printf(
"failed to compile call_to_fast_jit for func %u\n",
i + j * group_stride + module->import_function_count);
module->orcjit_stop_compiling = true;
return NULL;
}
}
if (module->orcjit_stop_compiling) {
return NULL;
}
}
}
/* Wait until init_llvm_jit_functions_stage2 finishes */
os_mutex_lock(&module->tierup_wait_lock);
while (!module->llvm_jit_inited) {
os_cond_reltimedwait(&module->tierup_wait_cond,
&module->tierup_wait_lock, 10);
if (module->orcjit_stop_compiling) {
/* init_llvm_jit_functions_stage2 failed */
os_mutex_unlock(&module->tierup_wait_lock);
return NULL;
}
}
os_mutex_unlock(&module->tierup_wait_lock);
#endif
#if WASM_ENABLE_JIT != 0
/* Compile llvm jit functions of this group */
for (i = group_idx; i < func_count;
i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) {
LLVMOrcJITTargetAddress func_addr = 0;
LLVMErrorRef error;
char func_name[48];
typedef void (*F)(void);
union {
F f;
void *v;
} u;
uint32 j;
snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i,
"_wrapper");
LOG_DEBUG("compile llvm jit func %s", func_name);
error =
LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile llvm jit function %u: %s", i, err_msg);
LLVMDisposeErrorMessage(err_msg);
break;
}
/* Call the jit wrapper function to trigger its compilation, so as
to compile the actual jit functions, since we add the latter to
function list in the PartitionFunction callback */
u.v = (void *)func_addr;
u.f();
for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) {
if (i + j * group_stride < func_count) {
module->func_ptrs_compiled[i + j * group_stride] = true;
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX,
i + j * group_stride);
error = LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr,
func_name);
if (error != LLVMErrorSuccess) {
char *err_msg = LLVMGetErrorMessage(error);
os_printf("failed to compile llvm jit function %u: %s", i,
err_msg);
LLVMDisposeErrorMessage(err_msg);
/* Ignore current llvm jit func, as its func ptr is
previous set to call_to_fast_jit, which also works */
continue;
}
jit_compiler_set_llvm_jit_func_ptr(
module,
i + j * group_stride + module->import_function_count,
(void *)func_addr);
/* Try to switch to call this llvm jit funtion instead of
fast jit function from fast jit jitted code */
jit_compiler_set_call_to_llvm_jit(
module,
i + j * group_stride + module->import_function_count);
#endif
}
}
if (module->orcjit_stop_compiling) {
break;
}
}
#endif
return NULL;
}
static void
orcjit_stop_compile_threads(WASMModule *module)
{
uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args)
/ sizeof(OrcJitThreadArg));
module->orcjit_stop_compiling = true;
for (i = 0; i < thread_num; i++) {
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
}
static bool
compile_jit_functions(WASMModule *module, char *error_buf,
uint32 error_buf_size)
{
uint32 thread_num =
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
uint32 i, j;
bh_print_time("Begin to compile jit functions");
thread_num =
(uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg));
/* Create threads to compile the jit functions */
for (i = 0; i < thread_num; i++) {
for (i = 0; i < thread_num && i < module->function_count; i++) {
#if WASM_ENABLE_JIT != 0
module->orcjit_thread_args[i].comp_ctx = module->comp_ctx;
#endif
module->orcjit_thread_args[i].module = module;
module->orcjit_thread_args[i].group_idx = i;
@ -1962,8 +2165,6 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
(void *)&module->orcjit_thread_args[i],
APP_THREAD_STACK_SIZE_DEFAULT)
!= 0) {
uint32 j;
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
/* Terminate the threads created */
@ -1978,15 +2179,39 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf,
#if WASM_ENABLE_LAZY_JIT == 0
/* Wait until all jit functions are compiled for eager mode */
for (i = 0; i < thread_num; i++) {
os_thread_join(module->orcjit_threads[i], NULL);
if (module->orcjit_threads[i])
os_thread_join(module->orcjit_threads[i], NULL);
}
#if WASM_ENABLE_FAST_JIT != 0
/* Ensure all the fast-jit functions are compiled */
for (i = 0; i < module->function_count; i++) {
if (!jit_compiler_is_compiled(module,
i + module->import_function_count)) {
set_error_buf(error_buf, error_buf_size,
"failed to compile fast jit function");
return false;
}
}
#endif
#if WASM_ENABLE_JIT != 0
/* Ensure all the llvm-jit functions are compiled */
for (i = 0; i < module->function_count; i++) {
if (!module->func_ptrs_compiled[i]) {
set_error_buf(error_buf, error_buf_size,
"failed to compile llvm jit function");
return false;
}
}
#endif
#endif /* end of WASM_ENABLE_LAZY_JIT == 0 */
bh_print_time("End compile jit functions");
return true;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#endif /* end of WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 */
#if WASM_ENABLE_REF_TYPES != 0
static bool
@ -2427,22 +2652,41 @@ load_from_sections(WASMModule *module, WASMSection *sections,
calculate_global_data_offset(module);
#if WASM_ENABLE_FAST_JIT != 0
if (!(module->fast_jit_func_ptrs =
loader_malloc(sizeof(void *) * module->function_count, error_buf,
error_buf_size))) {
return false;
}
if (!jit_compiler_compile_all(module)) {
set_error_buf(error_buf, error_buf_size, "fast jit compilation failed");
if (!init_fast_jit_functions(module, error_buf, error_buf_size)) {
return false;
}
#endif
#if WASM_ENABLE_JIT != 0
if (!compile_llvm_jit_functions(module, error_buf, error_buf_size)) {
if (!init_llvm_jit_functions_stage1(module, error_buf, error_buf_size)) {
return false;
}
#endif /* end of WASM_ENABLE_JIT != 0 */
#if !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0)
if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) {
return false;
}
#else
/* Run aot_compile_wasm in a backend thread, so as not to block the main
thread fast jit execution, since applying llvm optimizations in
aot_compile_wasm may cost a lot of time.
Create thread with enough native stack to apply llvm optimizations */
if (os_thread_create(&module->llvm_jit_init_thread,
init_llvm_jit_functions_stage2_callback,
(void *)module, APP_THREAD_STACK_SIZE_DEFAULT * 8)
!= 0) {
set_error_buf(error_buf, error_buf_size,
"create orcjit compile thread failed");
return false;
}
#endif
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* Create threads to compile the jit functions */
if (!compile_jit_functions(module, error_buf, error_buf_size)) {
return false;
}
#endif
#if WASM_ENABLE_MEMORY_TRACING != 0
wasm_runtime_dump_module_mem_consumption(module);
@ -2455,9 +2699,7 @@ create_module(char *error_buf, uint32 error_buf_size)
{
WASMModule *module =
loader_malloc(sizeof(WASMModule), error_buf, error_buf_size);
#if WASM_ENABLE_FAST_INTERP == 0
bh_list_status ret;
#endif
if (!module) {
return NULL;
@ -2472,9 +2714,18 @@ create_module(char *error_buf, uint32 error_buf_size)
module->br_table_cache_list = &module->br_table_cache_list_head;
ret = bh_list_init(module->br_table_cache_list);
bh_assert(ret == BH_LIST_SUCCESS);
(void)ret;
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
if (os_mutex_init(&module->instance_list_lock) != 0) {
set_error_buf(error_buf, error_buf_size,
"init instance list lock failed");
wasm_runtime_free(module);
return NULL;
}
#endif
(void)ret;
return module;
}
@ -2686,10 +2937,19 @@ wasm_loader_unload(WASMModule *module)
if (!module)
return;
#if WASM_ENABLE_JIT != 0
/* Stop LLVM JIT compilation firstly to avoid accessing
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
module->orcjit_stop_compiling = true;
if (module->llvm_jit_init_thread)
os_thread_join(module->llvm_jit_init_thread, NULL);
#endif
#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0
/* Stop Fast/LLVM JIT compilation firstly to avoid accessing
module internal data after they were freed */
orcjit_stop_compile_threads(module);
#endif
#if WASM_ENABLE_JIT != 0
if (module->func_ptrs)
wasm_runtime_free(module->func_ptrs);
if (module->comp_ctx)
@ -2698,6 +2958,13 @@ wasm_loader_unload(WASMModule *module)
aot_destroy_comp_data(module->comp_data);
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
if (module->tierup_wait_lock_inited) {
os_mutex_destroy(&module->tierup_wait_lock);
os_cond_destroy(&module->tierup_wait_cond);
}
#endif
if (module->types) {
for (i = 0; i < module->type_count; i++) {
if (module->types[i])
@ -2719,6 +2986,18 @@ wasm_loader_unload(WASMModule *module)
wasm_runtime_free(module->functions[i]->code_compiled);
if (module->functions[i]->consts)
wasm_runtime_free(module->functions[i]->consts);
#endif
#if WASM_ENABLE_FAST_JIT != 0
if (module->functions[i]->fast_jit_jitted_code) {
jit_code_cache_free(
module->functions[i]->fast_jit_jitted_code);
}
#if WASM_ENABLE_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0
if (module->functions[i]->llvm_jit_func_ptr) {
jit_code_cache_free(
module->functions[i]->llvm_jit_func_ptr);
}
#endif
#endif
wasm_runtime_free(module->functions[i]);
}
@ -2775,14 +3054,20 @@ wasm_loader_unload(WASMModule *module)
}
#endif
#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0
os_mutex_destroy(&module->instance_list_lock);
#endif
#if WASM_ENABLE_FAST_JIT != 0
if (module->fast_jit_func_ptrs) {
for (i = 0; i < module->function_count; i++) {
if (module->fast_jit_func_ptrs[i])
jit_code_cache_free(module->fast_jit_func_ptrs[i]);
}
wasm_runtime_free(module->fast_jit_func_ptrs);
}
for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) {
if (module->fast_jit_thread_locks_inited[i]) {
os_mutex_destroy(&module->fast_jit_thread_locks[i]);
}
}
#endif
wasm_runtime_free(module);
@ -5572,12 +5857,6 @@ re_scan:
if (depth > 255) {
/* The depth cannot be stored in one byte,
create br_table cache to store each depth */
#if WASM_ENABLE_DEBUG_INTERP != 0
if (!record_fast_op(module, p_org, *p_org,
error_buf, error_buf_size)) {
goto fail;
}
#endif
if (!(br_table_cache = loader_malloc(
offsetof(BrTableCache, br_depths)
+ sizeof(uint32)