Skip to content

Commit

Permalink
fix save_inference_model (PaddlePaddle#1198)
Browse files Browse the repository at this point in the history
* fix save_inference_model
  • Loading branch information
ceci3 authored Jun 28, 2022
1 parent a10fc88 commit e95a22c
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
2 changes: 1 addition & 1 deletion paddleslim/auto_compression/auto_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_strategy_config(strategy_str, model_type):
'prune_strategy':
'gmp', ### default unstruture prune strategy is gmp
'prune_mode': 'ratio',
'pruned_ratio': float(tmp_s[1]),
'ratio': float(tmp_s[1]),
'local_sparsity': True,
'prune_params_type': 'conv1x1_only'
}
Expand Down
22 changes: 11 additions & 11 deletions paddleslim/auto_compression/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,14 @@ def _get_final_train_config(self, train_config, strategy_config,

train_configs = [train_config]
for idx in range(1, len(self._strategy)):
if 'qat' in self._strategy[idx]:
### if compress strategy more than one, the train config in the yaml set for prune
### the train config for quantization is extrapolate from the yaml
if 'qat' in self._strategy[idx] or 'ptq' in self._strategy[idx]:
### If compress strategy more than one, the TrainConfig in the yaml only used in prune.
### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if self.model_type != 'transformer':
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
tmp_train_config['train_iter'] = int(
train_config.train_iter * 0.1)
Expand All @@ -228,8 +229,6 @@ def _get_final_train_config(self, train_config, strategy_config,
map(lambda x: x * 0.1, train_config.learning_rate[
'values']))
train_cfg = TrainConfig(**tmp_train_config)
elif 'ptq' in self._strategy[idx]:
train_cfg = None
else:
tmp_train_config = copy.deepcopy(train_config.__dict__)
train_cfg = TrainConfig(**tmp_train_config)
Expand Down Expand Up @@ -802,11 +801,12 @@ def _save_model(self, test_program_info, strategy, strategy_idx):
for name in test_program_info.feed_target_names
]

model_name = '.'.join(self.model_filename.split(
'.')[:-1]) if self.model_filename is not None else 'model'
path_prefix = os.path.join(model_dir, model_name)
paddle.static.save_inference_model(
path_prefix=str(model_dir),
path_prefix=path_prefix,
feed_vars=feed_vars,
fetch_vars=test_program_info.fetch_targets,
executor=self._exe,
program=test_program,
model_filename=self.model_filename,
params_filename=self.params_filename)
program=test_program)
6 changes: 4 additions & 2 deletions paddleslim/auto_compression/utils/fake_ptq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
Expand Down Expand Up @@ -111,10 +112,11 @@ def post_quant_fake(executor,
_program = graph.to_program()

feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model'
save_model_path = os.path.join(save_model_path, model_name)
paddle.static.save_inference_model(
path_prefix=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
feed_vars=feed_vars,
fetch_vars=_fetch_list,
executor=executor,
Expand Down
15 changes: 12 additions & 3 deletions paddleslim/auto_compression/utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import paddle

__all__ = ['load_inference_model']
Expand All @@ -29,8 +30,16 @@ def load_inference_model(path_prefix,
model_filename=model_filename,
params_filename=params_filename))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))
model_name = '.'.join(model_filename.split('.')
[:-1]) if model_filename is not None else 'model'
if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))

return [inference_program, feed_target_names, fetch_targets]
14 changes: 8 additions & 6 deletions paddleslim/auto_compression/utils/prune_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
feed_vars = [
inference_program.global_block().var(name) for name in feed_target_names
]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=executor,
program=inference_program,
model_filename=model_name,
params_filename=param_name)
program=inference_program)
print("The pruned model is saved in: ", save_path)


Expand Down Expand Up @@ -160,11 +161,12 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
feed_vars = [
main_program.global_block().var(name) for name in feed_target_names
]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=executor,
program=main_program,
model_filename=model_name,
params_filename=param_name)
program=main_program)
6 changes: 3 additions & 3 deletions paddleslim/quant/post_quant_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def quantize(cfg):
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_inference_program, feed_target_names, fetch_targets]= fluid.io.load_inference_model( \
dirname=g_quant_config.model_filename, \
dirname=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
float_metric = g_quant_config.eval_function(
Expand All @@ -320,8 +320,8 @@ def quantize(cfg):
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
quant_metric = g_quant_config.eval_function(
g_quant_config.executor, inference_program, feed_target_names,
fetch_targets)
g_quant_config.executor, quant_inference_program,
feed_target_names, fetch_targets)

emd_loss = float(abs(float_metric - quant_metric)) / float_metric

Expand Down

0 comments on commit e95a22c

Please sign in to comment.