-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathread_data.py
81 lines (68 loc) · 2.85 KB
/
read_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import json
import os
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import random
import shutil
class TensorDataset(Dataset):
def __init__(self, data_path, usage, input_size):
self.dataset_path = data_path
self.usage = usage
self.input_size = input_size
self.Y = list()
self.pics = list()
for label, folder_name in enumerate(os.listdir(self.dataset_path)):
folder_path = os.path.join(self.dataset_path, folder_name)
for pic_name in os.listdir(folder_path):
self.pics.append(os.path.join(folder_path, pic_name))
self.Y.append(label)
# 测试代码用
# self.pics = self.pics[:500]
# self.Y = self.Y[:500]
np.random.seed(2016)
n_examples = len(self.pics)
n_train = n_examples * 0.8
train_idx = np.random.choice(range(0, n_examples), size=int(n_train), replace=False)
val_idx = list(set(range(0, n_examples)) - set(train_idx))
self.pics, self.Y = np.array(self.pics), np.array(self.Y)
self.X_train, self.Y_train = self.pics[train_idx], self.Y[train_idx]
self.X_val, self.Y_val = self.pics[val_idx], self.Y[val_idx]
with open('mean_std.json', 'r') as f:
mean_std = json.load(f)
means = mean_std['means']
stdevs = mean_std['stdevs']
print(means, stdevs)
normalize = transforms.Normalize(mean=means, std=stdevs)
self.transform = transforms.Compose(
[transforms.CenterCrop(self.input_size), transforms.ToTensor(), normalize, ])
# 重写后支持通过索引来使用第i个数据的样本 dataset[i]
def __getitem__(self, index):
if self.usage == 'train':
x = Image.open(self.X_train[index])
y = self.Y_train[index]
elif self.usage == 'val':
x = Image.open(self.X_val[index])
y = self.Y_val[index]
# width, height, _ = np.shape(x)
x = self.transform(x).view(3, self.input_size, self.input_size)
return x, y
def __len__(self):
if self.usage == 'train':
return len(self.X_train)
elif self.usage == 'val':
return len(self.X_val)
def mv_val_train(ratio=0.8):
random.seed(2019)
path = 'datasets/val'
path_move = 'datasets/train'
class_list = os.listdir(path)
for each_class in class_list:
pic_names = os.listdir(os.path.join(path, each_class))
# 选取80%的验证集,并选出其中的文件名,并移动文件
choose_names = random.sample(pic_names, int(len(pic_names)*ratio))
for choose_name in choose_names:
shutil.move(os.path.join(path, each_class, choose_name), os.path.join(path_move, each_class, choose_name))
if __name__ == "__main__":
mv_val_train()