AOT/JIT native stack bound check improvement (#2244)

Move the native stack overflow check from the caller to the callee because the
former doesn't work for call_indirect and imported functions.

Make the stack usage estimation more accurate. Instead of making a guess from
the number of wasm locals in the function, use the LLVM's idea of the stack size
of each MachineFunction. The former is inaccurate because a) it doesn't reflect
optimization passes, and b) wasm locals are not the only reason to use stack.

To use the post-compilation stack usage information without requiring 2-pass
compilation or machine-code imm rewriting, introduce a global array to store
stack consumption of each functions:
For JIT, use a custom IRCompiler with an extra pass to fill the array.
For AOT, use `clang -fstack-usage` equivalent because we support external llc.

Re-implement function call stack usage estimation to reflect the real calling
conventions better. (aot_estimate_stack_usage_for_function_call)

Re-implement stack estimation logic (--enable-memory-profiling) based on the new
machinery.

Discussions: #2105.
This commit is contained in:
YAMAMOTO Takashi
2023-06-22 08:27:07 +09:00
committed by GitHub
parent 8797c751a5
commit cd7941cc39
12 changed files with 1348 additions and 252 deletions

View File

@ -366,143 +366,6 @@ fail:
#endif /* end of (WASM_ENABLE_DUMP_CALL_STACK != 0) \
|| (WASM_ENABLE_PERF_PROFILING != 0) */
static bool
record_stack_usage(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 callee_cell_num)
{
LLVMBasicBlockRef block_curr = LLVMGetInsertBlock(comp_ctx->builder);
LLVMBasicBlockRef block_update;
LLVMBasicBlockRef block_after_update;
LLVMValueRef callee_local_size, new_sp, cmp;
LLVMValueRef native_stack_top_min;
LLVMTypeRef ptrdiff_type;
if (comp_ctx->pointer_size == sizeof(uint64_t)) {
ptrdiff_type = I64_TYPE;
}
else {
ptrdiff_type = I32_TYPE;
}
/*
* new_sp = last_alloca - callee_local_size;
* if (*native_stack_top_min_addr > new_sp) {
* *native_stack_top_min_addr = new_sp;
* }
*/
if (!(callee_local_size = LLVMConstInt(
ptrdiff_type, -(int64_t)callee_cell_num * 4, true))) {
aot_set_last_error("llvm build const failed.");
return false;
}
if (!(new_sp = LLVMBuildInBoundsGEP2(comp_ctx->builder, INT8_TYPE,
func_ctx->last_alloca,
&callee_local_size, 1, "new_sp"))) {
aot_set_last_error("llvm build gep failed");
return false;
}
if (!(native_stack_top_min = LLVMBuildLoad2(
comp_ctx->builder, OPQ_PTR_TYPE,
func_ctx->native_stack_top_min_addr, "native_stack_top_min"))) {
aot_set_last_error("llvm build load failed");
return false;
}
if (!(cmp = LLVMBuildICmp(comp_ctx->builder, LLVMIntULT, new_sp,
native_stack_top_min, "cmp"))) {
aot_set_last_error("llvm build icmp failed.");
return false;
}
if (!(block_update = LLVMAppendBasicBlockInContext(
comp_ctx->context, func_ctx->func, "block_update"))) {
aot_set_last_error("llvm add basic block failed.");
return false;
}
if (!(block_after_update = LLVMAppendBasicBlockInContext(
comp_ctx->context, func_ctx->func, "block_after_update"))) {
aot_set_last_error("llvm add basic block failed.");
return false;
}
LLVMMoveBasicBlockAfter(block_update, block_curr);
LLVMMoveBasicBlockAfter(block_after_update, block_update);
if (!LLVMBuildCondBr(comp_ctx->builder, cmp, block_update,
block_after_update)) {
aot_set_last_error("llvm build cond br failed.");
return false;
}
LLVMPositionBuilderAtEnd(comp_ctx->builder, block_update);
if (!LLVMBuildStore(comp_ctx->builder, new_sp,
func_ctx->native_stack_top_min_addr)) {
aot_set_last_error("llvm build store failed");
return false;
}
if (!LLVMBuildBr(comp_ctx->builder, block_after_update)) {
aot_set_last_error("llvm build br failed.");
return false;
}
LLVMPositionBuilderAtEnd(comp_ctx->builder, block_after_update);
return true;
}
static bool
check_stack_boundary(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 callee_cell_num)
{
LLVMBasicBlockRef block_curr = LLVMGetInsertBlock(comp_ctx->builder);
LLVMBasicBlockRef check_stack;
LLVMValueRef callee_local_size, stack_bound, cmp;
if (!(callee_local_size = I32_CONST(callee_cell_num * 4))) {
aot_set_last_error("llvm build const failed.");
return false;
}
if (!(stack_bound = LLVMBuildInBoundsGEP2(
comp_ctx->builder, INT8_TYPE, func_ctx->native_stack_bound,
&callee_local_size, 1, "stack_bound"))) {
aot_set_last_error("llvm build inbound gep failed.");
return false;
}
if (!(check_stack = LLVMAppendBasicBlockInContext(
comp_ctx->context, func_ctx->func, "check_stack"))) {
aot_set_last_error("llvm add basic block failed.");
return false;
}
LLVMMoveBasicBlockAfter(check_stack, block_curr);
if (!(cmp = LLVMBuildICmp(comp_ctx->builder, LLVMIntULT,
func_ctx->last_alloca, stack_bound, "cmp"))) {
aot_set_last_error("llvm build icmp failed.");
return false;
}
if (!aot_emit_exception(comp_ctx, func_ctx, EXCE_NATIVE_STACK_OVERFLOW,
true, cmp, check_stack)) {
return false;
}
LLVMPositionBuilderAtEnd(comp_ctx->builder, check_stack);
return true;
}
static bool
check_stack(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 callee_cell_num)
{
if (comp_ctx->enable_stack_estimation
&& !record_stack_usage(comp_ctx, func_ctx, callee_cell_num))
return false;
if (comp_ctx->enable_stack_bound_check
&& !check_stack_boundary(comp_ctx, func_ctx, callee_cell_num))
return false;
return true;
}
/**
* Check whether the app address and its buffer are inside the linear memory,
* if no, throw exception
@ -610,6 +473,30 @@ check_app_addr_and_convert(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
return true;
}
static void
aot_estimate_and_record_stack_usage_for_function_call(
const AOTCompContext *comp_ctx, AOTFuncContext *caller_func_ctx,
const AOTFuncType *callee_func_type)
{
unsigned int size;
if (!(comp_ctx->enable_stack_bound_check
|| comp_ctx->enable_stack_estimation)) {
return;
}
size =
aot_estimate_stack_usage_for_function_call(comp_ctx, callee_func_type);
/*
* only record the max value, assuming that LLVM emits machine code
* which rewinds the stack before making the next call in the
* function.
*/
if (caller_func_ctx->stack_consumption_for_func_call < size) {
caller_func_ctx->stack_consumption_for_func_call = size;
}
}
bool
aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 func_idx, bool tail_call)
@ -620,7 +507,6 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
uint32 ext_ret_cell_num = 0, cell_num = 0;
AOTFuncContext **func_ctxes = comp_ctx->func_ctxes;
AOTFuncType *func_type;
AOTFunc *aot_func;
LLVMTypeRef *param_types = NULL, ret_type;
LLVMTypeRef ext_ret_ptr_type;
LLVMValueRef *param_values = NULL, value_ret = NULL, func;
@ -628,7 +514,6 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
LLVMValueRef ext_ret, ext_ret_ptr, ext_ret_idx;
int32 i, j = 0, param_count, result_count, ext_ret_count;
uint64 total_size;
uint32 callee_cell_num;
uint8 wasm_ret_type;
uint8 *ext_ret_types = NULL;
const char *signature = NULL;
@ -658,6 +543,8 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
func_type =
func_ctxes[func_idx - import_func_count]->aot_func->func_type;
}
aot_estimate_and_record_stack_usage_for_function_call(comp_ctx, func_ctx,
func_type);
/* Get param cell number */
param_cell_num = func_type->param_cell_num;
@ -885,15 +772,17 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
else {
if (func_ctxes[func_idx - import_func_count] == func_ctx) {
/* recursive call */
func = func_ctx->func;
func = func_ctx->precheck_func;
}
else {
if (!comp_ctx->is_jit_mode) {
func = func_ctxes[func_idx - import_func_count]->func;
func =
func_ctxes[func_idx - import_func_count]->precheck_func;
}
else {
#if !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0)
func = func_ctxes[func_idx - import_func_count]->func;
func =
func_ctxes[func_idx - import_func_count]->precheck_func;
#else
/* JIT tier-up, load func ptr from func_ptrs[func_idx] */
LLVMValueRef func_ptr, func_idx_const;
@ -938,13 +827,6 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
}
}
aot_func = func_ctxes[func_idx - import_func_count]->aot_func;
callee_cell_num =
aot_func->param_cell_num + aot_func->local_cell_num + 1;
if (!check_stack(comp_ctx, func_ctx, callee_cell_num))
goto fail;
#if LLVM_VERSION_MAJOR >= 14
llvm_func_type = func_ctxes[func_idx - import_func_count]->func_type;
#endif
@ -1213,6 +1095,8 @@ aot_compile_op_call_indirect(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
CHECK_LLVM_CONST(ftype_idx_const);
func_type = comp_ctx->comp_data->func_types[type_idx];
aot_estimate_and_record_stack_usage_for_function_call(comp_ctx, func_ctx,
func_type);
func_param_count = func_type->param_count;
func_result_count = func_type->result_count;
@ -1564,13 +1448,6 @@ aot_compile_op_call_indirect(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx,
/* Translate call non-import block */
LLVMPositionBuilderAtEnd(comp_ctx->builder, block_call_non_import);
if (!check_stack(comp_ctx, func_ctx,
param_cell_num + ext_cell_num
+ 1
/* Reserve some local variables */
+ 16))
goto fail;
/* Load function pointer */
if (!(func_ptr = LLVMBuildInBoundsGEP2(comp_ctx->builder, OPQ_PTR_TYPE,
func_ctx->func_ptrs, &func_idx, 1,