Skip to content

Commit

Permalink
Merge pull request #268 from lakshith-403/LoRA
Browse files Browse the repository at this point in the history
LoRA experiment
  • Loading branch information
vpj authored Aug 7, 2024
2 parents d4af40b + 61d32f4 commit 5d384d6
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 296 deletions.
89 changes: 89 additions & 0 deletions labml_nn/lora/experiment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"source": [
"from labml_nn.lora.experiment import Configs\n",
"from labml import experiment"
],
"id": "1b9da2e59ffce5d5",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true
},
"source": "experiment.create(name=\"lora_gpt2\")",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "conf = Configs()",
"id": "31c9bc08eca2592",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "experiment.configs(conf)",
"id": "fb6ce74326558948",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "conf.initialize()",
"id": "1456cfab47dee3b",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"with experiment.start():\n",
" conf.run()"
],
"id": "3fe4068fd2df9094",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "d3c3c723ebbe854a",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (ml)",
"language": "python",
"name": "ml"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
137 changes: 137 additions & 0 deletions labml_nn/lora/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from labml import lab, monit, tracker
from labml.configs import BaseConfigs, option
from labml.utils.download import download_file
from labml_helpers.device import DeviceConfigs
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from labml_nn.lora.gpt2 import GPTModel


class Configs(BaseConfigs):
device: torch.device = DeviceConfigs()
layer_norm_epsilon: float = 1e-05
n_embed: int = 768
n_layer: int = 12
n_positions: int = 1024
vocab_size: int = 50257
epochs: int = 10
batch_size: int = 32
learning_rate: float = 1e-4
context_len: int = 512
r: int = 32

text: TensorDataset = "tiny_shakespeare"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model: GPTModel
optimizer: torch.optim.Adam
criterion = torch.nn.CrossEntropyLoss()
data_loader: DataLoader

def _load_pretrained_weights(self):
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")

state_dict = hf_model.state_dict()

mapping = {
'transformer.wte.weight': 'token_embedding.weight',
'transformer.wpe.weight': 'position_embedding.weight',
'transformer.ln_f.weight': 'final_norm.weight',
'transformer.ln_f.bias': 'final_norm.bias',
'lm_head.weight': 'lm_head.weight'
}

for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'

new_state_dict = {}
for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]

# transpose weight matrices of convo 1d layers to use linear layers instead
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])

for layer in convo_layers:
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

self.model.load_state_dict(new_state_dict, strict=False) # state dict does not have lora weights

del hf_model
del state_dict
del new_state_dict

def initialize(self):
self.model = GPTModel(
layer_norm_epsilon=self.layer_norm_epsilon,
n_embd=self.n_embed,
n_layer=self.n_layer,
n_positions=self.n_positions,
vocab_size=self.vocab_size,
r=self.r,
device=self.device
).to(self.device)
self._load_pretrained_weights()

self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)

self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)

def run(self):
for _ in monit.loop(self.epochs):
for i, batch in monit.enum('Train', self.data_loader):
inputs = batch[0]
inputs = inputs.to(self.device)
labels = inputs.clone()

outputs = self.model(inputs)

shift_logits = outputs[..., :-1, :]
shift_labels = labels[..., 1:]

loss = self.criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

tracker.add({'loss': loss})

tracker.save()
tracker.add_global_step()
tracker.new_line()


@option(Configs.text)
def tiny_shakespeare(c: Configs):
"""
### Tiny Shakespeare dataset

It will download from the url if not present
"""
path = lab.get_data_path() / 'tiny_shakespeare.txt'
if not path.exists():
download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
with open(path, 'r', encoding='utf-8') as f:
text = f.read()

tokens = c.tokenizer.encode(text)
num_batches = len(tokens) // (c.batch_size * c.context_len)
tokens = tokens[:num_batches * c.batch_size * c.context_len]
input_ids = torch.tensor(tokens).view(-1, c.context_len)
return TensorDataset(input_ids)
57 changes: 24 additions & 33 deletions labml_nn/lora/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from labml_nn.lora import Linear, Embedding

tokenizer = AutoTokenizer.from_pretrained("gpt2")

config = {
"layer_norm_epsilon": 1e-05,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"vocab_size": 50257,
"device": "cuda"
}


class FFN(nn.Module):
def __init__(self, dim):
def __init__(self, dim: int, n_embed: int, r: int):
super().__init__()
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
self.c_fc = Linear(n_embed, dim, r=r, bias=True)
self.c_proj = Linear(dim, n_embed, r=r, bias=True)
self.act = nn.functional.gelu

def forward(self, hidden_states):
Expand All @@ -31,15 +18,15 @@ def forward(self, hidden_states):


class MultiHeadAttention(nn.Module):
def __init__(self):
def __init__(self, n_embed: int, r: int):
super().__init__()
self.embed_dim = config['n_embd']
self.num_heads = config['n_head']
self.embed_dim = n_embed
self.num_heads = n_embed
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim

self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)
self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Expand Down Expand Up @@ -76,12 +63,12 @@ def forward(self, hidden_states):


class Block(nn.Module):
def __init__(self):
def __init__(self, n_embed: int, layer_norm_epsilon: float, r: int):
super().__init__()
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.attn = MultiHeadAttention()
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.ffn = FFN(config['n_embd'] * 4)
self.pre_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
self.attn = MultiHeadAttention(n_embed, r)
self.post_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
self.ffn = FFN(n_embed * 4, n_embed, r)

def forward(self, hidden_states):
residual = hidden_states
Expand All @@ -99,23 +86,27 @@ def forward(self, hidden_states):


class GPTModel(nn.Module):
def __init__(self):
def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_positions: int,
vocab_size: int, r: int, device: torch.device):
super().__init__()

self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
self.token_embedding = Embedding(vocab_size, n_embd, r=r)
self.position_embedding = Embedding(n_positions, n_embd, r=r)

self.blocks = nn.ModuleList([Block(n_embd, layer_norm_epsilon, r=r)
for _ in range(n_layer)])

self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
self.final_norm = nn.LayerNorm(n_embd, eps=layer_norm_epsilon)

self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.lm_head = Linear(n_embd, vocab_size, r=r, bias=False)

self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
self.device = device

def forward(self, input_ids):
batch_size, input_shape = input_ids.size()

token_embeddings = self.token_embedding(input_ids) # B T C
position_ids = torch.arange(input_shape, device=config['device']) # T C
position_ids = torch.arange(input_shape, device=self.device) # T C
position_embeddings = self.position_embedding(position_ids) # B T C

hidden_states = token_embeddings + position_embeddings
Expand Down
Loading

0 comments on commit 5d384d6

Please sign in to comment.