wasi-nn: Support uint8 quantized networks (#2433)

Support (non-full) uint8 quantized networks.
Inputs and outputs are still required to be `float`. The (de)quantization is done internally by wasi-nn.

Example generated from `quantized_model.py`:
![Screenshot from 2023-08-07 17-57-05](https://github.com/bytecodealliance/wasm-micro-runtime/assets/80318361/91f12ff6-870c-427a-b1dc-e307f7d1f5ee)

Visualization with [netron](https://netron.app/).
This commit is contained in:
tonibofarull
2023-08-11 01:55:40 +02:00
committed by GitHub
parent a550f4d9f7
commit 0b0af1b3df
7 changed files with 176 additions and 17 deletions

View File

@ -285,14 +285,37 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
return invalid_argument;
}
auto *input =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
index);
if (input == NULL)
return missing_memory;
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information. Using float as default");
float *it =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
index);
int size = model_tensor_size * sizeof(float);
bh_memcpy_s(it, size, input_tensor->data, size);
}
else { // TODO: Assumming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
(TfLiteAffineQuantization *)tensor->quantization.params;
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
NN_ERR_PRINTF("Quantization per channel is not supported");
return runtime_error;
}
uint8_t *it =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<uint8_t>(
index);
float scale = quant_info->scale->data[0];
float zero_point = (float)quant_info->zero_point->data[0];
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *input_tensor_f = (float *)input_tensor->data;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
}
}
bh_memcpy_s(input, model_tensor_size * sizeof(float), input_tensor->data,
model_tensor_size * sizeof(float));
return success;
}
@ -325,6 +348,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
if (index + 1 > num_output_tensors) {
NN_ERR_PRINTF("Index %d is invalid.", index);
return runtime_error;
}
@ -343,15 +367,37 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
return missing_memory;
}
float *tensor_f =
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
index);
for (uint32_t i = 0; i < model_tensor_size; ++i)
NN_DBG_PRINTF("output: %f", tensor_f[i]);
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
float *ot =
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
index);
int size = model_tensor_size * sizeof(float);
bh_memcpy_s(output_tensor, size, ot, size);
}
else { // TODO: Assumming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
(TfLiteAffineQuantization *)tensor->quantization.params;
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
NN_ERR_PRINTF("Quantization per channel is not supported");
return runtime_error;
}
uint8_t *ot = tfl_ctx->interpreters[ctx]
.interpreter->typed_output_tensor<uint8_t>(index);
float scale = quant_info->scale->data[0];
float zero_point = (float)quant_info->zero_point->data[0];
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *output_tensor_f = (float *)output_tensor;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}
}
*output_tensor_size = model_tensor_size;
bh_memcpy_s(output_tensor, model_tensor_size * sizeof(float), tensor_f,
model_tensor_size * sizeof(float));
return success;
}