简单来说:所谓模型就是一个滤波器,训练的权重就是滤波系数,输入经过滤波器后得到一个输出。所以嵌入式AI部署一般就是解析模型得到“滤波系数”,输入信号进行一系列类似"滤波"运算,得到最终输出。
所以需要搞明白模型怎么解析,这篇讲TFllite模型的格式以及它的解析。
1 TFLite格式简介
Tflite文件由Tensorflow提供的TOCO工具生成的轻量级模型,存储格式是flatbuffer,flatbuffer是google开源的一种二进制序列化格式,与protobuf类似。
下图(来自于参考2)描述了 模型训练->模型转化为Tflite格式->模型部署 的大致流程。从图中可以看到获取Tflite的三种方式:
# TensorFlow 2.x
tf.lite.TFLiteConverter.from_saved_model(): # 由SavedModel转化
tf.lite.TFLiteConverter.from_keras_model(): # 由Keras model转化
tf.lite.TFLiteConverter.from_concrete_functions(): # 由具体函数转化
2 TFLite格式分析
例如我们已经训练得到了一个tflite模型(mnist_model.tflite),下面分析其格式:
方法1: Netron查看tflite模型
Netron 是一款常见的可视化工具,支持网页查看常见的AI模型,支持非常丰富的格式(ONNX, Tensorflow, Pytorch, Keras, Caffe等)
网页地址: https://netron.app/
将mnist_model.tflite导入,可以得到下图,可见mnist_model.tflite含有一个Reshape层,2个FullyConnected层,一个Relu层以及一个Softmax层
方法2:利用flatbuffer开源工具flatc
Tflite格式是flatbuffer格式,其优点是:解码速度极快、内存占用小,缺点是:数据没有可读性,需要借助其他工具实现可视化。
可使用google flatbuffer开源工具flatc,flatc可以实现tflite格式到jason文件的自动转换,解析时需要用到schema.fbs协议文件。
step1:安装flatc
# flatbuffer源码 https://github.com/google/flatbuffers
# 下载后进入文件夹,执行如下命令
mkdir build && cd build
cmake ../ # 生成Makefile
make # 编译
make install # 安装flatcstep2:获取schema.fbs
schema.fbs是二进制协议文件,一般改动较小。直接从Tensorflow的源码中获取(如果后面的转换步骤出现问题,可以找到对应TensorFlow版本的schema.fbs文件试试)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs
step3:转化为json
flatc -t schema.fbs -- mnist_model.tflite这样获取得到mnist_model.json:
{
version: 3,
operator_codes: [
...
],
subgraphs: [
tensors: [],
inputs: [],
outputs: [],
operators: [],
],
description: "MLIR Converted.",
buffers: [],
}这个数据结构描述了tflite的整体框架及所有细节,这个放到另一篇文档里讲。
方法3: 利用tensorflow提供的接口分析
tf.lite.Interpreter可以读取tflite模型,但是python接口没有描述模型结构(op node节点间的连接关系)
import tensorflow as tf
import numpy as np
#加载模型
interpreter = tf.lite.Interpreter(model_path="./mnist_model.tflite")
interpreter.allocate_tensors()
# 模型输入和输出细节
# input_details = interpreter.get_input_details()
# output_details = interpreter.get_output_details()
# 获取模型的tensor的详细信息
tensor = interpreter.get_tensor_details()
print(tensor)得到的结果如下:
[
{'name': 'serving_default_flatten_2_input:0',
'index': 0,
'shape': array([ 1, 28, 28], dtype=int32),
'shape_signature': array([-1, 28, 28], dtype=int32),
'dtype': <class 'numpy.float32'>,
'quantization': (0.0, 0),
'quantization_parameters': {
'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0
},
'sparsity_parameters': {}
},
{'name': 'sequential_2/dense_5/BiasAdd/ReadVariableOp',
'index': 1,
'shape': array([10], dtype=int32),
'shape_signature': array([10], dtype=int32),
'dtype': <class 'numpy.float32'>,
'quantization': (0.0, 0),
'quantization_parameters': {
'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0
},
'sparsity_parameters': {}
},
{'name': 'sequential_2/dense_4/BiasAdd/ReadVariableOp',
'index': 2,
'shape': array([128], dtype=int32),
'shape_signature': array([128], dtype=int32),
'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0),
'quantization_parameters': {
'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0
},
'sparsity_parameters': {}
},
...
]方法4:文本解析tflite文件
Flatbuffer格式的tflite文件,转成可读的python dict格式,并可描述模型完整推理流程。
直接下载Tensorflow提供的visualize.py工具,
下载地址:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py
|