Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWKV-7 conversion and evals #135

Merged
merged 8 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/utils/test_rwkv7_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-

import argparse
import json

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from fla.models.rwkv7 import RWKV7ForCausalLM


def test_rwkv7_lm_eval(model, tokenizer, task_names=["lambada_openai"]):
tokenizer1 = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
eos_token="<|endoftext|>",
pad_token="<|padding|>"
)
hf_model = HFLM(pretrained=model, tokenizer=tokenizer1)
results = evaluator.simple_evaluate(
model=hf_model,
tasks=task_names,
batch_size=1,
)
# {
# "lambada_openai": {
# "perplexity,none": 14.457888475382047,
# "perplexity_stderr,none": 0.4455143803996477,
# "acc,none": 0.4585678245682127,
# "acc_stderr,none": 0.006942020515885241,
# "alias": "lambada_openai"
# }
# }
print(json.dumps(results['results'], indent=2))

# official results:
# pile 168M: lambada_openai ppl 14.2 acc 45.6%
# pile 421M: lambada_openai ppl 8.14 acc 55.6%
# pile 1.47B: lambada_openai ppl 5.04 acc 64.9%


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert RWKV7')
parser.add_argument('model', type=str, help='path to model')
parser.add_argument('tokenizer', type=str, help='path to tokenizer')
parser.add_argument('--tasks', type=str, nargs='*',
default=['lambada_openai'])
args = parser.parse_args()

model = RWKV7ForCausalLM.from_pretrained(
args.model,
torch_dtype="auto",
device_map="cuda",
).half().eval()
tokenizer = Tokenizer.from_file(args.tokenizer)

test_rwkv7_lm_eval(model, tokenizer, task_names=["lambada_openai"])
101 changes: 55 additions & 46 deletions utils/convert_from_rwkv7.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,53 @@
# scripts for converting pretrained hf model weights to fla style

import argparse
import os
import re

import torch
from transformers import AutoModelForCausalLM

import fla # noqa
from fla.models.rwkv7 import RWKV7Config


def sizeof_fmt(num, suffix='B'):
for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'):
if abs(num) < 1024.0:
return f'{num:.2f}{unit}{suffix}'
num /= 1024.0
return f'{num:.2f}Yi{suffix}'


def convert(
rwkv7: str,
output: str
output: str,
precision: str = 'float32'
):
weights = torch.load(rwkv7, weights_only=True)
config = RWKV7Config()
config.vocab_size = weights['emb.weight'].shape[0] # 50304
config.hidden_size = weights['blocks.0.ffn.key.weight'].shape[1] # 768
config.hidden_ratio = weights['blocks.0.ffn.key.weight'].shape[0] / weights['blocks.0.ffn.key.weight'].shape[1] # 4.0
config.intermediate_size = weights['blocks.0.ffn.key.weight'].shape[0]
config.num_hidden_layers = 0
while f'blocks.{config.num_hidden_layers}.ffn.key.weight' in weights:
config.num_hidden_layers += 1
# 12
config.decay_low_rank_dim = weights['blocks.0.att.w1'].shape[1] # 64
config.gate_low_rank_dim = weights['blocks.0.att.g1'].shape[1] # 128
config.a_low_rank_dim = weights['blocks.0.att.a1'].shape[1] # 64
try:
config.v_low_rank_dim = weights['blocks.1.att.v1'].shape[1] # 32
except Exception:
except KeyError:
config.v_low_rank_dim = 32
config.torch_dtype = precision

print(f"Creating model with config:\n{config}")
model = AutoModelForCausalLM.from_config(config)
print(model)
model_dict = model.state_dict()
model_names = [n for n in model_dict]

