Make wasi-nn backends as separated shared libraries (#3509)

- All files under *core/iwasm/libraries/wasi-nn* are compiled as shared libraries
- *wasi-nn.c* is shared between backends
- Every backend has a separated shared library
- If wasi-nn feature is enabled, iwasm will depend on shared library libiwasm.so
  instead of linking static library libvmlib.a
This commit is contained in:
liang.he
2024-06-14 12:06:56 +08:00
committed by GitHub
parent 1434c45283
commit f844b33b2d
20 changed files with 296 additions and 258 deletions

View File

@ -21,7 +21,7 @@
3 -> err
4 -> NO LOGS
*/
#define NN_LOG_LEVEL 0
#define NN_LOG_LEVEL 2
#endif
// Definition of the levels

View File

@ -13,7 +13,6 @@
#include "wasi_nn_private.h"
#include "wasi_nn_app_native.h"
#include "wasi_nn_tensorflowlite.hpp"
#include "logger.h"
#include "bh_platform.h"
@ -21,45 +20,14 @@
#define HASHMAP_INITIAL_SIZE 20
/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding,
execution_target, graph *);
typedef wasi_nn_error (*INIT_EXECUTION_CONTEXT)(void *, graph,
graph_execution_context *);
typedef wasi_nn_error (*SET_INPUT)(void *, graph_execution_context, uint32_t,
tensor *);
typedef wasi_nn_error (*COMPUTE)(void *, graph_execution_context);
typedef wasi_nn_error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
tensor_data, uint32_t *);
typedef struct {
LOAD load;
INIT_EXECUTION_CONTEXT init_execution_context;
SET_INPUT set_input;
COMPUTE compute;
GET_OUTPUT get_output;
} api_function;
/* Global variables */
static api_function lookup[] = {
{ NULL, NULL, NULL, NULL, NULL },
{ NULL, NULL, NULL, NULL, NULL },
{ NULL, NULL, NULL, NULL, NULL },
{ NULL, NULL, NULL, NULL, NULL },
{ tensorflowlite_load, tensorflowlite_init_execution_context,
tensorflowlite_set_input, tensorflowlite_compute,
tensorflowlite_get_output }
};
static api_function lookup[backend_amount] = { 0 };
static HashMap *hashmap;
static void
wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx);
/* Get wasi-nn context from module instance */
static uint32
hash_func(const void *key)
{
@ -105,7 +73,16 @@ wasi_nn_initialize_context()
return NULL;
}
wasi_nn_ctx->is_model_loaded = false;
tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
/* only one backend can be registered */
{
unsigned i;
for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
if (lookup[i].init) {
lookup[i].init(&wasi_nn_ctx->backend_ctx);
break;
}
}
}
return wasi_nn_ctx;
}
@ -123,6 +100,7 @@ wasi_nn_initialize()
return true;
}
/* Get wasi-nn context from module instance */
static WASINNContext *
wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
{
@ -155,16 +133,30 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
NN_DBG_PRINTF("Freeing wasi-nn");
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
/* only one backend can be registered */
{
unsigned i;
for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
if (lookup[i].deinit) {
lookup[i].deinit(wasi_nn_ctx->backend_ctx);
break;
}
}
}
wasm_runtime_free(wasi_nn_ctx);
}
void
wasi_nn_destroy(wasm_module_inst_t instance)
static void
wasi_nn_ctx_destroy_helper(void *instance, void *wasi_nn_ctx, void *user_data)
{
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
bh_hash_map_remove(hashmap, (void *)instance, NULL, NULL);
wasi_nn_ctx_destroy(wasi_nn_ctx);
wasi_nn_ctx_destroy((WASINNContext *)wasi_nn_ctx);
}
void
wasi_nn_destroy()
{
bh_hash_map_traverse(hashmap, wasi_nn_ctx_destroy_helper, NULL);
bh_hash_map_destroy(hashmap);
}
/* Utils */
@ -233,7 +225,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
res = lookup[encoding].load(wasi_nn_ctx->tflite_ctx, &builder_native,
res = lookup[encoding].load(wasi_nn_ctx->backend_ctx, &builder_native,
encoding, target, g);
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
@ -270,7 +262,7 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
}
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
wasi_nn_ctx->tflite_ctx, g, ctx);
wasi_nn_ctx->backend_ctx, g, ctx);
NN_DBG_PRINTF(
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
@ -300,7 +292,7 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
return res;
res = lookup[wasi_nn_ctx->current_encoding].set_input(
wasi_nn_ctx->tflite_ctx, ctx, index, &input_tensor_native);
wasi_nn_ctx->backend_ctx, ctx, index, &input_tensor_native);
// XXX: Free intermediate structure pointers
if (input_tensor_native.dimensions)
@ -323,8 +315,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
res = lookup[wasi_nn_ctx->current_encoding].compute(wasi_nn_ctx->tflite_ctx,
ctx);
res = lookup[wasi_nn_ctx->current_encoding].compute(
wasi_nn_ctx->backend_ctx, ctx);
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
return res;
}
@ -360,11 +352,13 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, &output_tensor_len);
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
&output_tensor_len);
*output_tensor_size = output_tensor_len;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
output_tensor_size);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
res, *output_tensor_size);
@ -397,17 +391,53 @@ static NativeSymbol native_symbols_wasi_nn[] = {
uint32_t
get_wasi_nn_export_apis(NativeSymbol **p_native_symbols)
{
if (!wasi_nn_initialize())
return 0;
*p_native_symbols = native_symbols_wasi_nn;
return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
}
#if defined(WASI_NN_SHARED)
uint32_t
__attribute__((used)) uint32_t
get_native_lib(char **p_module_name, NativeSymbol **p_native_symbols)
{
NN_DBG_PRINTF("--|> get_native_lib");
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*p_module_name = "wasi_ephemeral_nn";
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
*p_module_name = "wasi_nn";
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
return get_wasi_nn_export_apis(p_native_symbols);
}
#endif
__attribute__((used)) int
init_native_lib()
{
NN_DBG_PRINTF("--|> init_native_lib");
if (!wasi_nn_initialize())
return 1;
return 0;
}
__attribute__((used)) void
deinit_native_lib()
{
NN_DBG_PRINTF("--|> deinit_native_lib");
wasi_nn_destroy();
}
__attribute__((used)) bool
wasi_nn_register_backend(graph_encoding backend_code, api_function apis)
{
NN_DBG_PRINTF("--|> wasi_nn_register_backend");
if (backend_code >= sizeof(lookup) / sizeof(lookup[0])) {
NN_ERR_PRINTF("Invalid backend code");
return false;
}
lookup[backend_code] = apis;
return true;
}

View File

@ -12,15 +12,7 @@
typedef struct {
bool is_model_loaded;
graph_encoding current_encoding;
void *tflite_ctx;
void *backend_ctx;
} WASINNContext;
/**
* @brief Destroy wasi-nn on app exists
*
*/
void
wasi_nn_destroy(wasm_module_inst_t instance);
#endif

View File

@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
#include "wasi_nn_types.h"
#include "wasi_nn_tensorflowlite.hpp"
#include "logger.h"
@ -487,3 +486,18 @@ tensorflowlite_destroy(void *tflite_ctx)
delete tfl_ctx;
NN_DBG_PRINTF("Memory free'd.");
}
__attribute__((constructor(200))) void
tflite_register_backend()
{
api_function apis = {
.load = tensorflowlite_load,
.init_execution_context = tensorflowlite_init_execution_context,
.set_input = tensorflowlite_set_input,
.compute = tensorflowlite_compute,
.get_output = tensorflowlite_get_output,
.init = tensorflowlite_initialize,
.deinit = tensorflowlite_destroy,
};
wasi_nn_register_backend(tensorflowlite, apis);
}