-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathsample_set.py
35 lines (27 loc) · 1.05 KB
/
sample_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.utils.data as data
from pandas.io.parsers import read_csv
class Sample_set(data.Dataset):
def __init__(self, filename):
df = read_csv(filename)
self.df = df
self.w1 = 1
self.w2 = 0.1
def __getitem__(self, index):
index = index + 5
data = [ self.df['close_change'][index-1]*self.w1,
self.df['close_change'][index-2]*self.w1,
self.df['close_change'][index-3]*self.w1,
self.df['close_change'][index-4]*self.w1,
self.df['close_change'][index-5]*self.w1,
self.df['volume_change'][index-1]*self.w2,
self.df['volume_change'][index-2]*self.w2,
self.df['volume_change'][index-3]*self.w2,
self.df['volume_change'][index-4]*self.w2,
self.df['volume_change'][index-5]*self.w2 ]
target = [ self.df['close_change'][index] ]
data = torch.Tensor(data)
target = torch.Tensor(target)
return data, target
def __len__(self):
return len(self.df) - 5