wasi-nn: Add a new target for llama.cpp as a wasi-nn backend (#3709)
Minimum support: - [x] accept (WasmEdge) customized model parameters. metadata. - [x] Target [wasmedge-ggml examples](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml) - [x] basic - [x] chatml - [x] gemma - [x] llama - [x] qwen --- In the future, to support if required: - [ ] Target [wasmedge-ggml examples](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml) - [ ] command-r. (>70G memory requirement) - [ ] embedding. (embedding mode) - [ ] grammar. (use the grammar option to constrain the model to generate the JSON output) - [ ] llama-stream. (new APIS `compute_single`, `get_output_single`, `fini_single`) - [ ] llava. (image representation) - [ ] llava-base64-stream. (image representation) - [ ] multimodel. (image representation) - [ ] Target [llamaedge](https://github.com/LlamaEdge/LlamaEdge)
This commit is contained in:
@ -29,7 +29,7 @@
|
||||
struct backends_api_functions {
|
||||
void *backend_handle;
|
||||
api_function functions;
|
||||
} lookup[autodetect] = { 0 };
|
||||
} lookup[autodetect + 1] = { 0 };
|
||||
|
||||
#define call_wasi_nn_func(backend_encoding, func, wasi_error, ...) \
|
||||
do { \
|
||||
@ -168,14 +168,7 @@ wasi_nn_destroy()
|
||||
lookup[i].backend_handle = NULL;
|
||||
}
|
||||
|
||||
lookup[i].functions.init = NULL;
|
||||
lookup[i].functions.deinit = NULL;
|
||||
lookup[i].functions.load = NULL;
|
||||
lookup[i].functions.load_by_name = NULL;
|
||||
lookup[i].functions.init_execution_context = NULL;
|
||||
lookup[i].functions.set_input = NULL;
|
||||
lookup[i].functions.compute = NULL;
|
||||
lookup[i].functions.get_output = NULL;
|
||||
memset(&lookup[i].functions, 0, sizeof(api_function));
|
||||
}
|
||||
}
|
||||
|
||||
@ -208,6 +201,10 @@ choose_a_backend()
|
||||
return ggml;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_WARN_PRINTF("%s", dlerror());
|
||||
#endif
|
||||
|
||||
handle = dlopen(OPENVINO_BACKEND_LIB, RTLD_LAZY);
|
||||
if (handle) {
|
||||
NN_INFO_PRINTF("Using openvino backend");
|
||||
@ -215,6 +212,10 @@ choose_a_backend()
|
||||
return openvino;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_WARN_PRINTF("%s", dlerror());
|
||||
#endif
|
||||
|
||||
handle = dlopen(TFLITE_BACKEND_LIB, RTLD_LAZY);
|
||||
if (handle) {
|
||||
NN_INFO_PRINTF("Using tflite backend");
|
||||
@ -222,6 +223,11 @@ choose_a_backend()
|
||||
return tensorflowlite;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_WARN_PRINTF("%s", dlerror());
|
||||
#endif
|
||||
|
||||
NN_WARN_PRINTF("No backend found");
|
||||
return unknown_backend;
|
||||
}
|
||||
|
||||
@ -257,6 +263,14 @@ register_backend(void *handle, api_function *functions)
|
||||
}
|
||||
functions->load_by_name = load_by_name;
|
||||
|
||||
LOAD_BY_NAME_WITH_CONFIG load_by_name_with_config =
|
||||
(LOAD_BY_NAME_WITH_CONFIG)dlsym(handle, "load_by_name_with_config");
|
||||
if (!load_by_name_with_config) {
|
||||
NN_WARN_PRINTF("load_by_name_with_config() not found");
|
||||
// since only llama.cpp backend need to support this function
|
||||
}
|
||||
functions->load_by_name_with_config = load_by_name_with_config;
|
||||
|
||||
INIT_EXECUTION_CONTEXT init_execution_context =
|
||||
(INIT_EXECUTION_CONTEXT)dlsym(handle, "init_execution_context");
|
||||
if (!init_execution_context) {
|
||||
@ -329,21 +343,23 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding)
|
||||
static bool
|
||||
detect_and_load_backend(graph_encoding backend_hint,
|
||||
struct backends_api_functions *backends,
|
||||
graph_encoding *loaded_backed)
|
||||
graph_encoding *loaded_backend)
|
||||
{
|
||||
if (backend_hint >= autodetect)
|
||||
if (backend_hint > autodetect)
|
||||
return false;
|
||||
|
||||
if (backend_hint == autodetect)
|
||||
backend_hint = choose_a_backend();
|
||||
|
||||
/* if already loaded */
|
||||
if (lookup[backend_hint].backend_handle) {
|
||||
*loaded_backed = backend_hint;
|
||||
return true;
|
||||
}
|
||||
if (backend_hint == unknown_backend)
|
||||
return false;
|
||||
|
||||
*loaded_backend = backend_hint;
|
||||
|
||||
/* if already loaded */
|
||||
if (lookup[backend_hint].backend_handle)
|
||||
return true;
|
||||
|
||||
*loaded_backed = backend_hint;
|
||||
const char *backend_lib_name =
|
||||
graph_encoding_to_backend_lib_name(backend_hint);
|
||||
if (!backend_lib_name)
|
||||
@ -353,6 +369,7 @@ detect_and_load_backend(graph_encoding backend_hint,
|
||||
}
|
||||
|
||||
/* WASI-NN implementation */
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
wasi_nn_error
|
||||
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
|
||||
@ -392,15 +409,15 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||
goto fail;
|
||||
}
|
||||
|
||||
graph_encoding loaded_backed = autodetect;
|
||||
if (!detect_and_load_backend(encoding, lookup, &loaded_backed)) {
|
||||
graph_encoding loaded_backend = autodetect;
|
||||
if (!detect_and_load_backend(encoding, lookup, &loaded_backend)) {
|
||||
res = invalid_encoding;
|
||||
NN_ERR_PRINTF("load backend failed");
|
||||
goto fail;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
wasi_nn_ctx->backend = loaded_backed;
|
||||
wasi_nn_ctx->backend = loaded_backend;
|
||||
|
||||
/* init() the backend */
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
||||
@ -413,7 +430,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||
if (res != success)
|
||||
goto fail;
|
||||
|
||||
wasi_nn_ctx->backend = loaded_backed;
|
||||
wasi_nn_ctx->is_model_loaded = true;
|
||||
|
||||
fail:
|
||||
@ -428,8 +444,6 @@ wasi_nn_error
|
||||
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
||||
graph *g)
|
||||
{
|
||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
||||
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
if (!instance) {
|
||||
return runtime_error;
|
||||
@ -446,15 +460,23 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
graph_encoding loaded_backed = autodetect;
|
||||
if (detect_and_load_backend(autodetect, lookup, &loaded_backed)) {
|
||||
if (name_len == 0 || name[name_len] != '\0') {
|
||||
NN_ERR_PRINTF("Invalid filename");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
||||
|
||||
graph_encoding loaded_backend = autodetect;
|
||||
if (!detect_and_load_backend(autodetect, lookup, &loaded_backend)) {
|
||||
NN_ERR_PRINTF("load backend failed");
|
||||
return invalid_encoding;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
wasi_nn_error res;
|
||||
wasi_nn_ctx->backend = loaded_backend;
|
||||
|
||||
wasi_nn_error res;
|
||||
/* init() the backend */
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
||||
&wasi_nn_ctx->backend_ctx);
|
||||
@ -466,7 +488,67 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
||||
if (res != success)
|
||||
return res;
|
||||
|
||||
wasi_nn_ctx->backend = loaded_backed;
|
||||
wasi_nn_ctx->backend = loaded_backend;
|
||||
wasi_nn_ctx->is_model_loaded = true;
|
||||
return success;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
||||
int32_t name_len, char *config,
|
||||
int32_t config_len, graph *g)
|
||||
{
|
||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||
if (!instance) {
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, name, name_len)) {
|
||||
NN_ERR_PRINTF("name is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, g,
|
||||
(uint64)sizeof(graph))) {
|
||||
NN_ERR_PRINTF("graph is invalid");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (name_len == 0 || name[name_len] != '\0') {
|
||||
NN_ERR_PRINTF("Invalid filename");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (!config || config_len == 0 || config[config_len] != '\0') {
|
||||
NN_ERR_PRINTF("Invalid config");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
||||
|
||||
graph_encoding loaded_backend = autodetect;
|
||||
if (!detect_and_load_backend(autodetect, lookup, &loaded_backend)) {
|
||||
NN_ERR_PRINTF("load backend failed");
|
||||
return invalid_encoding;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
wasi_nn_ctx->backend = loaded_backend;
|
||||
|
||||
wasi_nn_error res;
|
||||
/* init() the backend */
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
||||
&wasi_nn_ctx->backend_ctx);
|
||||
if (res != success)
|
||||
return res;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
|
||||
wasi_nn_ctx->backend_ctx, name, name_len, config,
|
||||
config_len, g);
|
||||
if (res != success)
|
||||
return res;
|
||||
|
||||
wasi_nn_ctx->backend = loaded_backend;
|
||||
wasi_nn_ctx->is_model_loaded = true;
|
||||
return success;
|
||||
}
|
||||
@ -608,6 +690,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
REG_NATIVE_FUNC(load, "(*iii*)i"),
|
||||
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
|
||||
REG_NATIVE_FUNC(load_by_name_with_config, "(*i*i*)i"),
|
||||
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
||||
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
||||
REG_NATIVE_FUNC(compute, "(i)i"),
|
||||
|
||||
601
core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c
Normal file
601
core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c
Normal file
@ -0,0 +1,601 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
#include "wasi_nn_types.h"
|
||||
#include "utils/logger.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include "cJSON.h"
|
||||
|
||||
// build info
|
||||
extern int LLAMA_BUILD_NUMBER;
|
||||
extern char const *LLAMA_COMMIT;
|
||||
extern char const *LLAMA_COMPILER;
|
||||
extern char const *LLAMA_BUILD_TARGET;
|
||||
|
||||
// compatable with WasmEdge
|
||||
// https://github.com/second-state/WasmEdge-WASINN-examples/blob/master/wasmedge-ggml/README.md#parameters
|
||||
// https://github.com/WasmEdge/WasmEdge/blob/master/plugins/wasi_nn/ggml.cpp
|
||||
struct wasi_nn_llama_config {
|
||||
// Backend(plugin in WasmEdge) parameters:
|
||||
bool enable_log;
|
||||
bool enable_debug_log;
|
||||
bool stream_stdout;
|
||||
// embedding mode
|
||||
bool embedding;
|
||||
// TODO: can it be -1?
|
||||
// can't bigger than ctx_size
|
||||
int32_t n_predict;
|
||||
char *reverse_prompt;
|
||||
|
||||
// Used by LLaVA
|
||||
// multi-model project file
|
||||
char *mmproj;
|
||||
char *image;
|
||||
|
||||
// Model parameters (need to reload the model if updated):
|
||||
// align to definition of struct llama_model_params
|
||||
int32_t n_gpu_layers;
|
||||
int32_t main_gpu;
|
||||
// limited size: llama_max_devices()
|
||||
float *tensor_split;
|
||||
bool use_mmap;
|
||||
|
||||
// Context parameters (used by the llama context):
|
||||
uint32_t ctx_size;
|
||||
uint32_t batch_size;
|
||||
uint32_t ubatch_size;
|
||||
uint32_t threads;
|
||||
|
||||
// Sampling parameters (used by the llama sampling context).
|
||||
float temp;
|
||||
float topP;
|
||||
float repeat_penalty;
|
||||
float presence_penalty;
|
||||
float frequency_penalty;
|
||||
};
|
||||
|
||||
struct LlamaContext {
|
||||
struct llama_context *ctx;
|
||||
struct llama_model *model;
|
||||
llama_token *prompt;
|
||||
size_t prompt_len;
|
||||
llama_token *generation;
|
||||
size_t generation_len;
|
||||
struct wasi_nn_llama_config config;
|
||||
};
|
||||
|
||||
static void
|
||||
wasm_edge_llama_default_configuration(struct wasi_nn_llama_config *output)
|
||||
{
|
||||
output->enable_log = false;
|
||||
output->enable_debug_log = false;
|
||||
output->stream_stdout = false;
|
||||
output->embedding = false;
|
||||
output->n_predict = 512;
|
||||
output->reverse_prompt = NULL;
|
||||
|
||||
output->mmproj = NULL;
|
||||
output->image = NULL;
|
||||
|
||||
output->main_gpu = 0;
|
||||
output->n_gpu_layers = 0;
|
||||
output->tensor_split = NULL;
|
||||
output->use_mmap = true;
|
||||
|
||||
// 0 = from model
|
||||
output->ctx_size = 0;
|
||||
output->batch_size = 512;
|
||||
output->ubatch_size = output->batch_size;
|
||||
output->threads = 1;
|
||||
|
||||
output->temp = 0.80;
|
||||
output->topP = 0.95;
|
||||
output->repeat_penalty = 1.10;
|
||||
output->presence_penalty = 0.0;
|
||||
output->frequency_penalty = 0.0;
|
||||
}
|
||||
|
||||
static void
|
||||
wasm_edge_llama_apply_configuration(const char *config_json,
|
||||
struct wasi_nn_llama_config *output)
|
||||
{
|
||||
cJSON *root = cJSON_Parse(config_json);
|
||||
if (root == NULL) {
|
||||
const char *error_ptr = cJSON_GetErrorPtr();
|
||||
if (error_ptr != NULL) {
|
||||
NN_WARN_PRINTF("Error before: %s\n", error_ptr);
|
||||
}
|
||||
else {
|
||||
NN_WARN_PRINTF("Failed to parse JSON");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
cJSON *item = NULL;
|
||||
|
||||
item = cJSON_GetObjectItem(root, "enable-log");
|
||||
if (item != NULL) {
|
||||
output->enable_log = cJSON_IsTrue(item);
|
||||
NN_DBG_PRINTF("apply enable-log %d", output->enable_log);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "enable-debug-log");
|
||||
if (item != NULL) {
|
||||
output->enable_debug_log = cJSON_IsTrue(item);
|
||||
NN_DBG_PRINTF("apply enable-debug-log %d", output->enable_debug_log);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "stream-stdout");
|
||||
if (item != NULL) {
|
||||
output->stream_stdout = cJSON_IsTrue(item);
|
||||
NN_DBG_PRINTF("apply stream-stdout %d", output->stream_stdout);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "embedding");
|
||||
if (item != NULL) {
|
||||
output->embedding = cJSON_IsTrue(item);
|
||||
NN_DBG_PRINTF("apply embedding %d", output->embedding);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "n-predict");
|
||||
if (item != NULL) {
|
||||
output->n_predict = (int32_t)cJSON_GetNumberValue(item);
|
||||
NN_DBG_PRINTF("apply n-predict %d", output->n_predict);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "n-gpu-layers");
|
||||
if (item != NULL) {
|
||||
output->n_gpu_layers = (int32_t)cJSON_GetNumberValue(item);
|
||||
NN_DBG_PRINTF("apply n_gpu_layers %d", output->n_gpu_layers);
|
||||
}
|
||||
|
||||
item = cJSON_GetObjectItem(root, "ctx-size");
|
||||
if (item != NULL) {
|
||||
output->ctx_size = (uint32_t)cJSON_GetNumberValue(item);
|
||||
NN_DBG_PRINTF("apply ctx-size %d", output->ctx_size);
|
||||
}
|
||||
|
||||
// more ...
|
||||
|
||||
cJSON_Delete(root);
|
||||
}
|
||||
|
||||
static struct llama_model_params
|
||||
llama_model_params_from_wasi_nn_llama_config(
|
||||
struct wasi_nn_llama_config *config)
|
||||
{
|
||||
struct llama_model_params result = llama_model_default_params();
|
||||
|
||||
// TODO: support more
|
||||
result.main_gpu = config->main_gpu;
|
||||
result.n_gpu_layers = config->n_gpu_layers;
|
||||
result.use_mmap = config->use_mmap;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static struct llama_context_params
|
||||
llama_context_params_from_wasi_nn_llama_config(
|
||||
struct wasi_nn_llama_config *config)
|
||||
{
|
||||
struct llama_context_params result = llama_context_default_params();
|
||||
|
||||
// TODO: support more
|
||||
result.n_ctx = config->ctx_size;
|
||||
// result.embeddings = config->embedding;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void
|
||||
llama_batch_clear(struct llama_batch *batch)
|
||||
{
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
static void
|
||||
llama_batch_add(struct llama_batch *batch, llama_token id, llama_pos pos,
|
||||
llama_seq_id *seq_ids, size_t seq_ids_len, bool logits)
|
||||
{
|
||||
batch->token[batch->n_tokens] = id;
|
||||
batch->pos[batch->n_tokens] = pos;
|
||||
batch->n_seq_id[batch->n_tokens] = seq_ids_len;
|
||||
for (size_t i = 0; i < seq_ids_len; ++i) {
|
||||
batch->seq_id[batch->n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch->logits[batch->n_tokens] = logits;
|
||||
|
||||
batch->n_tokens++;
|
||||
}
|
||||
|
||||
// always output ERROR and WARN
|
||||
// INFO needs enable_log
|
||||
// DEBUG needs enable_debug_log
|
||||
static void
|
||||
llama_log_callback_local(enum ggml_log_level level, const char *text,
|
||||
void *user_data)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)user_data;
|
||||
|
||||
if (level == GGML_LOG_LEVEL_DEBUG && !backend_ctx->config.enable_debug_log)
|
||||
return;
|
||||
|
||||
if (level == GGML_LOG_LEVEL_INFO && !backend_ctx->config.enable_log)
|
||||
return;
|
||||
|
||||
printf("%s", text);
|
||||
}
|
||||
|
||||
static void
|
||||
llama_build_output_metadata(const struct LlamaContext *backend_ctx,
|
||||
char *output_buf, size_t output_buf_size)
|
||||
{
|
||||
snprintf(output_buf, output_buf_size,
|
||||
"{\"input_tokens\":%ld, \"output_tokens\":%ld, "
|
||||
"\"llama_build_number\":%d,"
|
||||
"\"llama_commit\":\"%s\"}",
|
||||
backend_ctx->prompt_len, backend_ctx->generation_len,
|
||||
LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
init_backend(void **ctx)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = calloc(1, sizeof(struct LlamaContext));
|
||||
if (!backend_ctx) {
|
||||
NN_ERR_PRINTF("Allocate for OpenVINOContext failed");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
// llama_numa_init();
|
||||
llama_log_set(llama_log_callback_local, backend_ctx);
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_INFO_PRINTF("llama_build_number: % d, llama_commit: %s, llama_compiler: "
|
||||
"%s, llama_build_target: %s",
|
||||
LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER,
|
||||
LLAMA_BUILD_TARGET);
|
||||
#endif
|
||||
|
||||
*ctx = (void *)backend_ctx;
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
deinit_backend(void *ctx)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
if (!backend_ctx)
|
||||
return invalid_argument;
|
||||
|
||||
if (backend_ctx->generation)
|
||||
free(backend_ctx->generation);
|
||||
|
||||
if (backend_ctx->prompt)
|
||||
free(backend_ctx->prompt);
|
||||
|
||||
if (backend_ctx->ctx)
|
||||
llama_free(backend_ctx->ctx);
|
||||
|
||||
if (backend_ctx->model)
|
||||
llama_free_model(backend_ctx->model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
os_free(backend_ctx);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load(void *ctx, graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g)
|
||||
{
|
||||
return unsupported_operation;
|
||||
}
|
||||
|
||||
static wasi_nn_error
|
||||
__load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
// make sure backend_ctx->config is initialized
|
||||
|
||||
struct llama_model_params model_params =
|
||||
llama_model_params_from_wasi_nn_llama_config(&backend_ctx->config);
|
||||
struct llama_model *model =
|
||||
llama_load_model_from_file(filename, model_params);
|
||||
if (model == NULL) {
|
||||
NN_ERR_PRINTF("Failed to load model from file %s", filename);
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
char buf[128] = { 0 };
|
||||
llama_model_desc(model, buf, 127);
|
||||
NN_INFO_PRINTF("Model desc %s", buf);
|
||||
#endif
|
||||
|
||||
backend_ctx->model = model;
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
// use default params
|
||||
wasm_edge_llama_default_configuration(&backend_ctx->config);
|
||||
return __load_by_name_with_configuration(ctx, filename, g);
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name_with_config(void *ctx, const char *filename, uint32_t filename_len,
|
||||
const char *config, uint32_t config_len, graph *g)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
wasm_edge_llama_default_configuration(&backend_ctx->config);
|
||||
|
||||
if (config != NULL) {
|
||||
// parse wasmedge config
|
||||
wasm_edge_llama_apply_configuration(config, &backend_ctx->config);
|
||||
}
|
||||
else {
|
||||
NN_INFO_PRINTF("No configuration provided, use default");
|
||||
}
|
||||
|
||||
return __load_by_name_with_configuration(ctx, filename, g);
|
||||
}
|
||||
|
||||
// It is assumed that model params shouldn't be changed in Config stage.
|
||||
// We only load the model once in the Load stage.
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
struct llama_context_params ctx_params =
|
||||
llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config);
|
||||
struct llama_context *llama_ctx =
|
||||
llama_new_context_with_model(backend_ctx->model, ctx_params);
|
||||
if (llama_ctx == NULL) {
|
||||
NN_ERR_PRINTF("Failed to create context for model");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
backend_ctx->ctx = llama_ctx;
|
||||
|
||||
NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict,
|
||||
llama_n_ctx(backend_ctx->ctx));
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
|
||||
tensor *wasi_nn_tensor)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
// tensor->data is the prompt string. ends with \0
|
||||
char *prompt_text = (char *)wasi_nn_tensor->data;
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_DBG_PRINTF("--------------------------------------------------");
|
||||
NN_DBG_PRINTF("prompt_text: %s", prompt_text);
|
||||
NN_DBG_PRINTF("--------------------------------------------------");
|
||||
#endif
|
||||
|
||||
// tokenize the prompt
|
||||
uint32_t n_token_max = llama_n_ctx(backend_ctx->ctx);
|
||||
uint32_t prompt_text_len = strlen(prompt_text);
|
||||
|
||||
if (backend_ctx->prompt == NULL) {
|
||||
backend_ctx->prompt = calloc(n_token_max, sizeof(llama_token));
|
||||
if (backend_ctx->prompt == NULL) {
|
||||
NN_ERR_PRINTF("Failed to allocate tokens_list");
|
||||
return runtime_error;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t n_tokens =
|
||||
llama_tokenize(backend_ctx->model, prompt_text, prompt_text_len,
|
||||
backend_ctx->prompt, n_token_max, true, false);
|
||||
if (n_tokens < 0) {
|
||||
NN_ERR_PRINTF("Failed to tokenize prompt text");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
backend_ctx->prompt_len = n_tokens;
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated
|
||||
// tokens
|
||||
int n_kv_req = n_tokens + (backend_ctx->config.n_predict - n_tokens);
|
||||
if (n_kv_req < 0 || (uint32_t)n_kv_req > n_token_max) {
|
||||
NN_ERR_PRINTF("the required KV cache size is not big enough, either "
|
||||
"reduce n_predict or increase n_ctx");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
compute(void *ctx, graph_execution_context exec_ctx)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
wasi_nn_error ret = runtime_error;
|
||||
|
||||
// reset the generation buffer
|
||||
if (backend_ctx->generation == NULL) {
|
||||
backend_ctx->generation =
|
||||
calloc(backend_ctx->config.n_predict, sizeof(llama_token));
|
||||
if (backend_ctx->generation == NULL) {
|
||||
NN_ERR_PRINTF("Failed to allocate generation");
|
||||
return runtime_error;
|
||||
}
|
||||
}
|
||||
|
||||
backend_ctx->generation_len = 0;
|
||||
|
||||
// check KV cache
|
||||
uint32_t n_ctx = llama_n_ctx(backend_ctx->ctx);
|
||||
if (n_ctx <= backend_ctx->generation_len) {
|
||||
NN_ERR_PRINTF(
|
||||
"ctx_size(%u) is not big enough(<%ld), please increase it", n_ctx,
|
||||
backend_ctx->generation_len);
|
||||
return context_full;
|
||||
}
|
||||
|
||||
// prepare the batch
|
||||
struct llama_batch batch =
|
||||
llama_batch_init(backend_ctx->config.batch_size, 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
llama_seq_id seq_ids[1] = { 0 };
|
||||
for (size_t i = 0; i < backend_ctx->prompt_len; i++) {
|
||||
llama_batch_add(&batch, backend_ctx->prompt[i], i, seq_ids,
|
||||
sizeof(seq_ids) / sizeof(seq_ids[0]), false);
|
||||
}
|
||||
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (batch.n_tokens > backend_ctx->config.n_predict) {
|
||||
NN_DBG_PRINTF("n_predict(%d) is not big enough(%d), please increase it",
|
||||
backend_ctx->config.n_predict, batch.n_tokens);
|
||||
return prompt_tool_long;
|
||||
}
|
||||
|
||||
if (llama_decode(backend_ctx->ctx, batch) != 0) {
|
||||
NN_ERR_PRINTF("First decode failed");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
// main loop
|
||||
int32_t n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
int32_t n_vocab = llama_n_vocab(backend_ctx->model);
|
||||
llama_token_data *candidates = NULL;
|
||||
|
||||
candidates = calloc(n_vocab, sizeof(llama_token_data));
|
||||
if (candidates == NULL) {
|
||||
NN_ERR_PRINTF("Failed to allocate candidates");
|
||||
goto fail;
|
||||
}
|
||||
|
||||
while (n_cur <= backend_ctx->config.n_predict) {
|
||||
// sample the next token
|
||||
float *logits =
|
||||
llama_get_logits_ith(backend_ctx->ctx, batch.n_tokens - 1);
|
||||
|
||||
memset(candidates, 0, sizeof(llama_token_data) * n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates, n_vocab, false };
|
||||
|
||||
// sample the most likely token
|
||||
llama_token new_token_id =
|
||||
llama_sample_token_greedy(backend_ctx->ctx, &candidates_p);
|
||||
|
||||
backend_ctx->generation[backend_ctx->generation_len++] = new_token_id;
|
||||
|
||||
#ifndef NDEBUG
|
||||
{
|
||||
char buf[128] = { 0 };
|
||||
llama_token_to_piece(backend_ctx->model, new_token_id, buf, 120, 0,
|
||||
true);
|
||||
printf("%d(%s),", new_token_id, buf);
|
||||
}
|
||||
#endif
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(backend_ctx->model, new_token_id)) {
|
||||
printf("\n");
|
||||
NN_INFO_PRINTF("reach the end of generation");
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the next batch
|
||||
llama_batch_clear(&batch);
|
||||
// push this new token for next evaluation
|
||||
llama_batch_add(&batch, new_token_id, n_cur, seq_ids,
|
||||
sizeof(seq_ids) / sizeof(seq_ids[0]), true);
|
||||
n_decode++;
|
||||
n_cur++;
|
||||
|
||||
if (llama_decode(backend_ctx->ctx, batch) != 0) {
|
||||
NN_ERR_PRINTF("Secondary decode failed");
|
||||
goto fail;
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
ret = success;
|
||||
fail:
|
||||
llama_batch_free(batch);
|
||||
if (candidates != NULL) {
|
||||
free(candidates);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
|
||||
tensor_data output_tensor, uint32_t *output_tensor_size)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
// Compatibility with WasmEdge
|
||||
if (index > 1) {
|
||||
NN_ERR_PRINTF("Invalid output index %d", index);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
// Index 1 is for the metadata of the outputs.
|
||||
if (index == 1) {
|
||||
char output_metadata[128] = { 0 };
|
||||
llama_build_output_metadata(backend_ctx, output_metadata, 127);
|
||||
|
||||
if (backend_ctx->config.stream_stdout) {
|
||||
printf("%s\n", output_metadata);
|
||||
}
|
||||
|
||||
memcpy(output_tensor, output_metadata, strlen(output_metadata));
|
||||
*output_tensor_size = strlen(output_metadata);
|
||||
return success;
|
||||
}
|
||||
|
||||
// token -> piece -> output_tensor
|
||||
if (backend_ctx->config.stream_stdout) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
size_t end_pos = 0;
|
||||
for (size_t i = 0; i < backend_ctx->generation_len; i++) {
|
||||
char buf[128] = { 0 };
|
||||
llama_token_to_piece(backend_ctx->model, backend_ctx->generation[i],
|
||||
buf, 120, 0, true);
|
||||
|
||||
if (backend_ctx->config.stream_stdout) {
|
||||
printf("%s", buf);
|
||||
}
|
||||
|
||||
memcpy(output_tensor + end_pos, buf, strlen(buf));
|
||||
end_pos += strlen(buf);
|
||||
}
|
||||
|
||||
if (backend_ctx->config.stream_stdout) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
*output_tensor_size = end_pos;
|
||||
return success;
|
||||
}
|
||||
Reference in New Issue
Block a user