Google发布了TensorFlowLite,也提供了相关的Demo和文档,单如何把一个自己训练的深度神经网络模型,应用到TensorFlowLite Android App上,中间还是有不少的坑…
背景
Google发布了TensorFlowLite,也提供了相关的Demo和文档,单如何把一个自己训练的深度神经网络模型,应用到TensorFlowLite Android App上,中间还是有不少的坑。我们的目标就是完成从训练一个TensorFlow模型到把这个模型在Android平台上run起来整个完整的流程。
实现
1、训练一个TensorFlow模型
这里就不详细描述了,作为示例,我训练了一个mnist的model,采用两层cnn模型,模型精确度0.97
2、生成tflite
这部分的代码主要在tensorflow/tensorflow/contrib/lite/toco下面,主要有两种方式:
- Command-line
- Python API
文档:TOCO: TensorFlow Lite Optimizing Converter
Python API
Python API 文档:Python API examples
笔者是在TensorFlow1.9.0上运行python api,之前的版本可能不支持,或者调用的方式不一样。利用python api生成tflite主要有四种方式:
- Exporting a GraphDef from tf.Session
- Exporting a GraphDef from file
- Exporting a SavedModel
- Exporting a tf.keras File
Exporting a GraphDef from tf.Session
官方示例:
import tensorflow as tf
graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
input_arrays = ["input"]
output_arrays = ["MobilenetV1/Predictions/Softmax"]
converter = tf.contrib.lite.TocoConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
结论:
from_session方法生成的tflite不能保存update之后的变量,只能保存graph_def,所有变量都是初始值。
这个地方很坑,官方文档写得也不是很清楚,我开始一直以为是我写的Android端框架有问题,或者是输入的问题,最后没想到是保存的tflite的问题。
Exporting a GraphDef from frozen graph model file
这种方法是把官方的Exporting a GraphDef from file方法结合命令行生成frozen model
1、保存graph model和checkpoint文件
def save_model(sess):
saver = tf.train.Saver()
tf.train.write_graph(sess.graph_def, "test_model/", "test_graph.pb", as_text=False)
saver.save(sess, "test_model/test_model.ckpt")
2、利用freeze_graph工具生成frozen model
例如:
freeze_graph --input_graph=/Users/hailiangliao/Develop/git/android-test/deep-neural-networks/tf/mnist_model/mnist_graph.pb --input_checkpoint=/Users/hailiangliao/Develop/git/android-test/deep-neural-networks/tf/mnist_model/mnist_model.ckpt --input_binary=true --output_graph=/Users/hailiangliao/Develop/git/android-test/deep-neural-networks/tf/mnist_model/mnist_frozen.pb --output_node_names=prediction
–input_graph:保存的model文件
–input_checkpoint:checkpoint文件,保存变量的值
–output_graph:输出的frozen model的路径
–output_node_names:output的Tensor的名字,注意不是python变量的名字,而是Tensor的名字
3、生成tflite文件
主要就是利用上述生成的frozen model文件来生成,示例代码:
import tensorflow as tf
graph_def_file = "/Users/hailiangliao/Develop/git/android-test/deep-neural-networks/tf/mnist_model/mnist_frozen.pb"
input_arrays = ["input"]
output_arrays = ["prediction"]
tflite_model_file = "mnist_model.tflite"
def get_tflite_from_frozen_pb(graph_def_file, input_arrays, output_arrays, tflite_model_file):
converter = tf.contrib.lite.TocoConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open(tflite_model_file, "wb").write(tflite_model)
get_tflite_from_frozen_pb(graph_def_file, input_arrays, output_arrays, tflite_model_file)
input_arrays:input的Tensor的名字 output_arrays:output的Tensor的名字 这样生成的tflite就包含了整个模型的定义和训练之后的变量的值。
3、放在TensorFlow Lite框架中使用
这部分以后来写