Refine AOT/JIT code call wasm-c-api import process (#2982)

Allow to invoke the quick call entry wasm_runtime_quick_invoke_c_api_import to
call the wasm-c-api import functions to speedup the calling process, which reduces
the data copying.

Use `wamrc --invoke-c-api-import` to generate the optimized AOT code, and set
`jit_options->quick_invoke_c_api_import` true in wasm_engine_new when LLVM JIT
is enabled.
This commit is contained in:
Wenyong Huang
2024-01-10 18:37:02 +08:00
committed by GitHub
parent 7c7684819d
commit b21f17dd6d
20 changed files with 393 additions and 35 deletions

View File

@ -288,6 +288,213 @@ call_aot_invoke_native_func(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
return true;
}
static bool
call_aot_invoke_c_api_native(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 import_func_idx, AOTFuncType *aot_func_type,
LLVMValueRef *params)
{
LLVMTypeRef int8_ptr_type, param_types[6], ret_type;
LLVMTypeRef value_ptr_type = NULL, value_type = NULL;
LLVMTypeRef func_type, func_ptr_type;
LLVMValueRef param_values[6], res, func, value = NULL, offset;
LLVMValueRef c_api_func_imports, c_api_func_import;
LLVMValueRef c_api_params, c_api_results, value_ret;
LLVMValueRef c_api_param_kind, c_api_param_value;
LLVMValueRef c_api_result_value;
uint32 offset_c_api_func_imports, i;
uint32 offset_param_kind, offset_param_value;
char buf[16];
/* `int8 **` type */
int8_ptr_type = LLVMPointerType(INT8_PTR_TYPE, 0);
if (!int8_ptr_type) {
aot_set_last_error("create llvm pointer type failed");
return false;
}
param_types[0] = INT8_PTR_TYPE; /* module_inst */
param_types[1] = INT8_PTR_TYPE; /* CApiFuncImport *c_api_import */
param_types[2] = INT8_PTR_TYPE; /* wasm_val_t *params */
param_types[3] = I32_TYPE; /* uint32 param_count */
param_types[4] = INT8_PTR_TYPE; /* wasm_val_t *results */
param_types[5] = I32_TYPE; /* uint32 result_count */
ret_type = INT8_TYPE;
GET_AOT_FUNCTION(wasm_runtime_quick_invoke_c_api_native, 6);
param_values[0] = func_ctx->aot_inst;
/* Get module_inst->e->common.c_api_func_imports */
offset_c_api_func_imports =
get_module_inst_extra_offset(comp_ctx)
+ (comp_ctx->is_jit_mode
? offsetof(WASMModuleInstanceExtra, common.c_api_func_imports)
/* offsetof(AOTModuleInstanceExtra, common.c_api_func_imports) */
: sizeof(uint64));
offset = I32_CONST(offset_c_api_func_imports);
CHECK_LLVM_CONST(offset);
c_api_func_imports =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, func_ctx->aot_inst,
&offset, 1, "c_api_func_imports_addr");
c_api_func_imports =
LLVMBuildBitCast(comp_ctx->builder, c_api_func_imports, int8_ptr_type,
"c_api_func_imports_ptr");
c_api_func_imports =
LLVMBuildLoad2(comp_ctx->builder, INT8_PTR_TYPE, c_api_func_imports,
"c_api_func_imports");
/* Get &c_api_func_imports[func_idx], note size of CApiFuncImport
is pointer_size * 3 */
offset = I32_CONST((comp_ctx->pointer_size * 3) * import_func_idx);
CHECK_LLVM_CONST(offset);
c_api_func_import =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, c_api_func_imports,
&offset, 1, "c_api_func_import");
param_values[1] = c_api_func_import;
param_values[2] = c_api_params = func_ctx->argv_buf;
param_values[3] = I32_CONST(aot_func_type->param_count);
CHECK_LLVM_CONST(param_values[3]);
/* Ensure sizeof(wasm_val_t) is 16 bytes */
offset = I32_CONST(sizeof(wasm_val_t) * aot_func_type->param_count);
c_api_results =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, func_ctx->argv_buf,
&offset, 1, "results");
param_values[4] = c_api_results;
param_values[5] = I32_CONST(aot_func_type->result_count);
CHECK_LLVM_CONST(param_values[5]);
/* Set each c api param */
for (i = 0; i < aot_func_type->param_count; i++) {
/* Ensure sizeof(wasm_val_t) is 16 bytes */
offset_param_kind = sizeof(wasm_val_t) * i;
offset = I32_CONST(offset_param_kind);
CHECK_LLVM_CONST(offset);
c_api_param_kind =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, c_api_params,
&offset, 1, "c_api_param_kind_addr");
c_api_param_kind =
LLVMBuildBitCast(comp_ctx->builder, c_api_param_kind, INT8_PTR_TYPE,
"c_api_param_kind_ptr");
switch (aot_func_type->types[i]) {
case VALUE_TYPE_I32:
value = I8_CONST(WASM_I32);
break;
case VALUE_TYPE_F32:
value = I8_CONST(WASM_F32);
break;
case VALUE_TYPE_I64:
value = I8_CONST(WASM_I64);
break;
case VALUE_TYPE_F64:
value = I8_CONST(WASM_F64);
break;
default:
bh_assert(0);
break;
}
CHECK_LLVM_CONST(value);
LLVMBuildStore(comp_ctx->builder, value, c_api_param_kind);
/* Ensure offsetof(wasm_val_t, of) is 8 bytes */
offset_param_value = offset_param_kind + offsetof(wasm_val_t, of);
offset = I32_CONST(offset_param_value);
CHECK_LLVM_CONST(offset);
c_api_param_value =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, c_api_params,
&offset, 1, "c_api_param_value_addr");
switch (aot_func_type->types[i]) {
case VALUE_TYPE_I32:
value_ptr_type = INT32_PTR_TYPE;
break;
case VALUE_TYPE_F32:
value_ptr_type = F32_PTR_TYPE;
break;
case VALUE_TYPE_I64:
value_ptr_type = INT64_PTR_TYPE;
break;
case VALUE_TYPE_F64:
value_ptr_type = F64_PTR_TYPE;
break;
default:
bh_assert(0);
break;
}
c_api_param_value =
LLVMBuildBitCast(comp_ctx->builder, c_api_param_value,
value_ptr_type, "c_api_param_value_ptr");
LLVMBuildStore(comp_ctx->builder, params[i], c_api_param_value);
}
/* Call the function */
if (!(res = LLVMBuildCall2(comp_ctx->builder, func_type, func, param_values,
6, "call"))) {
aot_set_last_error("LLVM build call failed.");
goto fail;
}
/* Check whether exception was thrown when executing the function */
if (comp_ctx->enable_bound_check
&& !check_call_return(comp_ctx, func_ctx, res)) {
goto fail;
}
for (i = 0; i < aot_func_type->result_count; i++) {
/* Ensure sizeof(wasm_val_t) is 16 bytes and
offsetof(wasm_val_t, of) is 8 bytes */
uint32 offset_result_value =
sizeof(wasm_val_t) * i + offsetof(wasm_val_t, of);
offset = I32_CONST(offset_result_value);
CHECK_LLVM_CONST(offset);
c_api_result_value =
LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE, c_api_results,
&offset, 1, "c_api_result_value_addr");
switch (aot_func_type->types[aot_func_type->param_count + i]) {
case VALUE_TYPE_I32:
value_type = I32_TYPE;
value_ptr_type = INT32_PTR_TYPE;
break;
case VALUE_TYPE_F32:
value_type = F32_TYPE;
value_ptr_type = F32_PTR_TYPE;
break;
case VALUE_TYPE_I64:
value_type = I64_TYPE;
value_ptr_type = INT64_PTR_TYPE;
break;
case VALUE_TYPE_F64:
value_type = F64_TYPE;
value_ptr_type = F64_PTR_TYPE;
break;
default:
bh_assert(0);
break;
}
c_api_result_value =
LLVMBuildBitCast(comp_ctx->builder, c_api_result_value,
value_ptr_type, "c_api_result_value_ptr");
snprintf(buf, sizeof(buf), "%s%u", "ret", i);
value_ret = LLVMBuildLoad2(comp_ctx->builder, value_type,
c_api_result_value, buf);
PUSH(value_ret, aot_func_type->types[aot_func_type->param_count + i]);
}
return true;
fail:
return false;
}
#if (WASM_ENABLE_DUMP_CALL_STACK != 0) || (WASM_ENABLE_PERF_PROFILING != 0)
static bool
call_aot_alloc_frame_func(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
@ -533,6 +740,7 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
const char *signature = NULL;
bool ret = false;
char buf[32];
bool quick_invoke_c_api_import = false;
#if WASM_ENABLE_THREAD_MGR != 0
/* Insert suspend check point */
@ -702,17 +910,43 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
}
if (!signature) {
/* call aot_invoke_native() */
if (!call_aot_invoke_native_func(
comp_ctx, func_ctx, import_func_idx, func_type,
param_types + 1, param_values + 1, param_count,
param_cell_num, ret_type, wasm_ret_type, &value_ret, &res))
goto fail;
/* Check whether there was exception thrown when executing
the function */
if ((comp_ctx->enable_bound_check || is_win_platform(comp_ctx))
&& !check_call_return(comp_ctx, func_ctx, res))
goto fail;
if (comp_ctx->quick_invoke_c_api_import) {
uint32 buf_size_needed =
sizeof(wasm_val_t) * (param_count + result_count);
/* length of exec_env->argv_buf is 64 */
if (buf_size_needed < sizeof(uint32) * 64) {
for (i = 0; i < param_count + result_count; i++) {
/* Only support i32/i64/f32/f64 now */
if (!(func_type->types[i] == VALUE_TYPE_I32
|| func_type->types[i] == VALUE_TYPE_I64
|| func_type->types[i] == VALUE_TYPE_F32
|| func_type->types[i] == VALUE_TYPE_F64))
break;
}
if (i == param_count + result_count)
quick_invoke_c_api_import = true;
}
}
if (quick_invoke_c_api_import) {
if (!call_aot_invoke_c_api_native(comp_ctx, func_ctx, func_idx,
func_type, param_values + 1))
goto fail;
}
else {
/* call aot_invoke_native() */
if (!call_aot_invoke_native_func(
comp_ctx, func_ctx, import_func_idx, func_type,
param_types + 1, param_values + 1, param_count,
param_cell_num, ret_type, wasm_ret_type, &value_ret,
&res))
goto fail;
/* Check whether there was exception thrown when executing
the function */
if ((comp_ctx->enable_bound_check || is_win_platform(comp_ctx))
&& !check_call_return(comp_ctx, func_ctx, res))
goto fail;
}
}
else { /* call native func directly */
LLVMTypeRef native_func_type, func_ptr_type;
@ -869,7 +1103,7 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
goto fail;
}
if (func_type->result_count > 0) {
if (func_type->result_count > 0 && !quick_invoke_c_api_import) {
/* Push the first result to stack */
PUSH(value_ret, func_type->types[func_type->param_count]);
/* Load extra result from its address and push to stack */