wasi-nn: Support uint8 quantized networks (#2433)
Support (non-full) uint8 quantized networks. Inputs and outputs are still required to be `float`. The (de)quantization is done internally by wasi-nn. Example generated from `quantized_model.py`:  Visualization with [netron](https://netron.app/).
This commit is contained in:
@ -285,14 +285,37 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
auto *input =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
||||
index);
|
||||
if (input == NULL)
|
||||
return missing_memory;
|
||||
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||
NN_DBG_PRINTF("No quantization information. Using float as default");
|
||||
float *it =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
||||
index);
|
||||
|
||||
int size = model_tensor_size * sizeof(float);
|
||||
bh_memcpy_s(it, size, input_tensor->data, size);
|
||||
}
|
||||
else { // TODO: Assumming uint8 quantized networks.
|
||||
TfLiteAffineQuantization *quant_info =
|
||||
(TfLiteAffineQuantization *)tensor->quantization.params;
|
||||
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
|
||||
NN_ERR_PRINTF("Quantization per channel is not supported");
|
||||
return runtime_error;
|
||||
}
|
||||
uint8_t *it =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<uint8_t>(
|
||||
index);
|
||||
|
||||
float scale = quant_info->scale->data[0];
|
||||
float zero_point = (float)quant_info->zero_point->data[0];
|
||||
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
|
||||
zero_point);
|
||||
|
||||
float *input_tensor_f = (float *)input_tensor->data;
|
||||
for (uint32_t i = 0; i < model_tensor_size; ++i) {
|
||||
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
|
||||
}
|
||||
}
|
||||
|
||||
bh_memcpy_s(input, model_tensor_size * sizeof(float), input_tensor->data,
|
||||
model_tensor_size * sizeof(float));
|
||||
return success;
|
||||
}
|
||||
|
||||
@ -325,6 +348,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
||||
|
||||
if (index + 1 > num_output_tensors) {
|
||||
NN_ERR_PRINTF("Index %d is invalid.", index);
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
@ -343,15 +367,37 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||
return missing_memory;
|
||||
}
|
||||
|
||||
float *tensor_f =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
||||
index);
|
||||
for (uint32_t i = 0; i < model_tensor_size; ++i)
|
||||
NN_DBG_PRINTF("output: %f", tensor_f[i]);
|
||||
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||
NN_DBG_PRINTF("No quantization information");
|
||||
float *ot =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
||||
index);
|
||||
|
||||
int size = model_tensor_size * sizeof(float);
|
||||
bh_memcpy_s(output_tensor, size, ot, size);
|
||||
}
|
||||
else { // TODO: Assumming uint8 quantized networks.
|
||||
TfLiteAffineQuantization *quant_info =
|
||||
(TfLiteAffineQuantization *)tensor->quantization.params;
|
||||
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
|
||||
NN_ERR_PRINTF("Quantization per channel is not supported");
|
||||
return runtime_error;
|
||||
}
|
||||
uint8_t *ot = tfl_ctx->interpreters[ctx]
|
||||
.interpreter->typed_output_tensor<uint8_t>(index);
|
||||
|
||||
float scale = quant_info->scale->data[0];
|
||||
float zero_point = (float)quant_info->zero_point->data[0];
|
||||
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
|
||||
zero_point);
|
||||
|
||||
float *output_tensor_f = (float *)output_tensor;
|
||||
for (uint32_t i = 0; i < model_tensor_size; ++i) {
|
||||
output_tensor_f[i] = (ot[i] - zero_point) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
*output_tensor_size = model_tensor_size;
|
||||
bh_memcpy_s(output_tensor, model_tensor_size * sizeof(float), tensor_f,
|
||||
model_tensor_size * sizeof(float));
|
||||
return success;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user