sync up with latest wasi-nn spec (#3530)

This commit is contained in:
liang.he
2024-06-17 14:58:09 +08:00
committed by GitHub
parent d3e89895be
commit db025e457a
9 changed files with 209 additions and 101 deletions

View File

@ -76,7 +76,7 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder *builder = (graph_builder *)wasm_runtime_malloc(
array_size * sizeof(graph_builder));
if (builder == NULL)
return missing_memory;
return too_large;
for (uint32_t i = 0; i < array_size; ++i) {
wasi_nn_error res;
@ -149,7 +149,7 @@ tensor_dimensions_app_native(wasm_module_inst_t instance,
*dimensions =
(tensor_dimensions *)wasm_runtime_malloc(sizeof(tensor_dimensions));
if (dimensions == NULL)
return missing_memory;
return too_large;
(*dimensions)->size = dimensions_wasm->size;
(*dimensions)->buf = (uint32_t *)wasm_runtime_addr_app_to_native(

View File

@ -16,12 +16,26 @@
#include "logger.h"
#include "bh_platform.h"
#include "wasi_nn_types.h"
#include "wasm_export.h"
#define HASHMAP_INITIAL_SIZE 20
/* Global variables */
static api_function lookup[backend_amount] = { 0 };
// if using `load_by_name`, there is no known `encoding` at the time of loading
// so, just keep one `api_function` is enough
static api_function lookup = { 0 };
#define call_wasi_nn_func(wasi_error, func, ...) \
do { \
if (lookup.func) { \
wasi_error = lookup.func(__VA_ARGS__); \
} \
else { \
NN_ERR_PRINTF("Error: %s is not registered", #func); \
wasi_error = unsupported_operation; \
} \
} while (0)
static HashMap *hashmap;
@ -73,16 +87,16 @@ wasi_nn_initialize_context()
return NULL;
}
wasi_nn_ctx->is_model_loaded = false;
/* only one backend can be registered */
{
unsigned i;
for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
if (lookup[i].init) {
lookup[i].init(&wasi_nn_ctx->backend_ctx);
break;
}
}
wasi_nn_error res;
call_wasi_nn_func(res, init, &wasi_nn_ctx->backend_ctx);
if (res != success) {
NN_ERR_PRINTF("Error while initializing backend");
wasm_runtime_free(wasi_nn_ctx);
return NULL;
}
return wasi_nn_ctx;
}
@ -90,6 +104,7 @@ static bool
wasi_nn_initialize()
{
NN_DBG_PRINTF("Initializing wasi-nn");
// hashmap { instance: wasi_nn_ctx }
hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
key_equal_func, key_destroy_func,
value_destroy_func);
@ -133,42 +148,26 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
NN_DBG_PRINTF("Freeing wasi-nn");
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
/* only one backend can be registered */
{
unsigned i;
for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
if (lookup[i].deinit) {
lookup[i].deinit(wasi_nn_ctx->backend_ctx);
break;
}
}
}
wasm_runtime_free(wasi_nn_ctx);
}
static void
wasi_nn_ctx_destroy_helper(void *instance, void *wasi_nn_ctx, void *user_data)
{
wasi_nn_ctx_destroy((WASINNContext *)wasi_nn_ctx);
/* only one backend can be registered */
wasi_nn_error res;
call_wasi_nn_func(res, deinit, wasi_nn_ctx->backend_ctx);
if (res != success) {
NN_ERR_PRINTF("Error while destroyging backend");
}
wasm_runtime_free(wasi_nn_ctx);
}
void
wasi_nn_destroy()
{
bh_hash_map_traverse(hashmap, wasi_nn_ctx_destroy_helper, NULL);
// destroy hashmap will destroy keys and values
bh_hash_map_destroy(hashmap);
}
/* Utils */
static bool
is_encoding_implemented(graph_encoding encoding)
{
return lookup[encoding].load && lookup[encoding].init_execution_context
&& lookup[encoding].set_input && lookup[encoding].compute
&& lookup[encoding].get_output;
}
static wasi_nn_error
is_model_initialized(WASINNContext *wasi_nn_ctx)
{
@ -195,13 +194,9 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
target);
if (!is_encoding_implemented(encoding)) {
NN_ERR_PRINTF("Encoding not supported.");
return invalid_encoding;
}
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
bh_assert(instance);
if (!instance)
return runtime_error;
wasi_nn_error res;
graph_builder_array builder_native = { 0 };
@ -225,10 +220,11 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
res = lookup[encoding].load(wasi_nn_ctx->backend_ctx, &builder_native,
encoding, target, g);
call_wasi_nn_func(res, load, wasi_nn_ctx->backend_ctx, &builder_native,
encoding, target, g);
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
if (res != success)
goto fail;
wasi_nn_ctx->current_encoding = encoding;
wasi_nn_ctx->is_model_loaded = true;
@ -241,6 +237,39 @@ fail:
return res;
}
wasi_nn_error
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
graph *g)
{
NN_DBG_PRINTF("Running wasi_nn_load_by_name ...");
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance) {
return runtime_error;
}
if (!wasm_runtime_validate_native_addr(instance, name, name_len)) {
return invalid_argument;
}
if (!wasm_runtime_validate_native_addr(instance, g,
(uint64)sizeof(graph))) {
return invalid_argument;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
call_wasi_nn_func(res, load_by_name, wasi_nn_ctx->backend_ctx, name,
name_len, g);
NN_DBG_PRINTF("wasi_nn_load_by_name finished with status %d", *g);
if (res != success)
return res;
wasi_nn_ctx->current_encoding = autodetect;
wasi_nn_ctx->is_model_loaded = true;
return success;
}
wasi_nn_error
wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
graph_execution_context *ctx)
@ -248,7 +277,10 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g);
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
bh_assert(instance);
if (!instance) {
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
@ -261,9 +293,8 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
return invalid_argument;
}
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
wasi_nn_ctx->backend_ctx, g, ctx);
call_wasi_nn_func(res, init_execution_context, wasi_nn_ctx->backend_ctx, g,
ctx);
NN_DBG_PRINTF(
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
*ctx);
@ -278,7 +309,10 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
index);
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
bh_assert(instance);
if (!instance) {
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
@ -291,9 +325,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
&input_tensor_native)))
return res;
res = lookup[wasi_nn_ctx->current_encoding].set_input(
wasi_nn_ctx->backend_ctx, ctx, index, &input_tensor_native);
call_wasi_nn_func(res, set_input, wasi_nn_ctx->backend_ctx, ctx, index,
&input_tensor_native);
// XXX: Free intermediate structure pointers
if (input_tensor_native.dimensions)
wasm_runtime_free(input_tensor_native.dimensions);
@ -308,15 +341,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
bh_assert(instance);
if (!instance) {
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
res = lookup[wasi_nn_ctx->current_encoding].compute(
wasi_nn_ctx->backend_ctx, ctx);
call_wasi_nn_func(res, compute, wasi_nn_ctx->backend_ctx, ctx);
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
return res;
}
@ -337,7 +372,10 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
index);
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
bh_assert(instance);
if (!instance) {
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
@ -351,14 +389,12 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
&output_tensor_len);
call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index,
output_tensor, &output_tensor_len);
*output_tensor_size = output_tensor_len;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
output_tensor_size);
call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index,
output_tensor, output_tensor_size);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
res, *output_tensor_size);
@ -375,6 +411,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
static NativeSymbol native_symbols_wasi_nn[] = {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
REG_NATIVE_FUNC(load, "(*iii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
@ -429,15 +466,9 @@ deinit_native_lib()
}
__attribute__((used)) bool
wasi_nn_register_backend(graph_encoding backend_code, api_function apis)
wasi_nn_register_backend(api_function apis)
{
NN_DBG_PRINTF("--|> wasi_nn_register_backend");
if (backend_code >= sizeof(lookup) / sizeof(lookup[0])) {
NN_ERR_PRINTF("Invalid backend code");
return false;
}
lookup[backend_code] = apis;
lookup = apis;
return true;
}

View File

@ -11,6 +11,7 @@
typedef struct {
bool is_model_loaded;
// Optional
graph_encoding current_encoding;
void *backend_ctx;
} WASINNContext;

View File

@ -7,6 +7,7 @@
#include "logger.h"
#include "bh_platform.h"
#include "wasi_nn_types.h"
#include "wasm_export.h"
#include <tensorflow/lite/interpreter.h>
@ -144,7 +145,7 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size);
if (tfl_ctx->models[*g].model_pointer == NULL) {
NN_ERR_PRINTF("Error when allocating memory for model.");
return missing_memory;
return too_large;
}
bh_memcpy_s(tfl_ctx->models[*g].model_pointer, size, builder->buf[0].buf,
@ -159,7 +160,7 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
NN_ERR_PRINTF("Loading model error.");
wasm_runtime_free(tfl_ctx->models[*g].model_pointer);
tfl_ctx->models[*g].model_pointer = NULL;
return missing_memory;
return too_large;
}
// Save target
@ -167,6 +168,30 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
return success;
}
wasi_nn_error
tensorflowlite_load_by_name(void *tflite_ctx, const char *filename,
uint32_t filename_len, graph *g)
{
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
wasi_nn_error res = initialize_g(tfl_ctx, g);
if (success != res)
return res;
// Load model
tfl_ctx->models[*g].model =
std::move(tflite::FlatBufferModel::BuildFromFile(filename, NULL));
if (tfl_ctx->models[*g].model == NULL) {
NN_ERR_PRINTF("Loading model error.");
return too_large;
}
// Use CPU as default
tfl_ctx->models[*g].target = cpu;
return success;
}
wasi_nn_error
tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
graph_execution_context *ctx)
@ -187,7 +212,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
tflite_builder(&tfl_ctx->interpreters[*ctx].interpreter);
if (tfl_ctx->interpreters[*ctx].interpreter == NULL) {
NN_ERR_PRINTF("Error when generating the interpreter.");
return missing_memory;
return too_large;
}
bool use_default = false;
@ -207,7 +232,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
if (tfl_ctx->delegate == NULL) {
NN_ERR_PRINTF("Error when generating GPU delegate.");
use_default = true;
return missing_memory;
return too_large;
}
if (tfl_ctx->interpreters[*ctx]
.interpreter->ModifyGraphWithDelegate(tfl_ctx->delegate)
@ -232,7 +257,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
if (tfl_ctx->delegate == NULL) {
NN_ERR_PRINTF("Error when generating External delegate.");
use_default = true;
return missing_memory;
return too_large;
}
if (tfl_ctx->interpreters[*ctx]
.interpreter->ModifyGraphWithDelegate(tfl_ctx->delegate)
@ -276,7 +301,7 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
if (tensor == NULL) {
NN_ERR_PRINTF("Missing memory");
return missing_memory;
return too_large;
}
uint32_t model_tensor_size = 1;
@ -363,7 +388,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
auto tensor = tfl_ctx->interpreters[ctx].interpreter->output_tensor(index);
if (tensor == NULL) {
NN_ERR_PRINTF("Missing memory");
return missing_memory;
return too_large;
}
uint32_t model_tensor_size = 1;
@ -372,7 +397,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
if (*output_tensor_size < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return missing_memory;
return too_large;
}
if (tensor->quantization.type == kTfLiteNoQuantization) {
@ -409,13 +434,13 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
return success;
}
void
wasi_nn_error
tensorflowlite_initialize(void **tflite_ctx)
{
TFLiteContext *tfl_ctx = new TFLiteContext();
if (tfl_ctx == NULL) {
NN_ERR_PRINTF("Error when allocating memory for tensorflowlite.");
return;
return runtime_error;
}
NN_DBG_PRINTF("Initializing models.");
@ -433,9 +458,10 @@ tensorflowlite_initialize(void **tflite_ctx)
tfl_ctx->delegate = NULL;
*tflite_ctx = (void *)tfl_ctx;
return success;
}
void
wasi_nn_error
tensorflowlite_destroy(void *tflite_ctx)
{
/*
@ -485,6 +511,7 @@ tensorflowlite_destroy(void *tflite_ctx)
os_mutex_destroy(&tfl_ctx->g_lock);
delete tfl_ctx;
NN_DBG_PRINTF("Memory free'd.");
return success;
}
__attribute__((constructor(200))) void
@ -492,6 +519,7 @@ tflite_register_backend()
{
api_function apis = {
.load = tensorflowlite_load,
.load_by_name = tensorflowlite_load_by_name,
.init_execution_context = tensorflowlite_init_execution_context,
.set_input = tensorflowlite_set_input,
.compute = tensorflowlite_compute,
@ -499,5 +527,5 @@ tflite_register_backend()
.init = tensorflowlite_initialize,
.deinit = tensorflowlite_destroy,
};
wasi_nn_register_backend(tensorflowlite, apis);
wasi_nn_register_backend(apis);
}

View File

@ -32,10 +32,10 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t *output_tensor_size);
void
wasi_nn_error
tensorflowlite_initialize(void **tflite_ctx);
void
wasi_nn_error
tensorflowlite_destroy(void *tflite_ctx);
#ifdef __cplusplus