-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
123 lines (88 loc) · 4.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from src.model import *
from src.argparser import *
from src.utils import *
from src.dataloader import *
def train(epoch):
model.train()
train_losses = torch.zeros(3)
for batch_idx, (data,_) in enumerate(train_loader):
data = mix_data(data.to(device))[0].view(-1,dimx)
optimizer.zero_grad()
recon_y, mu_z, logvar_z, _ = model(data)
loss, ELL, KLD = loss_function(data,recon_y, mu_z, logvar_z, beta=beta)
loss.backward()
optimizer.step()
train_losses[0] += loss.item()
train_losses[1] += ELL.item()
train_losses[2] += KLD.item()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)] \t -ELL: {:5.6f} \t KLD: {:5.6f} \t Loss: {:5.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
ELL.item() / len(data),KLD.item() / len(data),loss.item() / len(data)))
train_losses /= len(train_loader.dataset)
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_losses[0]))
return train_losses
def test(epoch):
model.eval()
test_losses = torch.zeros(3)
with torch.no_grad():
for i, (data,_) in enumerate(test_loader):
data = mix_data(data.to(device))[0].view(-1,dimx)
recon_y, mu_z, logvar_z, recons = model(data)
loss, ELL, KLD = loss_function(data,recon_y, mu_z, logvar_z, beta=beta)
test_losses[0] += loss.item()
test_losses[1] += ELL.item()
test_losses[2] += KLD.item()
n = min(data.size(0), 6)
ncols = (2+args.sources)
comparison = torch.zeros(n*ncols,1,28,28)
comparison[::ncols] = data.view(data.size(0), 1, 28, 28)[:n]
comparison[1::ncols] = recon_y.view(data.size(0), 1, 28, 28)[:n]
for i in range(args.sources):
comparison[(i+2)::ncols] = recons[:,i].view(data.size(0), 1, 28, 28)[:n]
grid = make_grid(comparison,nrow=ncols)
save_image(comparison.cpu(),'results/reconstruction_' + str(epoch) + '.png', nrow=ncols)
test_losses /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_losses[0]))
return test_losses
def plot_losses(losses):
plt.figure()
plt.plot(np.array(range(1,args.epochs+1)),losses["train"][:,0].view(-1),label="Train")
plt.plot(np.array(range(1,args.epochs+1)),losses["test"][:,0].view(-1),label="Test")
plt.xlabel('Epoch'), plt.ylabel('Loss'), plt.legend(), plt.xlim(1,args.epochs)
plt.savefig('results/losses.png')
plt.close()
args = parser.parse_args()
torch.manual_seed(args.seed)
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': args.num_workers, 'pin_memory': True} if args.cuda else {}
train_loader, test_loader = get_data_loaders(args.data_directory,args.batch_size,kwargs)
# MNIST is 28 X 28
dimx = int(28*28)
model = VAE(dimx=dimx,dimz=args.dimz,n_sources=args.sources,device=device,variational=args.variational).to(device)
loss_function = Loss(sources=args.sources,likelihood='laplace',variational=args.variational,prior=args.prior,scale=args.scale)
optimizer = optim.Adam(model.parameters(), lr = args.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.decay, last_epoch=-1)
losses = {"train": torch.zeros(args.epochs,3), "test": torch.zeros(args.epochs,3)}
for epoch in range(1, args.epochs+1):
beta = min(1.0,(epoch)/min(args.epochs,args.warm_up)) * args.beta_max
losses["train"][epoch-1] = train(epoch)
losses["test"][epoch-1] = test(epoch)
if optimizer.param_groups[0]['lr'] >= 1.0e-5:
scheduler.step()
with torch.no_grad():
if epoch % args.save_interval == 0:
torch.save(model.state_dict(),'saves/model_'+('vae' if args.variational else 'ae')+'_K' + str(args.sources) + '.pt')
plot_losses(losses)