-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model.py
176 lines (148 loc) · 6.47 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export model (float or quantized tflite, and saved model) from a trained checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import efficientnet_builder
import imagenet_input
from edgetpu import efficientnet_edgetpu_builder
flags.DEFINE_string("model_name", None, "Model name to eval.")
flags.DEFINE_string("ckpt_dir", None, "Path to the training checkpoint")
flags.DEFINE_boolean("enable_ema", True, "Enable exponential moving average.")
flags.DEFINE_string("data_dir", None,
"Image dataset directory for post training quantization.")
flags.DEFINE_string("output_tflite", None, "Path to output tflite file.")
flags.DEFINE_bool("quantize", True,
"Quantize model to uint8 before exporting tflite model.")
flags.DEFINE_integer(
"num_steps", 1000,
"Number of post-training quantization calibration steps to run.")
flags.DEFINE_integer("image_size", 224, "Size of the input image.")
flags.DEFINE_integer("batch_size", 1, "Batch size of input tensor.")
flags.DEFINE_string("endpoint_name", None, "Endpoint name")
flags.DEFINE_string("output_saved_model_dir", None,
"Directory in which to save the saved_model.")
FLAGS = flags.FLAGS
def get_model_builder(model_name):
if model_name.startswith("efficientnet-edgetpu"):
return efficientnet_edgetpu_builder
elif model_name.startswith("efficientnet"):
return efficientnet_builder
else:
raise ValueError(
"Model must be either efficientnet-b* or efficientnet-edgetpu")
def restore_model(sess, ckpt_dir, enable_ema=True):
"""Restore variables from checkpoint dir."""
sess.run(tf.global_variables_initializer())
checkpoint = tf.train.latest_checkpoint(ckpt_dir)
if enable_ema:
ema = tf.train.ExponentialMovingAverage(decay=0.0)
ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
for v in tf.global_variables():
if "moving_mean" in v.name or "moving_variance" in v.name:
ema_vars.append(v)
ema_vars = list(set(ema_vars))
var_dict = ema.variables_to_restore(ema_vars)
else:
var_dict = None
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_dict, max_to_keep=1)
saver.restore(sess, checkpoint)
def representative_dataset_gen():
"""Gets a python generator of image numpy arrays for ImageNet."""
params = dict(batch_size=FLAGS.batch_size)
imagenet_eval = imagenet_input.ImageNetInput(
is_training=False,
data_dir=FLAGS.data_dir,
transpose_input=False,
cache=False,
image_size=FLAGS.image_size,
num_parallel_calls=1,
use_bfloat16=False,
include_background_label=True,
)
data = imagenet_eval.input_fn(params)
def preprocess_map_fn(images, labels):
del labels
model_builder = get_model_builder(FLAGS.model_name)
images -= tf.constant(
model_builder.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype)
images /= tf.constant(
model_builder.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype)
return images
data = data.map(preprocess_map_fn)
iterator = data.make_one_shot_iterator()
for _ in range(FLAGS.num_steps):
# In eager context, we can get a python generator from a dataset iterator.
images = iterator.get_next()
yield [images]
def main(_):
# Enables eager context for TF 1.x. TF 2.x will use eager by default.
# This is used to conveniently get a representative dataset generator using
# TensorFlow training input helper.
tf.enable_eager_execution()
model_builder = get_model_builder(FLAGS.model_name)
with tf.Graph().as_default(), tf.Session() as sess:
images = tf.placeholder(
tf.float32,
shape=(1, FLAGS.image_size, FLAGS.image_size, 3),
name="images")
logits, endpoints = model_builder.build_model(images, FLAGS.model_name,
False)
if FLAGS.endpoint_name:
output_tensor = endpoints[FLAGS.endpoint_name]
else:
output_tensor = tf.nn.softmax(logits)
restore_model(sess, FLAGS.ckpt_dir, FLAGS.enable_ema)
if FLAGS.output_saved_model_dir:
signature_def_map = {
"serving_default":
tf.compat.v1.saved_model.signature_def_utils
.predict_signature_def({"input": images},
{"output": output_tensor})
}
builder = tf.compat.v1.saved_model.Builder(FLAGS.output_saved_model_dir)
builder.add_meta_graph_and_variables(
sess, ["serve"], signature_def_map=signature_def_map)
builder.save()
print("Saved model written to %s" % FLAGS.output_saved_model_dir)
converter = tf.lite.TFLiteConverter.from_session(sess, [images],
[output_tensor])
if FLAGS.quantize:
if not FLAGS.data_dir:
raise ValueError(
"Post training quantization requires data_dir flag to point to the "
"calibration dataset. To export a float model, set "
"--quantize=False.")
converter.representative_dataset = tf.lite.RepresentativeDataset(
representative_dataset_gen)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.lite.constants.QUANTIZED_UINT8
converter.inference_output_type = tf.lite.constants.QUANTIZED_UINT8
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
tflite_buffer = converter.convert()
tf.gfile.GFile(FLAGS.output_tflite, "wb").write(tflite_buffer)
print("tflite model written to %s" % FLAGS.output_tflite)
if __name__ == "__main__":
flags.mark_flag_as_required("model_name")
flags.mark_flag_as_required("ckpt_dir")
flags.mark_flag_as_required("output_tflite")
app.run(main)