-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsplitSparseRecHits.py
executable file
·296 lines (199 loc) · 9.04 KB
/
splitSparseRecHits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#!/usr/bin/env python
# splits a rechits file into test and train file
import sys
import numpy as np
import fnmatch
#----------------------------------------------------------------------
# fraction of test sample
testFraction = 0.25
randSeed = 1337
#----------------------------------------------------------------------
# @param ignoreKeys is a list of fnmatch patterns
def makeOutputVar(inputData, indices, ignoreKeys):
# inputData must be a dict
outputData = {}
for key, value in inputData.items():
# check if we should not touch this key
if any(fnmatch.fnmatch(key, pattern) for pattern in ignoreKeys):
continue
# note that in the npy version value is never a dict
# but a numpy object
outputData[key] = value[indices]
# end of loop over items in the table
return outputData
#----------------------------------------------------------------------
def makeOutputDataRecHits(indices, inputData, outputData):
numOutputRows = len(indices)
#----------
# calculate the total number of output rechits
#----------
numOutputRecHits = inputData['X/numRecHits'][indices].sum()
#----------
# now that we know the total number of output rechits,
# copy the vectors related to the rechits
outputData['X/x'] = -1 * np.ones(numOutputRecHits, dtype = 'int32')
outputData['X/y'] = -1 * np.ones(numOutputRecHits, dtype = 'int32')
outputData['X/energy'] = -1 * np.ones(numOutputRecHits, dtype = 'float32')
outputData['X/firstIndex'] = -1 * np.ones(numOutputRows, dtype = 'int32')
outputData['X/numRecHits'] = inputData['X/numRecHits'][indices]
# note that we keep the one based convention from Torch here for
# historical reasons
firstIndex = 1
import tqdm
progbar = tqdm.tqdm(total = numOutputRows,
mininterval = 0.1,
unit = 'photons',
desc = 'splitting rechits')
# assign to variables to avoid dict lookups all the time (which makes it very slow, at least for npz files...)
inputFirstIndex = inputData['X/firstIndex']
inputNumRecHits = inputData['X/numRecHits']
inputDataX = inputData['X/x']
inputDataY = inputData['X/y']
inputDataE = inputData['X/energy']
outputDataX = outputData['X/x']
outputDataY = outputData['X/y']
outputDataE = outputData['X/energy']
outputDataFirstIndex = outputData['X/firstIndex']
outputDataNumRecHits = outputData['X/numRecHits']
for i in range(numOutputRows):
# this is zero based
index = indices[i]
outputDataFirstIndex[i] = firstIndex
# sanity check of input data
assert inputFirstIndex[index] >= 1
assert inputFirstIndex[index] + inputNumRecHits[index] - 1 <= len(inputDataE), "failed at index=" + str(index)
# baseInputIndex is zero based
baseInputIndex = inputFirstIndex[index] - 1
thisNumRecHits = inputNumRecHits[index]
outputDataNumRecHits[i] = thisNumRecHits
# copy coordinates and energy over
# note that firstIndex is one based
outputDataX[(firstIndex - 1):(firstIndex - 1 + thisNumRecHits)] = inputDataX[(baseInputIndex):(baseInputIndex + thisNumRecHits)]
outputDataY[(firstIndex - 1):(firstIndex - 1 + thisNumRecHits)] = inputDataY[(baseInputIndex):(baseInputIndex + thisNumRecHits)]
outputDataE[(firstIndex - 1):(firstIndex - 1 + thisNumRecHits)] = inputDataE[(baseInputIndex):(baseInputIndex + thisNumRecHits)]
firstIndex += thisNumRecHits
# end -- loop over rechits of this photon
progbar.update(1)
# end -- loop over photons
progbar.close()
#----------------------------------------------------------------------
def makeOutputDataTracks(indices, inputData, outputData):
numOutputRows = len(indices)
#----------
# calculate the total number of output tracks
#----------
numOutputTracks = inputData['tracks/numTracks'][indices].sum()
#----------
# now that we know the total number of output rechits,
# copy the vectors related to the rechits
# variables other than the indexing variables 'firstIndex'
# and 'numTracks'
otherVarNames = []
for key in inputData.keys():
if not key.startswith("tracks/"):
continue
if key == 'tracks/firstIndex':
# make sure we have int type here
outputData[key] = -1 * np.ones(numOutputRows, dtype = 'int32')
elif key == 'tracks/numTracks':
# we can just copy this one, filtering by indices
outputData[key] = inputData[key][indices]
else:
# take the input dtype for the output
outputData[key] = -1 * np.ones(numOutputTracks, dtype = inputData[key].dtype)
otherVarNames.append(key)
# end -- loop over keys of inputData.tracks
# note that we keep the Torch/Lua convention of one based indices
firstIndex = 1
# make local variables for input and output variables
# to avoid dict/npz file lookups within the loop
inputDataNumTracks = inputData['tracks/numTracks']
inputDataFirstIndex = inputData['tracks/firstIndex']
outputDataNumTracks = outputData['tracks/numTracks']
outputDataFirstIndex = outputData['tracks/firstIndex']
#----------
# calculate tracks/firstIndex
# and mapping from input to output index ranges for other variables
#----------
inputRanges = [ None ] * numOutputRows
outputRanges = [ None ] * numOutputRows
for i in range(numOutputRows):
index = indices[i]
outputDataFirstIndex[i] = firstIndex
# note we need to subtract one from both output but
# also from the input indices
inputRanges[i] = slice(inputDataFirstIndex[index] - 1,
inputDataFirstIndex[index] - 1 + inputDataNumTracks[index])
outputRanges[i] = slice(firstIndex - 1,
firstIndex - 1 + inputDataNumTracks[index])
# prepare next iteration
firstIndex += inputDataNumTracks[index]
#----------
# other variables (which need no adjustment like the firstIndex variable)
#
# outer loop is on variables so we can avoid repetitive lookups
# by variable name
#----------
for key in inputData.keys():
if not key.startswith("tracks/"):
continue
if key in ('tracks/firstIndex', 'tracks/numTracks'):
continue
inputVec = inputData[key]
outputVec = outputData[key]
for inputRange, outputRange in zip(inputRanges, outputRanges):
outputVec[outputRange] = inputVec[inputRange]
# end of loop over photons
# end of loop over variables
#----------------------------------------------------------------------
def makeOutputData(indices, inputData, checkFirstIndex = True):
# copy everything except the rechits
outputData = makeOutputVar(inputData, indices, [ 'X/*', 'tracks/*' ])
makeOutputDataRecHits(indices, inputData, outputData)
import mergeRecHits
if checkFirstIndex:
isOk = mergeRecHits.checkFirstIndex(outputData['X/firstIndex'], outputData['X/numRecHits'])
assert isOk, "internal error with rechits firstIndex"
if 'tracks/numTracks' in inputData:
makeOutputDataTracks(indices, inputData, outputData)
if checkFirstIndex:
isOk = mergeRecHits.checkFirstIndex(outputData['tracks/firstIndex'], outputData['tracks/numTracks'])
assert isOk, "internal error with tracks firstIndex"
return outputData
#----------------------------------------------------------------------
# main
#----------------------------------------------------------------------
ARGV = sys.argv[1:]
assert len(ARGV) == 1
np.random.seed(randSeed)
inputFile = ARGV.pop(0)
#----------
# generate output file test names
#----------
if inputFile.endswith(".npz"):
outputFilePrefix = inputFile[:-4]
else:
outputFilePrefix = inputFile
outputFileTrain = outputFilePrefix + "-train.npz"
outputFileTest = outputFilePrefix + "-test.npz"
#----------
inputData = np.load(inputFile)
numEvents = len(inputData['y'])
print "have", numEvents,"photons"
# throw a random number for each event
randVals = np.random.rand(numEvents)
trainIndices = np.arange(numEvents)[np.where(randVals > testFraction)]
testIndices = np.arange(numEvents)[np.where(randVals <= testFraction)]
assert len(trainIndices) + len(testIndices) == numEvents
#----------
# create and fill the output tensors
#----------
print "filling train dataset"
outputDataTrain = makeOutputData(trainIndices, inputData)
print "writing train dataset (",len(trainIndices),"photons) to",outputFileTrain
# note that this is uncompressed
np.savez(outputFileTrain, **outputDataTrain)
print "filling test dataset"
outputDataTest = makeOutputData(testIndices, inputData)
print "writing test dataset (", len(testIndices),"photons) to",outputFileTest
np.savez(outputFileTest, **outputDataTest)