-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.py
83 lines (60 loc) · 2.5 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pickle
import torch
from torch.utils.data import Dataset
from typing import Tuple, List, Any, Dict, Union
from preprocess import PREFIX_TO_TRAFFIC_ID, PREFIX_TO_APP_ID, AUX_ID
def load_data() -> Tuple[Any, Any, Any]:
"""Load data from pickle files"""
with open('data/train_data_rows.pkl', 'rb') as f:
train_data_rows = pickle.load(f)
with open('data/val_data_rows.pkl', 'rb') as f:
val_data_rows = pickle.load(f)
with open('data/test_data_rows.pkl', 'rb') as f:
test_data_rows = pickle.load(f)
print(f'Amount of train data: {len(train_data_rows)}')
print(f'Amount of val data: {len(val_data_rows)}')
print(f'Amount of test data: {len(test_data_rows)}')
return train_data_rows, val_data_rows, test_data_rows
def id_to_one_hot_tensor(
id_value: Union[int, torch.Tensor],
num_classes: int
):
"""
Convert an ID to a one-hot encoded tensor using PyTorch.
Parameters:
- id_value (int or Tensor): The ID value(s) to be converted to a one-hot tensor.
- num_classes (int): Total number of classes/categories.
Returns:
- one_hot_tensor (Tensor): The one-hot encoded tensor.
"""
# Convert int to tensor if single value
if isinstance(id_value, int):
id_value = torch.tensor(id_value)
one_hot_tensor = torch.nn.functional.one_hot(id_value, num_classes=num_classes)
return one_hot_tensor.to(torch.float32)
class CustomListDataset(Dataset):
"""Subclass of Dataset class"""
def __init__(self, rows: List[Dict[str, Any]]):
""" Initialize dataset.
Args:
rows: Data samples in a list of dict of features and labels.
"""
self.data = rows
self.n_traffic = len(PREFIX_TO_TRAFFIC_ID)
self.n_app = len(PREFIX_TO_APP_ID)
self.n_aux = len(AUX_ID)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
d = self.data[index]
# Convert class index to one-hot encoding
y_traffic = id_to_one_hot_tensor(d['traffic_label'], self.n_traffic)
y_app = id_to_one_hot_tensor(d['app_label'], self.n_app)
y_aux = id_to_one_hot_tensor(d['aux_label'], self.n_aux)
# Concat a data sample including a sparse matrix converted
sample = (torch.from_numpy(d['feature'].toarray()), y_traffic, y_app, y_aux)
return sample
def get_dataset(data_rows: List[Dict[str, Any]]) -> Dataset:
"""Create a dataset with data samples"""
ds = CustomListDataset(data_rows)
return ds