-
Notifications
You must be signed in to change notification settings - Fork 2
/
prepare.py
80 lines (73 loc) · 2.69 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import random
import requests
import tiktoken
import numpy as np
train_ids=[]
val_ids=[]
enc = tiktoken.get_encoding("gpt2")
chunk_no=0
def download_file(url, output_dir):
global chunk_no
if not os.path.exists(output_dir):
os.mkdir(output_dir)
response = requests.get(url, stream=True)
if response.status_code == 200:
for chunk in response.iter_content(chunk_size=104857600):
chunk_no=chunk_no+1
output_filename = os.path.join(output_dir, f'{chunk_no}-dataset.txt')
with open(output_filename, 'wb') as f:
f.write(chunk)
print(f"made chunk {chunk_no}")
print("downloaded and chunked dataset, proceeding to tokenizing...")
else:
print('Error downloading file:', response.status_code)
download_file('https://huggingface.co/datasets/VatsaDev/TinyText/resolve/main/full.txt', 'output')
train_len = 0
val_len = 0
train_no = 0
val_no = 0
for filename in os.listdir('output'): #blocks are chosen randomly from the text, more of a seamless train val split
if filename.endswith('.txt'):
train_or_val = random.randint(0, 9)
if train_or_val <= 8:
with open(f'output/{filename}', 'r') as f:
data = f.read()
train_ids = enc.encode_ordinary(data)
train_len = train_len+len(train_ids)
train_ids = np.array(train_ids, dtype=np.uint16)
train_no = train_no+1
train_ids.tofile(os.path.join(f'train{train_no}.bin'))
print(f"train has {train_len} tokens")
train_ids = []
if train_or_val > 8:
with open(f'output/{filename}', 'r') as f:
data = f.read()
val_ids = enc.encode_ordinary(data)
val_len = val_len+len(val_ids)
val_ids = np.array(val_ids, dtype=np.uint16)
val_no = val_no+1
val_ids.tofile(os.path.join(f'val{val_no}.bin'))
print(f"val has {val_len} tokens")
val_ids = []
def concat_bins():
global total_val_data
global total_train_data
for filename in os.listdir('/content/'):
if filename.endswith('.bin'):
if filename[:3] == 'val':
# Val files
print(f"concat {filename}")
val_data = np.memmap('/'+os.path.join(data_dir, filename), dtype=np.uint16, mode='r')
total_val_data = np.concatenate([total_val_data, val_data])
del val_data
total_val_data.tofile('/content/valtotal.bin')
else:
# Train files
print(f"concat {filename}")
train_data = np.memmap('/'+os.path.join(data_dir, filename), dtype=np.uint16, mode='r')
total_train_data = np.concatenate([total_train_data, train_data])
del train_data
total_train_data.tofile('/content/traintotal.bin')
print("concat over")
concat_bins()