Skip to content

Commit

Permalink
Update merge_lora script (#457)
Browse files Browse the repository at this point in the history
Fix merging bug for 1.3B models.
  • Loading branch information
iMountTai authored Dec 11, 2023
1 parent 0189e8b commit 7b60526
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions scripts/merge_llama2_with_chinese_lora_low_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@
help="Show detailed debugging messages")


emb_to_model_size = {
4096 : '7B',
5120 : '13B',
8192 : '70B',
layers_to_model_size = {
4 : '1.3B',
32 : '7B',
40 : '13B',
80 : '70B',
}
num_shards_of_models = {'7B': 1, '13B': 2, '70B': 8}
num_shards_of_models = {'1.3B': 1, '7B': 1, '13B': 2, '70B': 8}
params_of_models = {
'1.3B':
{
"dim": 4096,
"multiple_of": 256,
"n_heads": 32,
"n_layers": 4,
"norm_eps": 1e-05,
"vocab_size": -1,
},
'7B':
{
"dim": 4096,
Expand Down Expand Up @@ -73,6 +83,12 @@ def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight


def jsonload(filename):
with open(filename, "r") as file:
d = json.load(file)
return d


# Borrowed and modified from https://github.com/tloen/alpaca-lora
def translate_state_dict_key(k):
k = k.replace("base_model.model.", "")
Expand Down Expand Up @@ -241,7 +257,7 @@ def merge_shards(output_dir, num_shards: int):
lora_state_dict = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu')
if 'base_model.model.model.embed_tokens.weight' in lora_state_dict:
lora_vocab_size = lora_state_dict['base_model.model.model.embed_tokens.weight'].shape[0]
assert lora_vocab_size==len(tokenizer), \
assert lora_vocab_size == len(tokenizer), \
(f"The vocab size of the tokenizer {len(tokenizer)} does not match the vocab size of the LoRA weight {lora_vocab_size}!\n")
tokenizers_and_loras.append(
{
Expand All @@ -255,19 +271,21 @@ def merge_shards(output_dir, num_shards: int):
if not os.path.exists(base_model_path):
print("Cannot find lora model on the disk. Downloading lora model from hub...")
base_model_path = snapshot_download(repo_id=base_model_path)
ckpt_filenames = sorted([f for f in os.listdir(base_model_path) if re.match('pytorch_model-(\d+)-of-(\d+).bin',f)])
if len(ckpt_filenames)==0:
if os.path.exists(os.path.join(base_model_path, "pytorch_model.bin")):
ckpt_filenames = ["pytorch_model.bin"]
else:
ckpt_filenames = sorted([f for f in os.listdir(base_model_path) if re.match('pytorch_model-(\d+)-of-(\d+).bin',f)])
if len(ckpt_filenames) == 0:
raise FileNotFoundError(f"Cannot find base model checkpoints in ${base_model_path}. Please make sure the checkpoints are saved in the HF format.")
embedding_size = None
layers = jsonload(os.path.join(base_model_path, "config.json"))["num_hidden_layers"]
model_size = None
total_size = 0
for index, filename in enumerate(ckpt_filenames):
print(f"Loading ckpt {filename}")
state_dict = torch.load(os.path.join(base_model_path,filename), map_location='cpu')
if index == 0:
embedding_size = state_dict['model.embed_tokens.weight'].shape[1]
model_size = emb_to_model_size[embedding_size]
if output_type=='pth':
model_size = layers_to_model_size[layers]
if output_type == 'pth':
params = params_of_models[model_size]
num_shards = num_shards_of_models[model_size]
n_layers = params["n_layers"]
Expand Down Expand Up @@ -297,10 +315,10 @@ def merge_shards(output_dir, num_shards: int):
weight_size = state_dict[k].numel() * dtype_byte_size(state_dict[k].dtype)
total_size += weight_size

if output_type=='huggingface':
if output_type == 'huggingface':
print(f"Saving ckpt {filename} to {output_dir} in HF format...")
torch.save(state_dict,os.path.join(output_dir, filename))
elif output_type=='pth':
elif output_type == 'pth':
print(f"Converting to pth format...")
save_shards(model_sd=state_dict, num_shards=num_shards,prefix=f"L{index+1}-", verbose=args.verbose)
del state_dict
Expand All @@ -316,6 +334,8 @@ def merge_shards(output_dir, num_shards: int):

if output_type=='huggingface':
configs = ('config.json', 'generation_config.json', 'pytorch_model.bin.index.json')
if model_size == "1.3B":
configs = ('config.json', 'generation_config.json')
for config in configs:
if os.path.exists(os.path.join(lora_model_path, config)):
print(f"Saving {config} from {lora_model_path}")
Expand All @@ -325,9 +345,9 @@ def merge_shards(output_dir, num_shards: int):
print(f"Saving {config} from {base_model_path}")
with open(os.path.join(base_model_path, config),'r') as f:
obj = json.load(f)
if config=='config.json':
if config == 'config.json':
obj['vocab_size'] = len(tokenizers_and_loras[-1]['tokenizer'])
if config=='pytorch_model.bin.index.json':
if config == 'pytorch_model.bin.index.json':
obj['metadata']['total_size'] = total_size
with open(os.path.join(output_dir, config), 'w') as f:
json.dump(obj, f, indent=2)
Expand Down

0 comments on commit 7b60526

Please sign in to comment.