wasi-nn: Support uint8 quantized networks (#2433)

Support (non-full) uint8 quantized networks.
Inputs and outputs are still required to be `float`. The (de)quantization is done internally by wasi-nn.

Example generated from `quantized_model.py`:
![Screenshot from 2023-08-07 17-57-05](https://github.com/bytecodealliance/wasm-micro-runtime/assets/80318361/91f12ff6-870c-427a-b1dc-e307f7d1f5ee)

Visualization with [netron](https://netron.app/).
This commit is contained in:
tonibofarull
2023-08-11 01:55:40 +02:00
committed by GitHub
parent a550f4d9f7
commit 0b0af1b3df
7 changed files with 176 additions and 17 deletions

View File

@ -0,0 +1,30 @@
# Copyright (C) 2019 Intel Corporation. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import tensorflow as tf
import numpy as np
import pathlib
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=[5, 5, 1]),
tf.keras.layers.AveragePooling2D(
pool_size=(5, 5), strides=None, padding="valid", data_format=None)
])
def representative_dataset():
for _ in range(1000):
data = np.random.randint(0, 25, (1, 5, 5, 1))
yield [data.astype(np.float32)]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
tflite_model = converter.convert()
tflite_models_dir = pathlib.Path("./")
tflite_model_file = tflite_models_dir / "quantized_model.tflite"
tflite_model_file.write_bytes(tflite_model)