diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index ef1c77e546..f9116b7da0 100755 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -4,7 +4,7 @@ import os import torch from collections import OrderedDict -from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint +from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint MODEL_KEY = 'model' ARGS_KEY = 'args' @@ -92,7 +92,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False if pp_index == 0: meg_embedding_sd.update(nested_embedding_sd) - if pp_index == ds_checkpoint.pp_degree -1: + if pp_index == ds_checkpoint.pp_degree - 1: for key, value in embedding_sd.items(): if key.startswith(WORD_EMBEDDINGS_KEY): fields = key.split('.') @@ -111,7 +111,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False if pp_index == 0: checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd - if pp_index == ds_checkpoint.pp_degree -1: + if pp_index == ds_checkpoint.pp_degree - 1: checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args()