-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathcub.py
111 lines (95 loc) · 4.14 KB
/
cub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import numpy as np
import os
from PIL import Image, TarIO
import pickle
import tarfile
class cub200(torch.utils.data.Dataset):
def __init__(self, root, train=True, transform=None):
super(cub200, self).__init__()
self.root = root
self.train = train
self.transform = transform
if self._check_processed():
print('Train file has been extracted' if self.train else 'Test file has been extracted')
else:
self._extract()
if self.train:
self.train_data, self.train_label = pickle.load(
open(os.path.join(self.root, 'processed/train.pkl'), 'rb')
)
else:
self.test_data, self.test_label = pickle.load(
open(os.path.join(self.root, 'processed/test.pkl'), 'rb')
)
def __len__(self):
return len(self.train_data) if self.train else len(self.test_data)
def __getitem__(self, idx):
if self.train:
img, label = self.train_data[idx], self.train_label[idx]
else:
img, label = self.test_data[idx], self.test_label[idx]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, label
def _check_processed(self):
assert os.path.isdir(self.root) == True
assert os.path.isfile(os.path.join(self.root, 'CUB_200_2011.tgz')) == True
return (os.path.isfile(os.path.join(self.root, 'processed/train.pkl')) and
os.path.isfile(os.path.join(self.root, 'processed/test.pkl')))
def _extract(self):
processed_data_path = os.path.join(self.root, 'processed')
if not os.path.isdir(processed_data_path):
os.mkdir(processed_data_path)
cub_tgz_path = os.path.join(self.root, 'CUB_200_2011.tgz')
images_txt_path = 'CUB_200_2011/images.txt'
train_test_split_txt_path = 'CUB_200_2011/train_test_split.txt'
tar = tarfile.open(cub_tgz_path, 'r:gz')
images_txt = tar.extractfile(tar.getmember(images_txt_path))
train_test_split_txt = tar.extractfile(tar.getmember(train_test_split_txt_path))
if not (images_txt and train_test_split_txt):
print('Extract image.txt and train_test_split.txt Error!')
raise RuntimeError('cub-200-1011')
images_txt = images_txt.read().decode('utf-8').splitlines()
train_test_split_txt = train_test_split_txt.read().decode('utf-8').splitlines()
id2name = np.genfromtxt(images_txt, dtype=str)
id2train = np.genfromtxt(train_test_split_txt, dtype=int)
print('Finish loading images.txt and train_test_split.txt')
train_data = []
train_labels = []
test_data = []
test_labels = []
print('Start extract images..')
cnt = 0
train_cnt = 0
test_cnt = 0
for _id in range(id2name.shape[0]):
cnt += 1
image_path = 'CUB_200_2011/images/' + id2name[_id, 1]
image = tar.extractfile(tar.getmember(image_path))
if not image:
print('get image: '+image_path + ' error')
raise RuntimeError
image = Image.open(image)
label = int(id2name[_id, 1][:3]) - 1
if image.getbands()[0] == 'L':
image = image.convert('RGB')
image_np = np.array(image)
image.close()
if id2train[_id, 1] == 1:
train_cnt += 1
train_data.append(image_np)
train_labels.append(label)
else:
test_cnt += 1
test_data.append(image_np)
test_labels.append(label)
if cnt%1000 == 0:
print('{} images have been extracted'.format(cnt))
print('Total images: {}, training images: {}. testing images: {}'.format(cnt, train_cnt, test_cnt))
tar.close()
pickle.dump((train_data, train_labels),
open(os.path.join(self.root, 'processed/train.pkl'), 'wb'))
pickle.dump((test_data, test_labels),
open(os.path.join(self.root, 'processed/test.pkl'), 'wb'))