-
Notifications
You must be signed in to change notification settings - Fork 6
/
process_kitti.py
executable file
·102 lines (85 loc) · 4.07 KB
/
process_kitti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
'''
Code for downloading and processing KITTI data (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/)
Based on code related to PredNet - Lotter et al. 2016 (https://arxiv.org/abs/1605.08104 https://github.com/coxlab/prednet).
Method of resizing was specified (bicubic).
'''
import os
import requests
from bs4 import BeautifulSoup
import urllib.request
import numpy as np
from imageio import imread
from scipy.misc import imresize
import hickle as hkl
from kitti_settings import *
desired_im_sz = (128, 160)
categories = ['city', 'residential', 'road']
# Recordings used for validation and testing.
# Were initially chosen randomly such that one of the city recordings was used for validation and one of each category was used for testing.
val_recordings = [('city', '2011_09_26_drive_0005_sync')]
test_recordings = [('city', '2011_09_26_drive_0104_sync'), ('residential', '2011_09_26_drive_0079_sync'), ('road', '2011_09_26_drive_0070_sync')]
if not os.path.exists(DATA_DIR): os.mkdir(DATA_DIR)
# Download raw zip files by scraping KITTI website
def download_data():
base_dir = os.path.join(DATA_DIR, 'raw_kitti/')
if not os.path.exists(base_dir): os.mkdir(base_dir)
for c in categories:
url = "http://www.cvlibs.net/datasets/kitti/raw_data.php?type=" + c
r = requests.get(url)
soup = BeautifulSoup(r.content)
drive_list = soup.find_all("h3")
drive_list = [d.text[:d.text.find(' ')] for d in drive_list]
print( "Downloading set: " + c)
c_dir = base_dir + c + '/'
if not os.path.exists(c_dir): os.mkdir(c_dir)
for i, d in enumerate(drive_list):
print( str(i+1) + '/' + str(len(drive_list)) + ": " + d)
url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/" + d + "/" + d + "_sync.zip"
urllib.request.urlretrieve(url, filename=c_dir + d + "_sync.zip")
# unzip images
def extract_data():
for c in categories:
c_dir = os.path.join(DATA_DIR, 'raw_kitti/', c + '/')
zip_files = list(os.walk(c_dir, topdown=False))[-1][-1]#.next()
for f in zip_files:
print( 'unpacking: ' + f)
spec_folder = f[:10] + '/' + f[:-4] + '/image_03/data*'
command = 'unzip -qq ' + c_dir + f + ' ' + spec_folder + ' -d ' + c_dir + f[:-4]
os.system(command)
# Create image datasets.
# Processes images and saves them in train, val, test splits.
def process_data():
splits = {s: [] for s in ['train', 'test', 'val']}
splits['val'] = val_recordings
splits['test'] = test_recordings
not_train = splits['val'] + splits['test']
for c in categories:
c_dir = os.path.join(DATA_DIR, 'raw_kitti', c + '/')
folders= list(os.walk(c_dir, topdown=False))[-1][-2]
splits['train'] += [(c, f) for f in folders if (c, f) not in not_train]
for split in splits:
im_list = []
source_list = [] # corresponds to recording that image came from
for category, folder in splits[split]:
im_dir = os.path.join(DATA_DIR, 'raw_kitti/', category, folder, folder[:10], folder, 'image_03/data/')
files = list(os.walk(im_dir, topdown=False))[-1][-1]
im_list += [im_dir + f for f in sorted(files)]
source_list += [category + '-' + folder] * len(files)
print( 'Creating ' + split + ' data: ' + str(len(im_list)) + ' images')
X = np.zeros((len(im_list),) + desired_im_sz + (3,), np.uint8)
for i, im_file in enumerate(im_list):
im = imread(im_file)
X[i] = process_im(im, desired_im_sz)
hkl.dump(X, os.path.join(DATA_DIR, 'X_kitti_' + split + '_bic.hkl'))
hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_kitti_' + split + '_bic.hkl'))
# resize and crop image
def process_im(im, desired_sz):
target_ds = float(desired_sz[0])/im.shape[0]
im = imresize(im, (desired_sz[0], int(np.round(target_ds * im.shape[1]))),'bicubic')
d = int((im.shape[1] - desired_sz[1]) / 2)
im = im[:, d:d+desired_sz[1]]
return im
if __name__ == '__main__':
download_data()
extract_data()
process_data()