diff --git a/tools/program.py b/tools/program.py index 9d574a370b..a541472da1 100755 --- a/tools/program.py +++ b/tools/program.py @@ -17,6 +17,7 @@ from __future__ import print_function import os +import gc import sys import platform import yaml @@ -208,6 +209,7 @@ def train( profiler_options = config["profiler_options"] print_mem_info = config["Global"].get("print_mem_info", True) uniform_output_enabled = config["Global"].get("uniform_output_enabled", False) + export_during_train = config["Global"].get("export_during_train", False) global_step = 0 if "global_step" in pre_best_model_dict: @@ -489,12 +491,14 @@ def train( best_model_dict.update(cur_metric) best_model_dict["best_epoch"] = epoch prefix = "best_accuracy" - if uniform_output_enabled: + if export_during_train: export( config, model, os.path.join(save_model_dir, prefix, "inference"), ) + gc.collect() + if uniform_output_enabled: model_info = {"epoch": epoch, "metric": best_model_dict} else: model_info = None @@ -540,8 +544,10 @@ def train( reader_start = time.time() if dist.get_rank() == 0: prefix = "latest" - if uniform_output_enabled: + if export_during_train: export(config, model, os.path.join(save_model_dir, prefix, "inference")) + gc.collect() + if uniform_output_enabled: model_info = {"epoch": epoch, "metric": best_model_dict} else: model_info = None @@ -568,8 +574,10 @@ def train( if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: prefix = "iter_epoch_{}".format(epoch) - if uniform_output_enabled: + if export_during_train: export(config, model, os.path.join(save_model_dir, prefix, "inference")) + gc.collect() + if uniform_output_enabled: model_info = {"epoch": epoch, "metric": best_model_dict} else: model_info = None