-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_manga_tp.py
30 lines (27 loc) · 987 Bytes
/
train_manga_tp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
import torch
from data.dataloaders import PairwiseMangaDataNoIMG, PairwiseMangaData
from models.tpsortmodels import base_order_model, MMOrderModel
train_set = PairwiseMangaData("manga109/sentence_order.json", "manga109/images", split="train")
valid_set = PairwiseMangaData("manga109/sentence_order.json", "manga109/images", split="validation")
model = MMOrderModel(load_from_pretrained="pytorch_model.bin")
training_args = TrainingArguments(
output_dir="ckpt",
num_train_epochs=2,
per_device_train_batch_size=5,
learning_rate=5e-6,
warmup_steps=0,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=len(train_set) // 10 + 1,
save_steps=len(train_set) // 10 + 1,
save_total_limit=2,
dataloader_num_workers=6
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_set
)
torch.save(model.state_dict(), "unmasked.pth")
trainer.train()