-
Notifications
You must be signed in to change notification settings - Fork 1
/
usage_example.py
133 lines (107 loc) · 3.29 KB
/
usage_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from shutil import copytree
import os
import re
import torch
from torch.utils.data import DataLoader
import trainloop
import loss
import dataset
import multitaskmodel
####################################
############ Parameters ############
####################################
# dataset
# all images are stored as one dataset while labels for every image are stored as datasets in a group
dataset_path = "datagen/data/DSBOV.hdf5"
trainset_dataset_name = "trainset"
trainset_labels_group_name = "trainset_labels"
testset_dataset_name = "testset"
testset_labels_group_name = "testset_labels"
# number of star distance directions
num_rays = 32
# elastic deformation (data augmentation)
elastic_deform_sigma = 7
elastic_deform_points = 3
zoom_factor = 1.2
# data loader
train_batch_size = 4
test_batch_size = 1
num_workers = 15 # heavy preprocessing requires many workers
# model parameters
out_channels = 256
fmaps = (16, 32, 64, 128, 256) # larger and more complicated images require more complex models
# optimizer learning rate
lr = 1e-4
# training
epochs = 40
plot_every = 10
evaluate_every = 1
save_every = 5
# NMS sampling in tensorboard
num_proposals = 500
iou_thres = 0.1
min_objprob = 0.3
#############################################################################
#############################################################################
# gpu selection
gpu_id = str(input("Select gpu: "))
device = torch.device("cuda:" + gpu_id)
# automatic experiment name and path determination
exp_name = "run1"
while os.path.exists(os.path.join('experiments/results', exp_name)):
exp_name = "run" + str(int(re.findall('\d+', exp_name)[0]) + 1)
trainset = dataset.Dataset(
path=dataset_path,
images_dataset_name=trainset_dataset_name,
labels_group_name=trainset_labels_group_name
)
testset = dataset.Dataset(
path=dataset_path,
images_dataset_name=testset_dataset_name,
labels_group_name=testset_labels_group_name,
)
# these images are used to plot predictions in tensorboard
plot_trainset = trainset.get_plot_images(2)
plot_testset = testset.get_plot_images(2)
trainloader = DataLoader(
dataset=trainset,
batch_size=train_batch_size,
shuffle=True,
num_workers=num_workers
)
testloader = DataLoader(
dataset=testset,
batch_size=test_batch_size,
shuffle=False,
num_workers=num_workers
)
model = multitaskmodel.MultitaskModel(
out_channels=out_channels,
fmaps=fmaps,
)
# wrapper for homoscedastic uncertainty loss
mtl = loss.MultiTaskLossWrapper(model=model)
optimizer = torch.optim.Adam(mtl.parameters(), lr=lr)
# make experiment results directory
exp_results_path = os.path.join('experiments/results', exp_name)
os.makedirs(exp_results_path, exist_ok=False)
trainer = trainloop.Trainer(
exp_path=exp_results_path,
model=model,
mtl=mtl,
optimizer=optimizer,
trainloader=trainloader,
testloader=testloader,
device=device,
num_proposals=num_proposals,
iou_thres=iou_thres,
min_objprob=min_objprob,
plot_trainset=plot_trainset,
plot_testset=plot_testset,
plot_every=plot_every,
evaluate_every=evaluate_every,
save_every=save_every,
)
trainer.train_model(epochs=epochs)
# save final model
torch.save(trainer.model.state_dict(), os.path.join(os.path.join(exp_results_path, 'state_dict.pt')))