-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathoptimizer.py
109 lines (93 loc) · 3.82 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Pytorch Optimizer and Scheduler Related Task
"""
import math
import logging
import torch
from torch import optim
from config import cfg
def get_optimizer(args, net):
"""
Decide Optimizer (Adam or SGD)
"""
base_params = []
for name, param in net.named_parameters():
base_params.append(param)
if args.sgd:
optimizer = optim.SGD(base_params,
lr=args.lr,
weight_decay=5e-4, #args.weight_decay,
momentum=args.momentum,
nesterov=False)
else:
raise ValueError('Not a valid optimizer')
lambda1 = lambda iteration: math.exp(-1 * args.poly_exp * iteration / 120000)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
return optimizer, scheduler
def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False):
"""
Load weights from snapshot file
"""
logging.info("Loading weights from model %s", snapshot_file)
net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file,
restore_optimizer_bool)
return epoch, mean_iu
def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool):
"""
Restore weights and optimizer (if needed ) for resuming job.
"""
checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
logging.info("Checkpoint Loaded")
if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
optimizer.load_state_dict(checkpoint['optimizer'])
if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool:
scheduler.load_state_dict(checkpoint['scheduler'])
if 'state_dict' in checkpoint:
net = forgiving_state_restore(net, checkpoint['state_dict'])
logging.info("Checkpoint network state_dict uploaded!!")
else:
net = forgiving_state_restore(net, checkpoint)
logging.info("Checkpoint network uploaded!!")
if 'memory' in checkpoint:
try:
net.module.memory.m_items = checkpoint['memory'].cuda()
logging.info("Checkpoint memory uploaded!!")
except AttributeError:
print("There is no memory in the network!!, Memory in the pretrained model did not uploaded")
return net, optimizer, scheduler, checkpoint['epoch'], checkpoint['mean_iu']
def forgiving_state_restore(net, loaded_dict):
"""
Handle partial loading when some tensors don't match up in size.
Because we want to use models that were trained off a different
number of classes.
"""
net_state_dict = net.state_dict()
new_loaded_dict = {}
for k in net_state_dict:
if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
new_loaded_dict[k] = loaded_dict[k]
else:
print("Do not match with saved parameter ", k)
# logging.info("Skipped loading parameter %s", k)
net_state_dict.update(new_loaded_dict)
net.load_state_dict(net_state_dict)
return net
def forgiving_state_copy(target_net, source_net):
"""
Handle partial loading when some tensors don't match up in size.
Because we want to use models that were trained off a different
number of classes.
"""
net_state_dict = target_net.state_dict()
loaded_dict = source_net.state_dict()
new_loaded_dict = {}
for k in net_state_dict:
if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
new_loaded_dict[k] = loaded_dict[k]
print("Matched", k)
else:
print("Skipped loading parameter ", k)
# logging.info("Skipped loading parameter %s", k)
net_state_dict.update(new_loaded_dict)
target_net.load_state_dict(net_state_dict)
return target_net