-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathDataset.py
90 lines (80 loc) · 3.28 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from PIL import Image
import os
import os.path
import glob
import numpy as np
from torch.utils.data import Dataset
import random
class MyDataset(Dataset):
def __init__(self, root_path, data_folder='train', name_list='ucfTrainTestlist', version=1, transform=None, num_frames=16, modality='RGB', random=False):
self.root_path = root_path
self.num_frames = num_frames
self.data_folder = data_folder
self.random = random
self.split_file = os.path.join(self.root_path, name_list,
str(data_folder) + 'list0' + str(version) + '.txt')
self.label_file = os.path.join(self.root_path, name_list, 'classInd.txt')
self.label_dict = self.get_labels()
self.video_dict = self.get_video_list()
self.version = version
self.transform = transform
def get_video_list(self):
res = []
with open(self.split_file) as fin:
for line in list(fin):
line = line.replace("\n", "")
split = line.split(" ")
# get number frames of each video
video_path = split[0].split('.')[0]
frames_path = os.path.join(self.root_path, self.data_folder, video_path)
allfiles = glob.glob(frames_path + '/*.jpg')
# remove video which has < 16 image frames
if len(allfiles) >= self.num_frames:
res.append(split[0])
return res
# Get all labels from classInd.txt
def get_labels(self):
label_dict = {}
with open(self.label_file) as fin:
for row in list(fin):
row = row.replace("\n", "").split(" ")
# -1 because the index of array is start from 0
label_dict[row[1]] = int(row[0]) - 1
return label_dict
# Get all frame images of video
def get_all_images(self, dir, file_ext="jpg", sort_files=True):
allfiles = glob.glob(dir + '/*.' + file_ext)
if sort_files and len(allfiles) > 0:
allfiles = sorted(allfiles)
return allfiles
def get_video_tensor(self, dir):
images = self.get_all_images(dir)
# print(dir)
# print(len(images))
seed = np.random.random_integers(0, len(images) - self.num_frames) # random sampling
clip = list()
if self.random:
orders = list(range(len(images)))
random_picked = random.sample(orders, self.num_frames)
for i in range(self.num_frames):
idx = random_picked[i]
img = Image.open(images[idx])
clip.append(img)
else:
for i in range(self.num_frames):
img = Image.open(images[i + seed])
clip.append(img)
clip = self.transform(clip)
return clip
# stuff
def __getitem__(self, index):
video = self.video_dict[index]
video_path = video.split('.')[0]
frames_path = os.path.join(self.root_path, self.data_folder, video_path)
clip = self.get_video_tensor(frames_path)
# get label name from video path
label_name = video_path.split('/')[0]
label_index = self.label_dict[label_name];
return (clip, label_index)
def __len__(self):
return len(self.video_dict)