-
Notifications
You must be signed in to change notification settings - Fork 0
/
kieugpt_char.py
139 lines (106 loc) · 4.15 KB
/
kieugpt_char.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Trains a character-level language model.
"""
import os
import sys
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
def get_config():
C = CN()
# system
C.system = CN()
C.system.seed = 3407
C.system.work_dir = './out/kieu_char'
# data
C.data = CharDataset.get_default_config()
# model
C.model = GPT.get_default_config()
C.model.model_type = 'gpt-mini'
# trainer
C.trainer = Trainer.get_default_config()
C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
return C
# -----------------------------------------------------------------------------
class CharDataset(Dataset):
"""
Emits batches of characters
"""
@staticmethod
def get_default_config():
C = CN()
C.block_size = 32
return C
def __init__(self, config, data):
self.config = config
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(chars) }
self.itos = { i:ch for i,ch in enumerate(chars) }
self.vocab_size = vocab_size
self.data = data
def get_vocab_size(self):
return self.vocab_size
def get_block_size(self):
return self.config.block_size
def __len__(self):
return len(self.data) - self.config.block_size
def __getitem__(self, idx):
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx:idx + self.config.block_size + 1]
# encode every character to an integer
dix = [self.stoi[s] for s in chunk]
# return as tensors
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
print(chunk)
print(x)
print(y)
return x, y
# -----------------------------------------------------------------------------
if __name__ == '__main__':
# get default config and overrides from the command line, if any
config = get_config()
config.merge_from_args(sys.argv[1:])
print(config)
setup_logging(config)
set_seed(config.system.seed)
# construct the training dataset
text = open('truyenkieu.txt', 'r', encoding='utf-8').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(config.data, text)
print(train_dataset.itos)
x, y = train_dataset[0]
# construct the model
config.model.vocab_size = train_dataset.get_vocab_size()
config.model.block_size = train_dataset.get_block_size()
model = GPT(config.model)
# construct the trainer object
trainer = Trainer(config.trainer, model, train_dataset)
# iteration callback
def batch_end_callback(trainer):
if trainer.iter_num % 10 == 0:
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
if trainer.iter_num % 500 == 0:
# evaluate both the train and test score
model.eval()
with torch.no_grad():
# sample from the model...
context = "Trăm năm "
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = model.generate(x, 100, temperature=1.0, do_sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
# save the latest model
print("saving model")
ckpt_path = os.path.join(config.system.work_dir, f"{trainer.iter_num}_kieu_char_model.pt")
torch.save(model.state_dict(), ckpt_path)
# revert model to training mode
model.train()
trainer.set_callback('on_batch_end', batch_end_callback)
# run the optimization
# trainer.run()