-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathInstruction_finetune.py
110 lines (84 loc) · 3.9 KB
/
Instruction_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd
# Argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='./t5_model')
parser.add_argument('--data_path', type=str, default='instruction.csv')
parser.add_argument('--output_dir', type=str, default='./instructed_t5_model')
parser.add_argument('--num_train_epochs', type=int, default=5)
parser.add_argument('--learning_rate', type=float, default=3e-4)
parser.add_argument('--train_batch_size', type=int, default=4)
parser.add_argument('--eval_batch_size', type=int, default=4)
parser.add_argument('--warmup_steps', type=int, default=100)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--logging_dir', type=str, default='./logs')
parser.add_argument('--logging_steps', type=int, default=150)
parser.add_argument('--eval_steps', type=int, default=300)
parser.add_argument('--save_steps', type=int, default=300)
parser.add_argument('--train_frac', type=float, default=0.95)
args = parser.parse_args()
# Load the T5-large model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, model_max_length=512)
# Define the dataset class
class MyDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __getitem__(self, idx):
input_sequence = self.data.iloc[idx]['input_text']
output_sequence = self.data.iloc[idx]['output_text']
input_encoding = self.tokenizer(input_sequence, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
output_encoding = self.tokenizer(output_sequence, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
input_ids = input_encoding['input_ids'].squeeze()
attention_mask = input_encoding['attention_mask'].squeeze()
label_ids = output_encoding['input_ids'].squeeze()
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label_ids': label_ids}
def __len__(self):
return len(self.data)
# Define the data collator function
def data_collator(batch):
input_ids = torch.stack([example['input_ids'] for example in batch])
attention_mask = torch.stack([example['attention_mask'] for example in batch])
label_ids = torch.stack([example['label_ids'] for example in batch])
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': label_ids}
def main(args):
# Load the data
preprocess_data = pd.read_csv(args.data_path)
# Split the data into train and validation sets
train_data = preprocess_data.sample(frac=args.train_frac, random_state=1)
val_data = preprocess_data.drop(train_data.index)
# Create the datasets
train_dataset = MyDataset(train_data, tokenizer)
val_dataset = MyDataset(val_data, tokenizer)
# Define the training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
logging_dir=args.logging_dir,
logging_steps=args.logging_steps,
evaluation_strategy='steps',
eval_steps=args.eval_steps,
save_steps=args.save_steps
)
# Define the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
# Fine-tune the model
trainer.train()
if __name__ == "__main__":
main(args)