-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
81 lines (72 loc) · 2.44 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import numpy as np
from maf import MAF
from made import MADE
from datasets.data_loaders import get_data, get_data_loaders
from utils.train import train_one_epoch_maf, train_one_epoch_made
from utils.validation import val_maf, val_made
from utils.test import test_maf, test_made
from utils.plot import sample_digits_maf, plot_losses
# --------- SET PARAMETERS ----------
model_name = "maf" # 'MAF' or 'MADE'
dataset_name = "mnist"
batch_size = 128
n_mades = 5
hidden_dims = [512]
lr = 1e-4
random_order = False
patience = 30 # For early stopping
seed = 290713
plot = True
max_epochs = 1000
# -----------------------------------
# Get dataset.
data = get_data(dataset_name)
train = torch.from_numpy(data.train.x)
# Get data loaders.
train_loader, val_loader, test_loader = get_data_loaders(data, batch_size)
# Get model.
n_in = data.n_dims
if model_name.lower() == "maf":
model = MAF(n_in, n_mades, hidden_dims)
elif model_name.lower() == "made":
model = MADE(n_in, hidden_dims, random_order=random_order, seed=seed, gaussian=True)
# Get optimiser.
optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
# Format name of model save file.
save_name = f"{model_name}_{dataset_name}_{'_'.join(str(d) for d in hidden_dims)}.pt"
# Initialise list for plotting.
epochs_list = []
train_losses = []
val_losses = []
# Initialiise early stopping.
i = 0
max_loss = np.inf
# Training loop.
for epoch in range(1, max_epochs):
if model_name == "maf":
train_loss = train_one_epoch_maf(model, epoch, optimiser, train_loader)
val_loss = val_maf(model, train, val_loader)
elif model_name == "made":
train_loss = train_one_epoch_made(model, epoch, optimiser, train_loader)
val_loss = val_made(model, val_loader)
if plot:
sample_digits_maf(model, epoch, random_order=random_order, seed=5)
epochs_list.append(epoch)
train_losses.append(train_loss)
val_losses.append(val_loss)
# Early stopping. Save model on each epoch with improvement.
if val_loss < max_loss:
i = 0
max_loss = val_loss
torch.save(
model, "model_saves/" + save_name
) # Will print a UserWarning 1st epoch.
else:
i += 1
if i < patience:
print("Patience counter: {}/{}".format(i, patience))
else:
print("Patience counter: {}/{}\n Terminate training!".format(i, patience))
break
plot_losses(epochs_list, train_losses, val_losses, title=None)