-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_tf_to_trt.py
56 lines (47 loc) · 1.8 KB
/
convert_tf_to_trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#from helper import ModelOptimizer
import tensorrt as trt
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert
from time import perf_counter
print(f"tensorflow version={tf.__version__}")
print(f"tensorrt version={trt.__version__}")
PRECISION = "FP16"
GPU_RAM_4G = 4000000000
GPU_RAM_6G = 6000000000
GPU_RAM_8G = 8000000000
MPL = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/multipose_lightning"
SPL = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/singlepose_lightning"
SPT = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/singlepose_thunder"
YLXT = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/yolox_tiny_tf"
model_dir = YLXT
model_out_dir = model_dir + "_fp16"
# dotw: uses helper but error, helper not found...
#opt_model = ModelOptimizer(model_dir)
#model_fp16 = opt_model.convert(model_dir + "_fp16", precision=PRECISION)
# dotw: error, create_inference_graph() missing 2 required positional arguments:
# 'input_graph_def' and 'outputs'
#trt_convert.create_inference_graph(
# input_saved_model_dir = model_dir,
# output_saved_model_dir = model_out_dir
#)
conv_parms = trt_convert.TrtConversionParams(
precision_mode = trt_convert.TrtPrecisionMode.FP16,
max_workspace_size_bytes = GPU_RAM_4G,
# max_batch_size = 1
)
converter = trt_convert.TrtGraphConverterV2(
input_saved_model_dir = model_dir,
conversion_params = conv_parms
)
print(f"generating {model_out_dir}")
print("converting original model...")
st = perf_counter()
converter.convert()
#converter.build(input_fn = self.my_input_fn)
et = perf_counter()
print(f"conversion time = {et - st:.2f} sec")
print("saving generated model...")
st = perf_counter()
converter.save(model_out_dir)
et = perf_counter()
print(f"save time = {et - st:.2f} sec")