-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_organizer.lua
179 lines (146 loc) · 6.4 KB
/
data_organizer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
-- normalize the pos to [0,1]
require 'json'
require 'image'
require 'colormap'
require 'data_utils_new'
torch.setdefaulttensortype('torch.FloatTensor')
-- read json file
-- keyword
-- filename, class=image, annotations
-- for annotations: keyword
-- class = LeftClasperPoint, RightClasperPoint HeadPoint, ShaftPoint, TrackedPoint, EndPoint.
-- id = tool1, tool2
-- x y
local function readtoolLabelFile(label_file_tab)
local file_num = #label_file_tab
local multi_seq_anno_tab = {}
for seq_idx=1, file_num do
local jsonFilePath = label_file_tab[seq_idx]
local json_tab = json.load(jsonFilePath)
local frame_num = #json_tab
local anno_tab = {}
local anno_frame_num = 0
-- frame
for i=1, frame_num do
local frame_name = json_tab[i].filename
-- print('old frame ' .. frame_name)
-- point to new file location
frame_name = point2newFileLocation(frame_name, '/Users/xiaofeidu/mData', '/home/xiaofei/public_datasets')
frame_name = changeFrameFormat(frame_name, 'img_%06d_raw.png')
-- print('new frame ' .. frame_name)
local annotations = json_tab[i].annotations
if #annotations ~= 0 then
anno_frame_num = anno_frame_num + 1
anno_tab[anno_frame_num] = {}
anno_tab[anno_frame_num].filename = frame_name
local tool_ids = {}
-- reformat annotations: using joint class as key
local frame_anno = {}
for j=1, #annotations do
local joint_anno = annotations[j]
if frame_anno[joint_anno.class] == nil then
frame_anno[joint_anno.class] = {}
end
table.insert(frame_anno[joint_anno.class], { id = joint_anno.id,
x = joint_anno.x,
y = joint_anno.y
}
)
tool_ids[joint_anno.id] = true
end
anno_tab[anno_frame_num].annotations = frame_anno
anno_tab[anno_frame_num].jointNum = #annotations
local tool_num = 0
for __, __ in pairs(tool_ids) do
tool_num = tool_num + 1
end
anno_tab[anno_frame_num].toolNum = tool_num
end
end
-- normalize the location
for i=1, #anno_tab do
local frame_name = anno_tab[i].filename
local frame = image.load(frame_name, 3, 'byte')
local frame_width = frame:size(3)
local frame_height = frame:size(2)
local norm_frame_anno = normalizeToolPos01(frame_width, frame_height, anno_tab[i].annotations)
anno_tab[i].annotations = norm_frame_anno
end
table.insert(multi_seq_anno_tab, anno_tab)
end
return multi_seq_anno_tab
end
-- seperate the data into train and validation set for single sequence
local function sepTrainingData(anno_tab, train_percentage)
train_percentage = train_percentage or 0.8
local anno_frame_num = #anno_tab
assert(anno_frame_num >= 1)
local train_anno_tab = {}
local val_anno_tab = {}
local train_anno_frame_num = math.max(math.floor(train_percentage * anno_frame_num), 1)
for i=1, train_anno_frame_num do
table.insert(train_anno_tab, anno_tab[i])
end
for i=train_anno_frame_num+1, anno_frame_num do
table.insert(val_anno_tab, anno_tab[i])
end
return train_anno_tab, val_anno_tab
end
-- seperate the data into train and validation set for multiple sequence (internal sequence 80% : 20%)
local function internalSepTrainingData(multi_seq_anno_tab, train_percentage)
train_percentage = train_percentage or 0.8
local seq_num = #multi_seq_anno_tab
local train_anno_tab = {}
local val_anno_tab = {}
for seq_idx=1, seq_num do
local anno_tab = multi_seq_anno_tab[seq_idx]
local anno_frame_num = #anno_tab
assert(anno_frame_num >= 1)
local train_anno_frame_num = math.max(math.floor(train_percentage * anno_frame_num), 1)
for i=1, train_anno_frame_num do
table.insert(train_anno_tab, anno_tab[i])
end
for i=train_anno_frame_num+1, anno_frame_num do
table.insert(val_anno_tab, anno_tab[i])
end
print(train_anno_frame_num, anno_frame_num - train_anno_frame_num)
end
return train_anno_tab, val_anno_tab
end
-- random seperate the data into train and validation set for multiple sequences
local function internalRandomSepTrainingData(multi_seq_anno_tab, train_percentage)
train_percentage = train_percentage or 0.8
local seq_num = #multi_seq_anno_tab
local train_anno_tab = {}
local val_anno_tab = {}
for seq_idx=1, seq_num do
local anno_tab = multi_seq_anno_tab[seq_idx]
local anno_frame_num = #anno_tab
assert(anno_frame_num >= 1)
local perm = torch.randperm(anno_frame_num)
local train_anno_frame_num = math.max(math.floor(train_percentage * anno_frame_num), 1)
for i=1, train_anno_frame_num do
table.insert(train_anno_tab, anno_tab[perm[i]])
end
for i=train_anno_frame_num+1, anno_frame_num do
table.insert(val_anno_tab, anno_tab[perm[i]])
end
print(train_anno_frame_num, anno_frame_num - train_anno_frame_num)
end
return train_anno_tab, val_anno_tab
end
-- train dataset
local trainBaseDir = '/home/xiaofei/public_datasets/MICCAI_tool/Tracking_Robotic_Training/tool_label'
local json_files = {}
for seq_idx=1, 4 do
-- local json_file_path = paths.concat(trainBaseDir, 'endo' .. seq_idx .. '_labels.json') -- original label
local json_file_path = paths.concat(trainBaseDir, 'train' .. seq_idx .. '_labels.json') -- improved label (head)
table.insert(json_files, json_file_path)
end
local anno_tab = readtoolLabelFile(json_files)
local train_anno_tab, val_anno_tab = internalRandomSepTrainingData(anno_tab)
print(#train_anno_tab)
print(#val_anno_tab)
torch.save(paths.concat(trainBaseDir, 'train_random_toolpos_head.t7'), train_anno_tab)
torch.save(paths.concat(trainBaseDir, 'val_random_toolpos_head.t7'), val_anno_tab)
print('===========================================================================')