wasi-nn: Support multiple TFLite models (#2002)
Remove restrictions:
- Only 1 WASM app at a time
- Only 1 model at a time
- `graph` and `graph-execution-context` are ignored
Refer to previous document:
e8d718096d/core/iwasm/libraries/wasi-nn/README.md
This commit is contained in:
@ -13,51 +13,57 @@
|
||||
(strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__)
|
||||
|
||||
/* Disable a level by removing the define */
|
||||
#define ENABLE_ERR_LOG
|
||||
#define ENABLE_WARN_LOG
|
||||
#define ENABLE_DBG_LOG
|
||||
#define ENABLE_INFO_LOG
|
||||
#ifndef NN_LOG_LEVEL
|
||||
/*
|
||||
0 -> debug, info, warn, err
|
||||
1 -> info, warn, err
|
||||
2 -> warn, err
|
||||
3 -> err
|
||||
4 -> NO LOGS
|
||||
*/
|
||||
#define NN_LOG_LEVEL 0
|
||||
#endif
|
||||
|
||||
// Definition of the levels
|
||||
#ifdef ENABLE_ERR_LOG
|
||||
#define NN_ERR_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
#if NN_LOG_LEVEL <= 3
|
||||
#define NN_ERR_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d ERROR] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_ERR_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_WARN_LOG
|
||||
#define NN_WARN_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
#if NN_LOG_LEVEL <= 2
|
||||
#define NN_WARN_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d WARNING] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_WARN_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_DBG_LOG
|
||||
#define NN_DBG_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_DBG_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#ifdef ENABLE_INFO_LOG
|
||||
#define NN_INFO_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
#if NN_LOG_LEVEL <= 1
|
||||
#define NN_INFO_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d INFO] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_INFO_PRINTF(fmt, ...)
|
||||
#endif
|
||||
#if NN_LOG_LEVEL <= 0
|
||||
#define NN_DBG_PRINTF(fmt, ...) \
|
||||
do { \
|
||||
printf("[%s:%d DEBUG] " fmt, __FILENAME__, __LINE__, ##__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} while (0)
|
||||
#else
|
||||
#define NN_DBG_PRINTF(fmt, ...)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@ -22,13 +22,14 @@
|
||||
|
||||
/* Definition of 'wasi_nn.h' structs in WASM app format (using offset) */
|
||||
|
||||
typedef error (*LOAD)(graph_builder_array *, graph_encoding, execution_target,
|
||||
graph *);
|
||||
typedef error (*INIT_EXECUTION_CONTEXT)(graph, graph_execution_context *);
|
||||
typedef error (*SET_INPUT)(graph_execution_context, uint32_t, tensor *);
|
||||
typedef error (*COMPUTE)(graph_execution_context);
|
||||
typedef error (*GET_OUTPUT)(graph_execution_context, uint32_t, tensor_data,
|
||||
uint32_t *);
|
||||
typedef error (*LOAD)(void *, graph_builder_array *, graph_encoding,
|
||||
execution_target, graph *);
|
||||
typedef error (*INIT_EXECUTION_CONTEXT)(void *, graph,
|
||||
graph_execution_context *);
|
||||
typedef error (*SET_INPUT)(void *, graph_execution_context, uint32_t, tensor *);
|
||||
typedef error (*COMPUTE)(void *, graph_execution_context);
|
||||
typedef error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
|
||||
tensor_data, uint32_t *);
|
||||
|
||||
typedef struct {
|
||||
LOAD load;
|
||||
@ -123,12 +124,12 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||
goto fail;
|
||||
}
|
||||
|
||||
res = lookup[encoding].load(&builder_native, encoding, target, g);
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
res = lookup[encoding].load(wasi_nn_ctx->tflite_ctx, &builder_native,
|
||||
encoding, target, g);
|
||||
|
||||
NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g);
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_ctx->current_encoding = encoding;
|
||||
wasi_nn_ctx->is_initialized = true;
|
||||
|
||||
@ -160,8 +161,9 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(g, ctx);
|
||||
*ctx = g;
|
||||
res = lookup[wasi_nn_ctx->current_encoding].init_execution_context(
|
||||
wasi_nn_ctx->tflite_ctx, g, ctx);
|
||||
|
||||
NN_DBG_PRINTF(
|
||||
"wasi_nn_init_execution_context finished with status %d [ctx=%d]", res,
|
||||
*ctx);
|
||||
@ -189,8 +191,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
||||
&input_tensor_native)))
|
||||
return res;
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].set_input(ctx, index,
|
||||
&input_tensor_native);
|
||||
res = lookup[wasi_nn_ctx->current_encoding].set_input(
|
||||
wasi_nn_ctx->tflite_ctx, ctx, index, &input_tensor_native);
|
||||
|
||||
// XXX: Free intermediate structure pointers
|
||||
if (input_tensor_native.dimensions)
|
||||
@ -213,7 +215,8 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].compute(ctx);
|
||||
res = lookup[wasi_nn_ctx->current_encoding].compute(wasi_nn_ctx->tflite_ctx,
|
||||
ctx);
|
||||
NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res);
|
||||
return res;
|
||||
}
|
||||
@ -241,7 +244,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
||||
}
|
||||
|
||||
res = lookup[wasi_nn_ctx->current_encoding].get_output(
|
||||
ctx, index, output_tensor, output_tensor_size);
|
||||
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
|
||||
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
|
||||
res, *output_tensor_size);
|
||||
return res;
|
||||
@ -261,6 +264,7 @@ wasi_nn_initialize()
|
||||
}
|
||||
wasi_nn_ctx->is_initialized = true;
|
||||
wasi_nn_ctx->current_encoding = 3;
|
||||
tensorflowlite_initialize(&wasi_nn_ctx->tflite_ctx);
|
||||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
@ -275,7 +279,7 @@ wasi_nn_destroy(WASINNContext *wasi_nn_ctx)
|
||||
NN_DBG_PRINTF("Freeing wasi-nn");
|
||||
NN_DBG_PRINTF("-> is_initialized: %d", wasi_nn_ctx->is_initialized);
|
||||
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding);
|
||||
tensorflowlite_destroy();
|
||||
tensorflowlite_destroy(wasi_nn_ctx->tflite_ctx);
|
||||
wasm_runtime_free(wasi_nn_ctx);
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
typedef struct {
|
||||
bool is_initialized;
|
||||
graph_encoding current_encoding;
|
||||
void *tflite_ctx;
|
||||
} WASINNContext;
|
||||
|
||||
/**
|
||||
|
||||
@ -16,25 +16,105 @@
|
||||
#include <tensorflow/lite/model.h>
|
||||
#include <tensorflow/lite/optional_debug_tools.h>
|
||||
#include <tensorflow/lite/error_reporter.h>
|
||||
|
||||
#if defined(WASI_NN_ENABLE_GPU)
|
||||
#include <tensorflow/lite/delegates/gpu/delegate.h>
|
||||
#endif
|
||||
|
||||
/* Global variables */
|
||||
/* Maximum number of graphs per WASM instance */
|
||||
#define MAX_GRAPHS_PER_INST 10
|
||||
/* Maximum number of graph execution context per WASM instance*/
|
||||
#define MAX_GRAPH_EXEC_CONTEXTS_PER_INST 10
|
||||
|
||||
static std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
static std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
typedef struct {
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
} Interpreter;
|
||||
|
||||
static char *model_pointer = NULL;
|
||||
typedef struct {
|
||||
char *model_pointer;
|
||||
std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
execution_target target;
|
||||
} Model;
|
||||
|
||||
typedef struct {
|
||||
uint32_t current_models;
|
||||
Model models[MAX_GRAPHS_PER_INST];
|
||||
uint32_t current_interpreters;
|
||||
Interpreter interpreters[MAX_GRAPH_EXEC_CONTEXTS_PER_INST];
|
||||
korp_mutex g_lock;
|
||||
} TFLiteContext;
|
||||
|
||||
/* Utils */
|
||||
|
||||
static error
|
||||
initialize_g(TFLiteContext *tfl_ctx, graph *g)
|
||||
{
|
||||
os_mutex_lock(&tfl_ctx->g_lock);
|
||||
if (tfl_ctx->current_models == MAX_GRAPHS_PER_INST) {
|
||||
os_mutex_unlock(&tfl_ctx->g_lock);
|
||||
NN_ERR_PRINTF("Excedded max graphs per WASM instance");
|
||||
return runtime_error;
|
||||
}
|
||||
*g = tfl_ctx->current_models++;
|
||||
os_mutex_unlock(&tfl_ctx->g_lock);
|
||||
return success;
|
||||
}
|
||||
static error
|
||||
initialize_graph_ctx(TFLiteContext *tfl_ctx, graph g,
|
||||
graph_execution_context *ctx)
|
||||
{
|
||||
os_mutex_lock(&tfl_ctx->g_lock);
|
||||
if (tfl_ctx->current_interpreters == MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
|
||||
os_mutex_unlock(&tfl_ctx->g_lock);
|
||||
NN_ERR_PRINTF("Excedded max graph execution context per WASM instance");
|
||||
return runtime_error;
|
||||
}
|
||||
*ctx = tfl_ctx->current_interpreters++;
|
||||
os_mutex_unlock(&tfl_ctx->g_lock);
|
||||
return success;
|
||||
}
|
||||
|
||||
static error
|
||||
is_valid_graph(TFLiteContext *tfl_ctx, graph g)
|
||||
{
|
||||
if (g >= MAX_GRAPHS_PER_INST) {
|
||||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
|
||||
return runtime_error;
|
||||
}
|
||||
if (tfl_ctx->models[g].model_pointer == NULL) {
|
||||
NN_ERR_PRINTF("Context (model) non-initialized.");
|
||||
return runtime_error;
|
||||
}
|
||||
if (tfl_ctx->models[g].model == NULL) {
|
||||
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
|
||||
return runtime_error;
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
static error
|
||||
is_valid_graph_execution_context(TFLiteContext *tfl_ctx,
|
||||
graph_execution_context ctx)
|
||||
{
|
||||
if (ctx >= MAX_GRAPH_EXEC_CONTEXTS_PER_INST) {
|
||||
NN_ERR_PRINTF("Invalid graph execution context: %d >= %d", ctx,
|
||||
MAX_GRAPH_EXEC_CONTEXTS_PER_INST);
|
||||
return runtime_error;
|
||||
}
|
||||
if (tfl_ctx->interpreters[ctx].interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Context (interpreter) non-initialized.");
|
||||
return runtime_error;
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
/* WASI-NN (tensorflow) implementation */
|
||||
|
||||
error
|
||||
tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g)
|
||||
tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
|
||||
graph_encoding encoding, execution_target target, graph *g)
|
||||
{
|
||||
if (model_pointer != NULL) {
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
}
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
if (builder->size != 1) {
|
||||
NN_ERR_PRINTF("Unexpected builder format.");
|
||||
@ -51,39 +131,68 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
error res;
|
||||
if (success != (res = initialize_g(tfl_ctx, g)))
|
||||
return res;
|
||||
|
||||
uint32_t size = builder->buf[0].size;
|
||||
|
||||
model_pointer = (char *)wasm_runtime_malloc(size);
|
||||
if (model_pointer == NULL) {
|
||||
// Save model
|
||||
tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size);
|
||||
if (tfl_ctx->models[*g].model_pointer == NULL) {
|
||||
NN_ERR_PRINTF("Error when allocating memory for model.");
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
bh_memcpy_s(model_pointer, size, builder->buf[0].buf, size);
|
||||
bh_memcpy_s(tfl_ctx->models[*g].model_pointer, size, builder->buf[0].buf,
|
||||
size);
|
||||
|
||||
model = tflite::FlatBufferModel::BuildFromBuffer(model_pointer, size, NULL);
|
||||
if (model == NULL) {
|
||||
// Save model flatbuffer
|
||||
tfl_ctx->models[*g].model =
|
||||
std::move(tflite::FlatBufferModel::BuildFromBuffer(
|
||||
tfl_ctx->models[*g].model_pointer, size, NULL));
|
||||
|
||||
if (tfl_ctx->models[*g].model == NULL) {
|
||||
NN_ERR_PRINTF("Loading model error.");
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
wasm_runtime_free(tfl_ctx->models[*g].model_pointer);
|
||||
tfl_ctx->models[*g].model_pointer = NULL;
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
// Save target
|
||||
tfl_ctx->models[*g].target = target;
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
|
||||
graph_execution_context *ctx)
|
||||
{
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
error res;
|
||||
if (success != (res = is_valid_graph(tfl_ctx, g)))
|
||||
return res;
|
||||
|
||||
if (success != (res = initialize_graph_ctx(tfl_ctx, g, ctx)))
|
||||
return res;
|
||||
|
||||
// Build the interpreter with the InterpreterBuilder.
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder tflite_builder(*model, resolver);
|
||||
tflite_builder(&interpreter);
|
||||
if (interpreter == NULL) {
|
||||
tflite::InterpreterBuilder tflite_builder(*tfl_ctx->models[g].model,
|
||||
resolver);
|
||||
tflite_builder(&tfl_ctx->interpreters[*ctx].interpreter);
|
||||
if (tfl_ctx->interpreters[*ctx].interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Error when generating the interpreter.");
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
bool use_default = false;
|
||||
switch (target) {
|
||||
switch (tfl_ctx->models[g].target) {
|
||||
case gpu:
|
||||
{
|
||||
#if defined(WASI_NN_ENABLE_GPU)
|
||||
NN_WARN_PRINTF("GPU enabled.");
|
||||
// https://www.tensorflow.org/lite/performance/gpu
|
||||
auto options = TfLiteGpuDelegateOptionsV2Default();
|
||||
options.inference_preference =
|
||||
@ -91,10 +200,16 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
options.inference_priority1 =
|
||||
TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
|
||||
auto *delegate = TfLiteGpuDelegateV2Create(&options);
|
||||
if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
|
||||
if (tfl_ctx->interpreters[*ctx]
|
||||
.interpreter->ModifyGraphWithDelegate(delegate)
|
||||
!= kTfLiteOk) {
|
||||
NN_ERR_PRINTF("Error when enabling GPU delegate.");
|
||||
use_default = true;
|
||||
}
|
||||
#else
|
||||
NN_WARN_PRINTF("GPU not enabled.");
|
||||
use_default = true;
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -103,36 +218,28 @@ tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
if (use_default)
|
||||
NN_WARN_PRINTF("Default encoding is CPU.");
|
||||
|
||||
tfl_ctx->interpreters[*ctx].interpreter->AllocateTensors();
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx)
|
||||
tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
|
||||
uint32_t index, tensor *input_tensor)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
interpreter->AllocateTensors();
|
||||
return success;
|
||||
}
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
error
|
||||
tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
tensor *input_tensor)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
error res;
|
||||
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
||||
return res;
|
||||
|
||||
uint32_t num_tensors = interpreter->inputs().size();
|
||||
uint32_t num_tensors =
|
||||
tfl_ctx->interpreters[ctx].interpreter->inputs().size();
|
||||
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
|
||||
if (index + 1 > num_tensors) {
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
auto tensor = interpreter->input_tensor(index);
|
||||
auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
|
||||
if (tensor == NULL) {
|
||||
NN_ERR_PRINTF("Missing memory");
|
||||
return missing_memory;
|
||||
@ -152,7 +259,9 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
auto *input = interpreter->typed_input_tensor<float>(index);
|
||||
auto *input =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
||||
index);
|
||||
if (input == NULL)
|
||||
return missing_memory;
|
||||
|
||||
@ -162,34 +271,38 @@ tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_compute(graph_execution_context ctx)
|
||||
tensorflowlite_compute(void *tflite_ctx, graph_execution_context ctx)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
interpreter->Invoke();
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
error res;
|
||||
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
||||
return res;
|
||||
|
||||
tfl_ctx->interpreters[ctx].interpreter->Invoke();
|
||||
return success;
|
||||
}
|
||||
|
||||
error
|
||||
tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
tensor_data output_tensor,
|
||||
tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||
uint32_t index, tensor_data output_tensor,
|
||||
uint32_t *output_tensor_size)
|
||||
{
|
||||
if (interpreter == NULL) {
|
||||
NN_ERR_PRINTF("Non-initialized interpreter.");
|
||||
return runtime_error;
|
||||
}
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
uint32_t num_output_tensors = interpreter->outputs().size();
|
||||
error res;
|
||||
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
||||
return res;
|
||||
|
||||
uint32_t num_output_tensors =
|
||||
tfl_ctx->interpreters[ctx].interpreter->outputs().size();
|
||||
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
||||
|
||||
if (index + 1 > num_output_tensors) {
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
auto tensor = interpreter->output_tensor(index);
|
||||
auto tensor = tfl_ctx->interpreters[ctx].interpreter->output_tensor(index);
|
||||
if (tensor == NULL) {
|
||||
NN_ERR_PRINTF("Missing memory");
|
||||
return missing_memory;
|
||||
@ -204,7 +317,9 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
float *tensor_f = interpreter->typed_output_tensor<float>(index);
|
||||
float *tensor_f =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
||||
index);
|
||||
for (uint32_t i = 0; i < model_tensor_size; ++i)
|
||||
NN_DBG_PRINTF("output: %f", tensor_f[i]);
|
||||
|
||||
@ -215,20 +330,51 @@ tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
}
|
||||
|
||||
void
|
||||
tensorflowlite_destroy()
|
||||
tensorflowlite_initialize(void **tflite_ctx)
|
||||
{
|
||||
TFLiteContext *tfl_ctx = new TFLiteContext();
|
||||
if (tfl_ctx == NULL) {
|
||||
NN_ERR_PRINTF("Error when allocating memory for tensorflowlite.");
|
||||
return;
|
||||
}
|
||||
|
||||
NN_DBG_PRINTF("Initializing models.");
|
||||
tfl_ctx->current_models = 0;
|
||||
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
||||
tfl_ctx->models[i].model_pointer = NULL;
|
||||
}
|
||||
NN_DBG_PRINTF("Initializing interpreters.");
|
||||
tfl_ctx->current_interpreters = 0;
|
||||
|
||||
if (os_mutex_init(&tfl_ctx->g_lock) != 0) {
|
||||
NN_ERR_PRINTF("Error while initializing the lock");
|
||||
}
|
||||
|
||||
*tflite_ctx = (void *)tfl_ctx;
|
||||
}
|
||||
|
||||
void
|
||||
tensorflowlite_destroy(void *tflite_ctx)
|
||||
{
|
||||
/*
|
||||
TensorFlow Lite memory is man
|
||||
TensorFlow Lite memory is internally managed by tensorflow
|
||||
|
||||
Related issues:
|
||||
* https://github.com/tensorflow/tensorflow/issues/15880
|
||||
*/
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
NN_DBG_PRINTF("Freeing memory.");
|
||||
model.reset(nullptr);
|
||||
model = NULL;
|
||||
interpreter.reset(nullptr);
|
||||
interpreter = NULL;
|
||||
wasm_runtime_free(model_pointer);
|
||||
model_pointer = NULL;
|
||||
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
||||
tfl_ctx->models[i].model.reset();
|
||||
if (tfl_ctx->models[i].model_pointer)
|
||||
wasm_runtime_free(tfl_ctx->models[i].model_pointer);
|
||||
tfl_ctx->models[i].model_pointer = NULL;
|
||||
}
|
||||
for (int i = 0; i < MAX_GRAPH_EXEC_CONTEXTS_PER_INST; ++i) {
|
||||
tfl_ctx->interpreters[i].interpreter.reset();
|
||||
}
|
||||
os_mutex_destroy(&tfl_ctx->g_lock);
|
||||
delete tfl_ctx;
|
||||
NN_DBG_PRINTF("Memory free'd.");
|
||||
}
|
||||
|
||||
@ -13,26 +13,30 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
error
|
||||
tensorflowlite_load(graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g);
|
||||
tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder,
|
||||
graph_encoding encoding, execution_target target, graph *g);
|
||||
|
||||
error
|
||||
tensorflowlite_init_execution_context(graph g, graph_execution_context *ctx);
|
||||
tensorflowlite_init_execution_context(void *tflite_ctx, graph g,
|
||||
graph_execution_context *ctx);
|
||||
|
||||
error
|
||||
tensorflowlite_set_input(graph_execution_context ctx, uint32_t index,
|
||||
tensor *input_tensor);
|
||||
tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
|
||||
uint32_t index, tensor *input_tensor);
|
||||
|
||||
error
|
||||
tensorflowlite_compute(graph_execution_context ctx);
|
||||
tensorflowlite_compute(void *tflite_ctx, graph_execution_context ctx);
|
||||
|
||||
error
|
||||
tensorflowlite_get_output(graph_execution_context ctx, uint32_t index,
|
||||
tensor_data output_tensor,
|
||||
tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||
uint32_t index, tensor_data output_tensor,
|
||||
uint32_t *output_tensor_size);
|
||||
|
||||
void
|
||||
tensorflowlite_destroy();
|
||||
tensorflowlite_initialize(void **tflite_ctx);
|
||||
|
||||
void
|
||||
tensorflowlite_destroy(void *tflite_ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user