-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathtrain.py
32 lines (25 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from FOD.Trainer import Trainer
from FOD.dataset import AutoFocusDataset
with open('config.json', 'r') as f:
config = json.load(f)
np.random.seed(config['General']['seed'])
list_data = config['Dataset']['paths']['list_datasets']
## train set
autofocus_datasets_train = []
for dataset_name in list_data:
autofocus_datasets_train.append(AutoFocusDataset(config, dataset_name, 'train'))
train_data = ConcatDataset(autofocus_datasets_train)
train_dataloader = DataLoader(train_data, batch_size=config['General']['batch_size'], shuffle=True)
## validation set
autofocus_datasets_val = []
for dataset_name in list_data:
autofocus_datasets_val.append(AutoFocusDataset(config, dataset_name, 'val'))
val_data = ConcatDataset(autofocus_datasets_val)
val_dataloader = DataLoader(val_data, batch_size=config['General']['batch_size'], shuffle=True)
trainer = Trainer(config)
trainer.train(train_dataloader, val_dataloader)