Support extended constant expressions (#4432)

* implement extended const expr (#4318)
* add a toggle to enable extended const on wamrc (#4412)
This commit is contained in:
TianlongLiang
2025-07-07 13:34:02 +08:00
committed by GitHub
parent be33a40ba7
commit 7d05dbc988
28 changed files with 1734 additions and 379 deletions

View File

@ -261,6 +261,9 @@ typedef struct InitValue {
uint8 type;
uint8 flag;
WASMValue value;
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
InitializerExpression *expr;
#endif
} InitValue;
typedef struct ConstExprContext {
@ -282,7 +285,11 @@ init_const_expr_stack(ConstExprContext *ctx, WASMModule *module)
static bool
push_const_expr_stack(ConstExprContext *ctx, uint8 flag, uint8 type,
WASMValue *value, char *error_buf, uint32 error_buf_size)
WASMValue *value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
InitializerExpression *expr,
#endif
char *error_buf, uint32 error_buf_size)
{
InitValue *cur_value;
@ -305,6 +312,9 @@ push_const_expr_stack(ConstExprContext *ctx, uint8 flag, uint8 type,
cur_value->type = type;
cur_value->flag = flag;
cur_value->value = *value;
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
cur_value->expr = expr;
#endif
return true;
fail:
@ -313,7 +323,11 @@ fail:
static bool
pop_const_expr_stack(ConstExprContext *ctx, uint8 *p_flag, uint8 type,
WASMValue *p_value, char *error_buf, uint32 error_buf_size)
WASMValue *p_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
InitializerExpression **p_expr,
#endif
char *error_buf, uint32 error_buf_size)
{
InitValue *cur_value;
@ -331,18 +345,50 @@ pop_const_expr_stack(ConstExprContext *ctx, uint8 *p_flag, uint8 type,
*p_flag = cur_value->flag;
if (p_value)
*p_value = cur_value->value;
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
if (p_expr)
*p_expr = cur_value->expr;
#endif
return true;
}
static void
destroy_const_expr_stack(ConstExprContext *ctx)
destroy_const_expr_stack(ConstExprContext *ctx, bool free_exprs)
{
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
if (free_exprs) {
for (uint32 j = 0; j < ctx->sp; j++) {
if (is_expr_binary_op(ctx->stack[j].expr->init_expr_type)) {
destroy_init_expr_recursive(ctx->stack[j].expr);
ctx->stack[j].expr = NULL;
}
}
}
#endif
if (ctx->stack != ctx->data) {
wasm_runtime_free(ctx->stack);
}
}
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
static void
destroy_init_expr(InitializerExpression *expr)
{
// free left expr and right exprs for binary oprand
if (is_expr_binary_op(expr->init_expr_type)) {
return;
}
if (expr->u.binary.l_expr) {
destroy_init_expr_recursive(expr->u.binary.l_expr);
}
if (expr->u.binary.r_expr) {
destroy_init_expr_recursive(expr->u.binary.r_expr);
}
expr->u.binary.l_expr = expr->u.binary.r_expr = NULL;
}
#endif /* end of WASM_ENABLE_EXTENDED_CONST_EXPR != 0 */
static bool
load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
InitializerExpression *init_expr, uint8 type, char *error_buf,
@ -353,6 +399,9 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
uint32 i;
ConstExprContext const_expr_ctx = { 0 };
WASMValue cur_value = { 0 };
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
InitializerExpression *cur_expr = NULL;
#endif
init_const_expr_stack(&const_expr_ctx, module);
@ -367,8 +416,11 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
if (!push_const_expr_stack(&const_expr_ctx, flag,
VALUE_TYPE_I32, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
break;
/* i64.const */
@ -377,8 +429,11 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
if (!push_const_expr_stack(&const_expr_ctx, flag,
VALUE_TYPE_I64, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
break;
/* f32.const */
@ -390,8 +445,11 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
if (!push_const_expr_stack(&const_expr_ctx, flag,
VALUE_TYPE_F32, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
break;
/* f64.const */
@ -403,8 +461,11 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
if (!push_const_expr_stack(&const_expr_ctx, flag,
VALUE_TYPE_F64, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
break;
@ -417,13 +478,16 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
cur_value.ref_index = func_idx;
if (!check_function_index(module, func_idx, error_buf,
error_buf_size)) {
bh_assert(0);
goto fail;
}
if (!push_const_expr_stack(&const_expr_ctx, flag,
VALUE_TYPE_FUNCREF, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
break;
}
@ -438,9 +502,12 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
cur_value.ref_index = UINT32_MAX;
if (!push_const_expr_stack(&const_expr_ctx, flag, type1,
&cur_value, error_buf,
error_buf_size)) {
bh_assert(0);
&cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size)) {
goto fail;
}
break;
}
@ -471,15 +538,93 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
}
if (!push_const_expr_stack(&const_expr_ctx, flag, global_type,
&cur_value, error_buf,
error_buf_size))
bh_assert(0);
&cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
NULL,
#endif
error_buf, error_buf_size))
goto fail;
break;
}
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
case INIT_EXPR_TYPE_I32_ADD:
case INIT_EXPR_TYPE_I64_ADD:
case INIT_EXPR_TYPE_I32_SUB:
case INIT_EXPR_TYPE_I64_SUB:
case INIT_EXPR_TYPE_I32_MUL:
case INIT_EXPR_TYPE_I64_MUL:
{
InitializerExpression *l_expr, *r_expr;
WASMValue l_value, r_value;
uint8 l_flag, r_flag;
uint8 value_type;
if (flag == INIT_EXPR_TYPE_I32_ADD
|| flag == INIT_EXPR_TYPE_I32_SUB
|| flag == INIT_EXPR_TYPE_I32_MUL) {
value_type = VALUE_TYPE_I32;
}
else {
value_type = VALUE_TYPE_I64;
}
/* If right flag indicates a binary operation, right expr will
* be popped from stack. Otherwise, allocate a new expr for
* right expr. Same for left expr.
*/
if (!(pop_const_expr_stack(&const_expr_ctx, &r_flag, value_type,
&r_value, &r_expr, error_buf,
error_buf_size))) {
goto fail;
}
if (!is_expr_binary_op(r_flag)) {
if (!(r_expr = loader_malloc(sizeof(InitializerExpression),
error_buf, error_buf_size))) {
goto fail;
}
r_expr->init_expr_type = r_flag;
r_expr->u.unary.v = r_value;
}
if (!(pop_const_expr_stack(&const_expr_ctx, &l_flag, value_type,
&l_value, &l_expr, error_buf,
error_buf_size))) {
destroy_init_expr_recursive(r_expr);
goto fail;
}
if (!is_expr_binary_op(l_flag)) {
if (!(l_expr = loader_malloc(sizeof(InitializerExpression),
error_buf, error_buf_size))) {
destroy_init_expr_recursive(r_expr);
goto fail;
}
l_expr->init_expr_type = l_flag;
l_expr->u.unary.v = l_value;
}
if (!(cur_expr = loader_malloc(sizeof(InitializerExpression),
error_buf, error_buf_size))) {
destroy_init_expr_recursive(l_expr);
destroy_init_expr_recursive(r_expr);
goto fail;
}
cur_expr->init_expr_type = flag;
cur_expr->u.binary.l_expr = l_expr;
cur_expr->u.binary.r_expr = r_expr;
if (!push_const_expr_stack(&const_expr_ctx, flag, value_type,
&cur_value, cur_expr, error_buf,
error_buf_size)) {
destroy_init_expr_recursive(cur_expr);
goto fail;
}
break;
}
#endif
default:
{
bh_assert(0);
goto fail;
}
}
@ -489,18 +634,42 @@ load_init_expr(WASMModule *module, const uint8 **p_buf, const uint8 *buf_end,
/* There should be only one value left on the init value stack */
if (!pop_const_expr_stack(&const_expr_ctx, &flag, type, &cur_value,
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
&cur_expr,
#endif
error_buf, error_buf_size)) {
bh_assert(0);
goto fail;
}
bh_assert(const_expr_ctx.sp == 0);
if (const_expr_ctx.sp != 0) {
set_error_buf(error_buf, error_buf_size,
"type mismatch: illegal constant opcode sequence");
goto fail;
}
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
if (cur_expr != NULL) {
bh_memcpy_s(init_expr, sizeof(InitializerExpression), cur_expr,
sizeof(InitializerExpression));
wasm_runtime_free(cur_expr);
}
else {
init_expr->init_expr_type = flag;
init_expr->u.unary.v = cur_value;
}
#else
init_expr->init_expr_type = flag;
init_expr->u = cur_value;
init_expr->u.unary.v = cur_value;
#endif /* end of WASM_ENABLE_EXTENDED_CONST_EXPR != 0 */
*p_buf = p;
destroy_const_expr_stack(&const_expr_ctx);
destroy_const_expr_stack(&const_expr_ctx, false);
return true;
fail:
destroy_const_expr_stack(&const_expr_ctx, true);
return false;
}
static bool
@ -1385,13 +1554,14 @@ load_global_section(const uint8 *buf, const uint8 *buf_end, WASMModule *module,
* global.get instructions are
* only allowed to refer to imported globals.
*/
uint32 target_global_index = global->init_expr.u.global_index;
uint32 target_global_index =
global->init_expr.u.unary.v.global_index;
bh_assert(target_global_index < module->import_global_count);
(void)target_global_index;
}
else if (INIT_EXPR_TYPE_FUNCREF_CONST
== global->init_expr.init_expr_type) {
bh_assert(global->init_expr.u.ref_index
bh_assert(global->init_expr.u.unary.v.ref_index
< module->import_function_count
+ module->function_count);
}
@ -1575,7 +1745,7 @@ load_func_index_vec(const uint8 **p_buf, const uint8 *buf_end,
}
init_expr->init_expr_type = INIT_EXPR_TYPE_FUNCREF_CONST;
init_expr->u.ref_index = function_index;
init_expr->u.unary.v.ref_index = function_index;
}
*p_buf = p;
@ -1890,6 +2060,9 @@ load_data_segment_section(const uint8 *buf, const uint8 *buf_end,
if (!(dataseg = module->data_segments[i] = loader_malloc(
sizeof(WASMDataSeg), error_buf, error_buf_size))) {
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
destroy_init_expr(&init_expr);
#endif
return false;
}
@ -2778,7 +2951,8 @@ load_from_sections(WASMModule *module, WASMSection *sections,
&& global->init_expr.init_expr_type
== INIT_EXPR_TYPE_I32_CONST) {
aux_heap_base_global = global;
aux_heap_base = (uint64)(uint32)global->init_expr.u.i32;
aux_heap_base =
(uint64)(uint32)global->init_expr.u.unary.v.i32;
aux_heap_base_global_index = export->index;
LOG_VERBOSE("Found aux __heap_base global, value: %" PRIu64,
aux_heap_base);
@ -2798,7 +2972,8 @@ load_from_sections(WASMModule *module, WASMSection *sections,
&& global->init_expr.init_expr_type
== INIT_EXPR_TYPE_I32_CONST) {
aux_data_end_global = global;
aux_data_end = (uint64)(uint32)global->init_expr.u.i32;
aux_data_end =
(uint64)(uint32)global->init_expr.u.unary.v.i32;
aux_data_end_global_index = export->index;
LOG_VERBOSE("Found aux __data_end global, value: %" PRIu64,
aux_data_end);
@ -2838,10 +3013,11 @@ load_from_sections(WASMModule *module, WASMSection *sections,
&& global->type.val_type == VALUE_TYPE_I32
&& global->init_expr.init_expr_type
== INIT_EXPR_TYPE_I32_CONST
&& (uint64)(uint32)global->init_expr.u.i32
&& (uint64)(uint32)global->init_expr.u.unary.v.i32
<= aux_heap_base) {
aux_stack_top_global = global;
aux_stack_top = (uint64)(uint32)global->init_expr.u.i32;
aux_stack_top =
(uint64)(uint32)global->init_expr.u.unary.v.i32;
module->aux_stack_top_global_index =
module->import_global_count + global_index;
module->aux_stack_bottom = aux_stack_top;
@ -3448,8 +3624,14 @@ wasm_loader_unload(WASMModule *module)
if (module->memories)
wasm_runtime_free(module->memories);
if (module->globals)
if (module->globals) {
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
for (i = 0; i < module->global_count; i++) {
destroy_init_expr(&module->globals[i].init_expr);
}
#endif
wasm_runtime_free(module->globals);
}
if (module->exports)
wasm_runtime_free(module->exports);
@ -3458,6 +3640,9 @@ wasm_loader_unload(WASMModule *module)
for (i = 0; i < module->table_seg_count; i++) {
if (module->table_segments[i].init_values)
wasm_runtime_free(module->table_segments[i].init_values);
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
destroy_init_expr(&module->table_segments[i].base_offset);
#endif
}
wasm_runtime_free(module->table_segments);
}
@ -3467,6 +3652,9 @@ wasm_loader_unload(WASMModule *module)
if (module->data_segments[i]) {
if (module->data_segments[i]->is_data_cloned)
wasm_runtime_free(module->data_segments[i]->data);
#if WASM_ENABLE_EXTENDED_CONST_EXPR != 0
destroy_init_expr(&module->data_segments[i]->base_offset);
#endif
wasm_runtime_free(module->data_segments[i]);
}
}
@ -7320,7 +7508,8 @@ re_scan:
== VALUE_TYPE_FUNCREF
&& module->globals[i].init_expr.init_expr_type
== INIT_EXPR_TYPE_FUNCREF_CONST
&& module->globals[i].init_expr.u.u32 == func_idx) {
&& module->globals[i].init_expr.u.unary.v.ref_index
== func_idx) {
func_declared = true;
break;
}
@ -7334,7 +7523,8 @@ re_scan:
i++, table_seg++) {
if (table_seg->elem_type == VALUE_TYPE_FUNCREF) {
for (j = 0; j < table_seg->value_count; j++) {
if (table_seg->init_values[j].u.ref_index
if (table_seg->init_values[j]
.u.unary.v.ref_index
== func_idx) {
func_declared = true;
break;