Refactor WASI-NN to simplify the support for multiple frameworks (#1834)
- Reorganize the library structure
- Use the latest version of `wasi-nn` wit (Oct 25, 2022):
0f77c48ec1/wasi-nn.wit.md
- Split logic that converts WASM structs to native structs in a separate file
- Simplify addition of new frameworks
This commit is contained in:
63
core/iwasm/libraries/wasi-nn/src/utils/logger.h
Normal file
63
core/iwasm/libraries/wasi-nn/src/utils/logger.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#ifndef WASI_NN_LOGGER_H
|
||||
#define WASI_NN_LOGGER_H
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#define __FILENAME__ \
|
||||
(strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
|
||||
|
||||
/* Disable a level by removing the define */
|
||||
#define ENABLE_ERR_LOG
|
||||
#define ENABLE_WARN_LOG
|
||||
#define ENABLE_DBG_LOG
|
||||
#define ENABLE_INFO_LOG
|
||||
|
||||
// Definition of the levels
|
||||
#ifdef ENABLE_ERR_LOG
|
||||
#define NN_ERR_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_ERR_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_WARN_LOG
|
||||
#define NN_WARN_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_WARN_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_DBG_LOG
|
||||
#define NN_DBG_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_DBG_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_INFO_LOG
|
||||
#define NN_INFO_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_INFO_PRINTF(fmt, ...)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
163
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c
Normal file
163
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c
Normal file
@ -0,0 +1,163 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#include "wasi_nn_app_native.h"
|
||||
|
||||
static error
|
||||
graph_builder_app_native(wasm_module_inst_t instance,
|
||||
graph_builder_wasm *builder_wasm,
|
||||
graph_builder *builder)
|
||||
{
|
||||
if (!wasm_runtime_validate_app_addr(instance, builder_wasm->buf_offset,
|
||||
builder_wasm->size * sizeof(uint8_t))) {
|
||||
NN_ERR_PRINTF("builder_wasm->buf_offset is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
builder->buf = (uint8_t *)wasm_runtime_addr_app_to_native(
|
||||
instance, builder_wasm->buf_offset);
|
||||
builder->size = builder_wasm->size;
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
graph_builder_array_app_native(wasm_module_inst_t instance,
|
||||
graph_builder_array_wasm *builder_array_wasm,
|
||||
graph_builder_array *builder_array)
|
||||
{
|
||||
if (!wasm_runtime_validate_native_addr(instance, builder_array_wasm,
|
||||
sizeof(graph_builder_array_wasm))) {
|
||||
NN_ERR_PRINTF("builder_array_wasm is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
NN_DBG_PRINTF("Graph builder array contains %d elements",
|
||||
builder_array_wasm->size);
|
||||
|
||||
if (!wasm_runtime_validate_app_addr(
|
||||
instance, builder_array_wasm->buf_offset,
|
||||
builder_array_wasm->size * sizeof(graph_builder_wasm))) {
|
||||
NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
graph_builder_wasm *builder_wasm =
|
||||
(graph_builder_wasm *)wasm_runtime_addr_app_to_native(
|
||||
instance, builder_array_wasm->buf_offset);
|
||||
|
||||
graph_builder *builder = (graph_builder *)wasm_runtime_malloc(
|
||||
builder_array_wasm->size * sizeof(graph_builder));
|
||||
if (builder == NULL)
|
||||
return missing_memory;
|
||||
|
||||
for (uint32_t i = 0; i < builder_array_wasm->size; ++i) {
|
||||
error res;
|
||||
if (success
|
||||
!= (res = graph_builder_app_native(instance, &builder_wasm[i],
|
||||
&builder[i]))) {
|
||||
wasm_runtime_free(builder);
|
||||
return res;
|
||||
}
|
||||
|
||||
NN_DBG_PRINTF("Graph builder %d contains %d elements", i,
|
||||
builder->size);
|
||||
}
|
||||
|
||||
builder_array->buf = builder;
|
||||
builder_array->size = builder_array_wasm->size;
|
||||
return success;
|
||||
}
|
||||
|
||||
static error
|
||||
tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
|
||||
tensor_wasm *input_tensor_wasm, tensor_data *data)
|
||||
{
|
||||
if (!wasm_runtime_validate_app_addr(
|
||||
instance, input_tensor_wasm->data_offset, total_elements)) {
|
||||
NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
*data = (tensor_data)wasm_runtime_addr_app_to_native(
|
||||
instance, input_tensor_wasm->data_offset);
|
||||
return success;
|
||||
}
|
||||
|
||||
static error
|
||||
tensor_dimensions_app_native(wasm_module_inst_t instance,
|
||||
tensor_wasm *input_tensor_wasm,
|
||||
tensor_dimensions **dimensions)
|
||||
{
|
||||
if (!wasm_runtime_validate_app_addr(instance,
|
||||
input_tensor_wasm->dimensions_offset,
|
||||
sizeof(tensor_dimensions_wasm))) {
|
||||
NN_ERR_PRINTF("input_tensor_wasm->dimensions_offset is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
tensor_dimensions_wasm *dimensions_wasm =
|
||||
(tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native(
|
||||
instance, input_tensor_wasm->dimensions_offset);
|
||||
|
||||
if (!wasm_runtime_validate_app_addr(instance, dimensions_wasm->buf_offset,
|
||||
sizeof(tensor_dimensions))) {
|
||||
NN_ERR_PRINTF("dimensions_wasm->buf_offset is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
*dimensions =
|
||||
(tensor_dimensions *)wasm_runtime_malloc(sizeof(tensor_dimensions));
|
||||
if (dimensions == NULL)
|
||||
return missing_memory;
|
||||
|
||||
(*dimensions)->size = dimensions_wasm->size;
|
||||
(*dimensions)->buf = (uint32_t *)wasm_runtime_addr_app_to_native(
|
||||
instance, dimensions_wasm->buf_offset);
|
||||
|
||||
NN_DBG_PRINTF("Number of dimensions: %d", (*dimensions)->size);
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor_wasm,
|
||||
tensor *input_tensor)
|
||||
{
|
||||
NN_DBG_PRINTF("Converting tensor_wasm to tensor");
|
||||
if (!wasm_runtime_validate_native_addr(instance, input_tensor_wasm,
|
||||
sizeof(tensor_wasm))) {
|
||||
NN_ERR_PRINTF("input_tensor_wasm is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
error res;
|
||||
|
||||
tensor_dimensions *dimensions = NULL;
|
||||
if (success
|
||||
!= (res = tensor_dimensions_app_native(instance, input_tensor_wasm,
|
||||
&dimensions))) {
|
||||
NN_ERR_PRINTF("error when parsing dimensions");
|
||||
return res;
|
||||
}
|
||||
|
||||
uint32_t total_elements = 1;
|
||||
for (uint32_t i = 0; i < dimensions->size; ++i) {
|
||||
total_elements *= dimensions->buf[i];
|
||||
NN_DBG_PRINTF("Dimension %d: %d", i, dimensions->buf[i]);
|
||||
}
|
||||
NN_DBG_PRINTF("Tensor type: %d", input_tensor_wasm->type);
|
||||
NN_DBG_PRINTF("Total number of elements: %d", total_elements);
|
||||
|
||||
tensor_data data = NULL;
|
||||
if (success
|
||||
!= (res = tensor_data_app_native(instance, total_elements,
|
||||
input_tensor_wasm, &data))) {
|
||||
wasm_runtime_free(dimensions);
|
||||
return res;
|
||||
}
|
||||
|
||||
input_tensor->type = input_tensor_wasm->type;
|
||||
input_tensor->dimensions = dimensions;
|
||||
input_tensor->data = data;
|
||||
return success;
|
||||
}
|
||||
51
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h
Normal file
51
core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h
Normal file
@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#ifndef WASI_NN_APP_NATIVE
|
||||
#define WASI_NN_APP_NATIVE
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <assert.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "wasi_nn.h"
|
||||
#include "logger.h"
|
||||
|
||||
#include "bh_platform.h"
|
||||
#include "wasm_export.h"
|
||||
|
||||
typedef struct {
|
||||
uint32_t buf_offset;
|
||||
uint32_t size;
|
||||
} graph_builder_wasm;
|
||||
|
||||
typedef struct {
|
||||
uint32_t buf_offset;
|
||||
uint32_t size;
|
||||
} graph_builder_array_wasm;
|
||||
|
||||
typedef struct {
|
||||
uint32_t buf_offset;
|
||||
uint32_t size;
|
||||
} tensor_dimensions_wasm;
|
||||
|
||||
typedef struct {
|
||||
uint32_t dimensions_offset;
|
||||
tensor_type type;
|
||||
uint32_t data_offset;
|
||||
} tensor_wasm;
|
||||
|
||||
error
|
||||
graph_builder_array_app_native(wasm_module_inst_t instance,
|
||||
graph_builder_array_wasm *builder,
|
||||
graph_builder_array *builder_native);
|
||||
|
||||
error
|
||||
tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor,
|
||||
tensor *input_tensor_native);
|
||||
|
||||
#endif
|
||||
302
core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Normal file
302
core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Normal file
@ -0,0 +1,302 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include <assert.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "wasi_nn.h"
|
||||
#include "wasi_nn_app_native.h"
|
||||
#include "logger.h"
|
||||
#include "wasi_nn_tensorflowlite.hpp"
|
||||
|
||||
#include "bh_platform.h"
|
||||
#include "wasm_export.h"
|
||||
#include "wasm_runtime.h"
|
||||
#include "aot_runtime.h"
|
||||
|
||||
/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
|
||||
|
||||
typedef error (*LOAD)(graph_builder_array *, graph_encoding, execution_target,
|
||||
graph *);
|
||||
typedef error (*INIT_EXECUTION_CONTEXT)(graph, graph_execution_context *);
|
||||
typedef error (*SET_INPUT)(graph_execution_context, uint32_t, tensor *);
|
||||
typedef error (*COMPUTE)(graph_execution_context);
|
||||
typedef error (*GET_OUTPUT)(graph_execution_context, uint32_t, tensor_data,
|
||||
uint32_t *);
|
||||
|
||||
typedef struct {
|
||||
LOAD load;
|
||||
INIT_EXECUTION_CONTEXT init_execution_context;
|
||||
SET_INPUT set_input;
|
||||
COMPUTE compute;
|
||||
GET_OUTPUT get_output;
|
||||
} api_function;
|
||||
|
||||
/* Global variables */
|
||||
|
||||
static api_function lookup[] = {
|
||||
{ NULL, NULL, NULL, NULL, NULL },
|
||||
{ NULL, NULL, NULL, NULL, NULL },
|
||||
{ NULL, NULL, NULL, NULL, NULL },
|
||||
{ NULL, NULL, NULL, NULL, NULL },
|
||||
{ tensorflowlite_load, tensorflowlite_init_execution_context,
|
||||
tensorflowlite_set_input, tensorflowlite_compute,
|
||||
tensorflowlite_get_output }
|
||||
};
|
||||
|
||||
/* Utils */
|
||||
|
||||
static bool
|
||||
is_encoding_implemented(graph_encoding encoding)
|
||||
{
|
||||
return lookup[encoding].load && lookup[encoding].init_execution_context
|
||||
&& lookup[encoding].set_input && lookup[encoding].compute
|
||||
&& lookup[encoding].get_output;
|
||||
}
|
||||
|
||||
static error
|
||||
is_model_initialized(WASINNContext *wasi_nn_ctx)
|
||||
{
|
||||
if (!wasi_nn_ctx->is_initialized) {
|
||||
NN_ERR_PRINTF("Model not initialized.");
|
||||
return runtime_error;
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
WASINNContext *
|
||||
wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
||||
{
|
||||
WASINNContext *wasi_nn_ctx = NULL;
|
||||
#if WASM_ENABLE_INTERP != 0
|
||||
if (instance->module_type == Wasm_Module_Bytecode) {
|
||||
NN_DBG_PRINTF("Getting ctx from WASM");
|
||||
WASMModuleInstance *module_inst = (WASMModuleInstance *)instance;
|
||||
wasi_nn_ctx = ((WASMModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
|
||||
}
|
||||
#endif
|
||||
#if WASM_ENABLE_AOT != 0
|
||||
if (instance->module_type == Wasm_Module_AoT) {
|
||||
NN_DBG_PRINTF("Getting ctx from AOT");
|
||||
AOTModuleInstance *module_inst = (AOTModuleInstance *)instance;
|
||||
wasi_nn_ctx = ((AOTModuleInstanceExtra *)module_inst->e)->wasi_nn_ctx;
|
||||
}
|
||||
#endif
|
||||
bh_assert(wasi_nn_ctx != NULL);
|
||||
NN_DBG_PRINTF("Returning ctx");
|
||||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
/* WASI-NN implementation */
|
||||
|
||||
error
|
||||
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||
graph_encoding encoding, execution_target target, graph *g)
|
||||
{
|
||||
NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
|
||||
target);
|
||||
|
||||
if (!is_encoding_implemented(encoding)) {
|
||||
NN_ERR_PRINTF("Encoding not supported.");
|
||||
return invalid_encoding;
|
||||
}
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
bh_assert(instance);
|
||||
|
||||
error res;
|
||||
graph_builder_array builder_native = { 0 };
|
||||
if (success
|
||||
!= (res = graph_builder_array_app_native(instance, builder,
|
||||
&builder_native)))
|
||||
return res;
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, g, sizeof(graph))) {
|
||||
NN_ERR_PRINTF("graph is invalid");
|
||||
res = invalid_argument;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
res = lookup[encoding].load(&builder_native, encoding, target, g);
|
||||
|
||||
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_ctx->current_encoding = encoding;
|
||||
wasi_nn_ctx->is_initialized = true;
|
||||
|
||||
fail:
|
||||
// XXX: Free intermediate structure pointers
|
||||
if (builder_native.buf)
|
||||
wasm_runtime_free(builder_native.buf);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
error
|
||||
wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
||||
graph_execution_context *ctx)
|
||||
{
|
||||
NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g);
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
bh_assert(instance);
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
error res;
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, ctx,
|
||||
sizeof(graph_execution_context))) {
|
||||
NN_ERR_PRINTF("ctx is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(g, ctx);
|
||||
*ctx = g;
|
||||
NN_DBG_PRINTF(
|
||||
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
|
||||
*ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
error
|
||||
wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
||||
uint32_t index, tensor_wasm *input_tensor)
|
||||
{
|
||||
NN_DBG_PRINTF("Running wasi_nn_set_input [ctx=%d, index=%d]...", ctx,
|
||||
index);
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
bh_assert(instance);
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
error res;
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
|
||||
tensor input_tensor_native = { 0 };
|
||||
if (success
|
||||
!= (res = tensor_app_native(instance, input_tensor,
|
||||
&input_tensor_native)))
|
||||
return res;
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].set_input(ctx, index,
|
||||
&input_tensor_native);
|
||||
|
||||
// XXX: Free intermediate structure pointers
|
||||
if (input_tensor_native.dimensions)
|
||||
wasm_runtime_free(input_tensor_native.dimensions);
|
||||
|
||||
NN_DBG_PRINTF("wasi_nn_set_input finished with status %d", res);
|
||||
return res;
|
||||
}
|
||||
|
||||
error
|
||||
wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
||||
{
|
||||
NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx);
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
bh_assert(instance);
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
error res;
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].compute(ctx);
|
||||
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
|
||||
return res;
|
||||
}
|
||||
|
||||
error
|
||||
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
||||
uint32_t index, tensor_data output_tensor,
|
||||
uint32_t *output_tensor_size)
|
||||
{
|
||||
NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx,
|
||||
index);
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
bh_assert(instance);
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
error res;
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
|
||||
sizeof(uint32_t))) {
|
||||
NN_ERR_PRINTF("output_tensor_size is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
||||
ctx, index, output_tensor, output_tensor_size);
|
||||
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
|
||||
res, *output_tensor_size);
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Non-exposed public functions */
|
||||
|
||||
WASINNContext *
|
||||
wasi_nn_initialize()
|
||||
{
|
||||
NN_DBG_PRINTF("Initializing wasi-nn");
|
||||
WASINNContext *wasi_nn_ctx =
|
||||
(WASINNContext *)wasm_runtime_malloc(sizeof(WASINNContext));
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
NN_ERR_PRINTF("Error when allocating memory for WASI-NN context");
|
||||
return NULL;
|
||||
}
|
||||
wasi_nn_ctx->is_initialized = true;
|
||||
wasi_nn_ctx->current_encoding = 3;
|
||||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
void
|
||||
wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
|
||||
{
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
NN_ERR_PRINTF(
|
||||
"Error when deallocating memory. WASI-NN context is NULL");
|
||||
return;
|
||||
}
|
||||
NN_DBG_PRINTF("Freeing wasi-nn");
|
||||
NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
|
||||
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
||||
tensorflowlite_destroy();
|
||||
wasm_runtime_free(wasi_nn_ctx);
|
||||
}
|
||||
|
||||
/* Register WASI-NN in WAMR */
|
||||
|
||||
/* clang-format off */
|
||||
#define REG_NATIVE_FUNC(func_name, signature) \
|
||||
{ #func_name, wasi_nn_##func_name, signature, NULL }
|
||||
/* clang-format on */
|
||||
|
||||
static NativeSymbol native_symbols_wasi_nn[] = {
|
||||
REG_NATIVE_FUNC(load, "(*ii*)i"),
|
||||
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
||||
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
||||
REG_NATIVE_FUNC(compute, "(i)i"),
|
||||
REG_NATIVE_FUNC(get_output, "(ii**)i"),
|
||||
};
|
||||
|
||||
uint32_t
|
||||
get_wasi_nn_export_apis(NativeSymbol **p_libc_wasi_apis)
|
||||
{
|
||||
*p_libc_wasi_apis = native_symbols_wasi_nn;
|
||||
return sizeof(native_symbols_wasi_nn) / sizeof(NativeSymbol);
|
||||
}
|
||||
30
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h
Normal file
30
core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h
Normal file
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#ifndef WASI_NN_PRIVATE_H
|
||||
#define WASI_NN_PRIVATE_H
|
||||
|
||||
#include "wasi_nn_types.h"
|
||||
|
||||
typedef struct {
|
||||
bool is_initialized;
|
||||
graph_encoding current_encoding;
|
||||
} WASINNContext;
|
||||
|
||||
/**
|
||||
* @brief Initialize wasi-nn
|
||||
*
|
||||
*/
|
||||
WASINNContext *
|
||||
wasi_nn_initialize();
|
||||
/**
|
||||
* @brief Destroy wasi-nn on app exists
|
||||
*
|
||||
*/
|
||||
|
||||
void
|
||||
wasi_nn_destroy(WASINNContext *wasi_nn_ctx);
|
||||
|
||||
#endif
|
||||
210
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Normal file
210
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Normal file
@ -0,0 +1,210 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#include "wasi_nn.h"
|
||||
#include "wasi_nn_tensorflowlite.hpp"
|
||||
#include "logger.h"
|
||||
|
||||
#include "bh_common.h"
|
||||
#include "bh_platform.h"
|
||||
#include "platform_common.h"
|
||||
|
||||
#include <tensorflow/lite/interpreter.h>
|
||||
#include <tensorflow/lite/kernels/register.h>
|
||||
#include <tensorflow/lite/model.h>
|
||||
#include <tensorflow/lite/optional_debug_tools.h>
|
||||
#include <tensorflow/lite/error_reporter.h>
|
||||
|
||||
/* Global variables */
|
||||
|
||||
static std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
static std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
|
||||
static char *model_pointer = NULL;
|
||||
|
||||
/* WASI-NN (tensorflow) implementation */
|
||||
|
||||
error
|
||||
tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g)
|
||||
{
|
||||
if (model_pointer != NULL) {
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
}
|
||||
|
||||
if (builder->size != 1) {
|
||||
NN_ERR_PRINTF("Unexpected builder format.");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (encoding != tensorflowlite) {
|
||||
NN_ERR_PRINTF("Encoding is not tensorflowlite.");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (target != cpu) {
|
||||
NN_ERR_PRINTF("Only CPU target is supported.");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
uint32_t size = builder->buf[0].size;
|
||||
|
||||
model_pointer = (char *)wasm_runtime_malloc(size);
|
||||
if (model_pointer == NULL) {
|
||||
NN_ERR_PRINTF("Error when allocating memory for model.");
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
bh_memcpy_s(model_pointer, size, builder->buf[0].buf, size);
|
||||
|
||||
model = tflite::FlatBufferModel::BuildFromBuffer(model_pointer, size, NULL);
|
||||
if (model == NULL) {
|
||||
NN_ERR_PRINTF("Loading model error.");
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
// Build the interpreter with the InterpreterBuilder.
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder tflite_builder(*model, resolver);
|
||||
tflite_builder(&interpreter);
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Error when generating the interpreter.");
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
interpreter->AllocateTensors();
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
tensor *input_tensor)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
uint32_t num_tensors = interpreter->inputs().size();
|
||||
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
|
||||
if (index + 1 > num_tensors) {
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
auto tensor = interpreter->input_tensor(index);
|
||||
if (tensor == NULL) {
|
||||
NN_ERR_PRINTF("Missing memory");
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
uint32_t model_tensor_size = 1;
|
||||
for (int i = 0; i < tensor->dims->size; ++i)
|
||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||
|
||||
uint32_t input_tensor_size = 1;
|
||||
for (uint32_t i = 0; i < input_tensor->dimensions->size; i++)
|
||||
input_tensor_size *= (uint32_t)input_tensor->dimensions->buf[i];
|
||||
|
||||
if (model_tensor_size != input_tensor_size) {
|
||||
NN_ERR_PRINTF("Input tensor shape from the model is different than the "
|
||||
"one provided");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
auto *input = interpreter->typed_input_tensor<float>(index);
|
||||
if (input == NULL)
|
||||
return missing_memory;
|
||||
|
||||
bh_memcpy_s(input, model_tensor_size * sizeof(float), input_tensor->data,
|
||||
model_tensor_size * sizeof(float));
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_compute(graph_execution_context ctx)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
interpreter->Invoke();
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
tensor_data output_tensor,
|
||||
uint32_t *output_tensor_size)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
uint32_t num_output_tensors = interpreter->outputs().size();
|
||||
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
||||
|
||||
if (index + 1 > num_output_tensors) {
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
auto tensor = interpreter->output_tensor(index);
|
||||
if (tensor == NULL) {
|
||||
NN_ERR_PRINTF("Missing memory");
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
uint32_t model_tensor_size = 1;
|
||||
for (int i = 0; i < (int)tensor->dims->size; ++i)
|
||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||
|
||||
if (*output_tensor_size < model_tensor_size) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
float *tensor_f = interpreter->typed_output_tensor<float>(index);
|
||||
for (uint32_t i = 0; i < model_tensor_size; ++i)
|
||||
NN_DBG_PRINTF("output: %f", tensor_f[i]);
|
||||
|
||||
*output_tensor_size = model_tensor_size;
|
||||
bh_memcpy_s(output_tensor, model_tensor_size * sizeof(float), tensor_f,
|
||||
model_tensor_size * sizeof(float));
|
||||
return success;
|
||||
}
|
||||
|
||||
void
|
||||
tensorflowlite_destroy()
|
||||
{
|
||||
/*
|
||||
TensorFlow Lite memory is man
|
||||
|
||||
Related issues:
|
||||
* https://github.com/tensorflow/tensorflow/issues/15880
|
||||
*/
|
||||
NN_DBG_PRINTF("Freeing memory.");
|
||||
model.reset(nullptr);
|
||||
model = NULL;
|
||||
interpreter.reset(nullptr);
|
||||
interpreter = NULL;
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
NN_DBG_PRINTF("Memory free'd.");
|
||||
}
|
||||
41
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp
Normal file
41
core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#ifndef WASI_NN_TENSORFLOWLITE_HPP
|
||||
#define WASI_NN_TENSORFLOWLITE_HPP
|
||||
|
||||
#include "wasi_nn.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
error
|
||||
tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g);
|
||||
|
||||
error
|
||||
tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx);
|
||||
|
||||
error
|
||||
tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
tensor *input_tensor);
|
||||
|
||||
error
|
||||
tensorflowlite_compute(graph_execution_context ctx);
|
||||
|
||||
error
|
||||
tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
tensor_data output_tensor,
|
||||
uint32_t *output_tensor_size);
|
||||
|
||||
void
|
||||
tensorflowlite_destroy();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user