-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel22-resnext.py
108 lines (71 loc) · 2.9 KB
/
model22-resnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# like model06 but with dropout layer only applied
# to the rechits variables, not the other (track iso)
# variables
import rechitmodelutils
import numpy as np
import math
import torch.nn as nn
import torch.nn.functional as F
# see http://torch.ch/blog/2016/02/04/resnets.html
# and https://discuss.pytorch.org/t/pytorch-performance/3079
# this seems not to exist in PyTorch ?
# torch.backends.cudnn.fastest = True
torch.backends.cudnn.benchmark = True
#----------------------------------------------------------------------
# model
#----------------------------------------------------------------------
def makeModel():
layers = []
import resnext
from ParallelTable import ParallelTable
from JoinTable import JoinTable
model = resnext.ModelCreator(depth = 29,
cardinality = 16,
baseWidth = 64,
dataset = 'cifar10',
bottleneckType = 'resnext_C',
numInputPlanes = 1,
avgKernelSize = 9, # for 35x35 inputs
numOutputNodes = 1,
).create()
layers.append(ParallelTable( [ model ] ))
layers.append(JoinTable(1))
layers.append(nn.Sigmoid())
result = nn.Sequential(*layers)
return result
#----------------------------------------------------------------------
# function to prepare input data samples
#----------------------------------------------------------------------
unpacker = rechitmodelutils.RecHitsUnpacker(
35, # width,
35, # height,
# for shifting 18,18 to 4,12
# recHitsXoffset = -18 + 4,
# recHitsYoffset = -18 + 12,
)
#----------------------------------------------------------------------
import torch.utils.data
class MyDataset(torch.utils.data.Dataset):
# see also http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
#----------------------------------------
def __init__(self, dataset):
self.weights = dataset['weights']
self.targets = dataset['labels']
self.nrows = len(self.weights)
# unpack rechits here
self.recHits = unpacker.unpack(dataset, range(self.nrows))
#----------------------------------------
def __len__(self):
return self.nrows
#----------------------------------------
def __getitem__(self, index):
return [ self.weights[index],
self.targets[index],
self.recHits[index] ]
#----------------------------------------------------------------------
def makeDataSet(dataset):
# note that this is defined in the model file because
# the dataset we make out of the input files
# is model dependent
return MyDataset(dataset)