unused_weights = [
'model.layers.0.attn.v_lora.lora.0.weight',
'model.layers.0.attn.v_lora.lora.2.weight',
'model.layers.0.attn.v_lora.lora.2.bias'
# these parameters may be present in pth file but are never used:
unused_names = ['blocks.0.attn.v0', 'blocks.0.attn.v1', 'blocks.0.attn.v2']
# these parameters may or may not be present in pth file:
possible_absent_weights = [
'model.layers.0.pre_norm.weight', 'model.layers.0.pre_norm.bias'
]
possible_absent_weights = ['model.layers.0.pre_norm.weight', 'model.layers.0.pre_norm.bias']
# other parameters may raise a KeyError

def translate_into_fla(name):
transposed = False
Expand All @@ -66,6 +66,8 @@ def translate_into_fla(name):
'ln_x': 'g_norm',
'output': 'o_proj',
}
if name in unused_names:
return '', False
if name in emb_head:
return emb_head[name], False
name_compo = name.split('.')
Expand All @@ -79,7 +81,9 @@ def translate_into_fla(name):
'ln1': 'attn_norm',
'ln2': 'ffn_norm'
}[name_compo[2]]
if re.match("[wvag][012]", name_compo[3]):
if name_compo[2] == 'attn' and re.match("x_[rwkvag]", name_compo[3]):
name_compo[3] = 'x_x'
elif re.match("[wvag][012]", name_compo[3]):
typ, num = name_compo[3]
name_compo[3] = f'{typ}_lora.lora.' + {
'0': '2.bias',
Expand All @@ -89,46 +93,51 @@ def translate_into_fla(name):
transposed |= (num in ['1', '2'])
elif name_compo[2] == 'attn' and name_compo[3] in proj:
name_compo[3] = proj[name_compo[3]]
elif name_compo[2] == 'attn' and re.match('x_[rwkvag]', name_compo[3]):
name_compo[3] = 'x_x'
return '.'.join(name_compo), transposed

x_x = {}
for name in weights:
fla_name, transposed = translate_into_fla(name)
print(f"{name:32} -> {fla_name:50}, {transposed}")
if re.match('.*att.x_[rwkvag]', name):
x_x[name] = weights[name]
if len(x_x) == 6:
weight = torch.stack(list(x_x.values())).squeeze_()
x_x = {}
else:
continue
else:
weight = weights[name] if not transposed else weights[name].T
if re.match('.*[wva]0', name):
print(f'{name:32} -> {fla_name:42}, {transposed}')
if not fla_name:
print('redundant parameters in source weight: ', name, '\n')
continue
weight = weights[name]
# print shape information
shape1 = list(weight.shape)
shape2 = list(model_dict[fla_name].shape)
print(f'{str(shape1):32} {str(shape2)}\n')

if transposed:
weight.t_()
if shape1 == [1, 1, config.hidden_size]:
weight.squeeze_()
if re.match('.*att.[kr]_[k_a]', name):
weight.squeeze_()
if re.match('.*ffn.x_[xk]', name):
weight.squeeze_()
assert model_dict[fla_name].shape == weight.shape
model_dict[fla_name].data.copy_(weight)
model_names.remove(fla_name)

print("unused parameters: ", model_names)
# fix: fusing x_[rwkvag] to x_x
if fla_name.endswith('attn.x_x'):
model_dict[fla_name].data['rwkvag'.find(name[-1])].copy_(weight)
if fla_name in model_names:
model_names.remove(fla_name)
else:
assert model_dict[fla_name].shape == weight.shape
model_dict[fla_name].data.copy_(weight)
model_names.remove(fla_name)

print("uninitialized parameters: ", model_names)
for n in model_names:
if not (n in unused_weights or n in possible_absent_weights):
if n not in possible_absent_weights:
raise KeyError(n)

print(f"Saving model to {output}")
model.save_pretrained(output)
os.makedirs(output, exist_ok=True)

from safetensors.torch import save_file
save_file(model.state_dict(), os.path.join(output, 'model.safetensors'))
model.config.save_pretrained(output)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model")
parser.add_argument("--output")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert RWKV7')
parser.add_argument('--rwkv7', type=str, help='Path to the input model')
parser.add_argument('--output', type=str, help='Directory to save model')
parser.add_argument('--precision', type=str, default='float32')
args = parser.parse_args()
convert(args.model, args.output)
convert(args.rwkv7, args.output, precision=args.precision)
Loading