Refactor WASI-NN to simplify the support for multiple frameworks (#1834)
- Reorganize the library structure
- Use the latest version of `wasi-nn` wit (Oct 25, 2022):
0f77c48ec1/wasi-nn.wit.md
- Split logic that converts WASM structs to native structs in a separate file
- Simplify addition of new frameworks
This commit is contained in:
@ -3,63 +3,17 @@
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#ifndef WASI_NN_WASM_H
|
||||
#define WASI_NN_WASM_H
|
||||
|
||||
#include "wasi_nn_common.h"
|
||||
|
||||
/**
|
||||
* Following definition from:
|
||||
* [Aug 10th, 2022]
|
||||
* https://github.com/WebAssembly/wasi-nn/blob/e5e1a6c31f424c7cd63026cd270e9746775675a0/wasi-nn.wit.md
|
||||
* [Oct 25th, 2022]
|
||||
* https://github.com/WebAssembly/wasi-nn/blob/0f77c48ec195748990ff67928a4b3eef5f16c2de/wasi-nn.wit.md
|
||||
*/
|
||||
|
||||
/* The graph initialization data. */
|
||||
#ifndef WASI_NN_H
|
||||
#define WASI_NN_H
|
||||
|
||||
// This consists of an array of buffers because implementing backends may encode
|
||||
// their graph IR in parts (e.g., OpenVINO stores its IR and weights
|
||||
// separately).
|
||||
typedef struct {
|
||||
uint8_t *buf;
|
||||
uint32_t size;
|
||||
} graph_builder;
|
||||
|
||||
typedef struct {
|
||||
graph_builder *buf;
|
||||
uint32_t size;
|
||||
} graph_builder_array;
|
||||
|
||||
/* The dimensions of a tensor. */
|
||||
|
||||
// The array length matches the tensor rank and each element in the array
|
||||
// describes the size of each dimension.
|
||||
typedef struct {
|
||||
uint32_t *buf;
|
||||
uint32_t size;
|
||||
} tensor_dimensions;
|
||||
|
||||
/* The tensor data. */
|
||||
|
||||
// Initially conceived as a sparse representation, each empty cell would be
|
||||
// filled with zeros and the array length must match the product of all of the
|
||||
// dimensions and the number of bytes in the type (e.g., a 2x2 tensor with
|
||||
// 4-byte f32 elements would have a data array of length 16). Naturally, this
|
||||
// representation requires some knowledge of how to lay out data in
|
||||
// memory--e.g., using row-major ordering--and could perhaps be improved.
|
||||
typedef uint8_t *tensor_data;
|
||||
|
||||
/* A tensor. */
|
||||
|
||||
typedef struct {
|
||||
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
|
||||
// represent a tensor containing a single value, use `[1]` for the tensor
|
||||
// dimensions.
|
||||
tensor_dimensions *dimensions;
|
||||
// Describe the type of element in the tensor (e.g., f32).
|
||||
tensor_type type;
|
||||
// Contains the tensor data.
|
||||
tensor_data data;
|
||||
} tensor;
|
||||
#include <stdint.h>
|
||||
#include "wasi_nn_types.h"
|
||||
|
||||
/**
|
||||
* @brief Load an opaque sequence of bytes to use for inference.
|
||||
@ -67,25 +21,31 @@ typedef struct {
|
||||
* @param builder Model builder.
|
||||
* @param encoding Model encoding.
|
||||
* @param target Execution target.
|
||||
* @param graph Graph.
|
||||
* @param g Graph.
|
||||
* @return error Execution status.
|
||||
*/
|
||||
error
|
||||
load(graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *graph)
|
||||
__attribute__((export_module("wasi_nn")))
|
||||
execution_target target, graph *g)
|
||||
__attribute__((import_module("wasi_nn")));
|
||||
|
||||
/**
|
||||
* INFERENCE
|
||||
*
|
||||
*/
|
||||
|
||||
// Bind a `graph` to the input and output tensors for an inference.
|
||||
typedef uint32_t graph_execution_context;
|
||||
|
||||
/**
|
||||
* @brief Create an execution instance of a loaded graph.
|
||||
*
|
||||
* @param graph Graph.
|
||||
* @param g Graph.
|
||||
* @param ctx Execution context.
|
||||
* @return error Execution status.
|
||||
*/
|
||||
error
|
||||
init_execution_context(graph graph, graph_execution_context *ctx)
|
||||
__attribute__((export_module("wasi_nn")))
|
||||
init_execution_context(graph g, graph_execution_context *ctx)
|
||||
__attribute__((import_module("wasi_nn")));
|
||||
|
||||
/**
|
||||
@ -98,7 +58,6 @@ init_execution_context(graph graph, graph_execution_context *ctx)
|
||||
*/
|
||||
error
|
||||
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
|
||||
__attribute__((export_module("wasi_nn")))
|
||||
__attribute__((import_module("wasi_nn")));
|
||||
|
||||
/**
|
||||
@ -108,8 +67,7 @@ set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
|
||||
* @return error Execution status.
|
||||
*/
|
||||
error
|
||||
compute(graph_execution_context ctx) __attribute__((export_module("wasi_nn")))
|
||||
__attribute__((import_module("wasi_nn")));
|
||||
compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));
|
||||
|
||||
/**
|
||||
* @brief Extract the outputs after inference.
|
||||
@ -126,7 +84,6 @@ __attribute__((import_module("wasi_nn")));
|
||||
error
|
||||
get_output(graph_execution_context ctx, uint32_t index,
|
||||
tensor_data output_tensor, uint32_t *output_tensor_size)
|
||||
__attribute__((export_module("wasi_nn")))
|
||||
__attribute__((import_module("wasi_nn")));
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user