-
Notifications
You must be signed in to change notification settings - Fork 7
/
split_train_test.py
76 lines (55 loc) · 1.98 KB
/
split_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import pandas as pd
import os
import numpy as np
from configs import argHandler
FLAGS = argHandler()
FLAGS.setDefaults()
dataset_df = pd.read_excel('./data/annotations.xlsx')
test_set_fraction = 0.2
shuffle = True
if shuffle:
dataset_df = dataset_df.sample(frac=1., random_state=np.random.randint(1, 100))
def get_sparse_labels(y):
labels = np.zeros(y.shape[0], dtype=int)
class_counts = np.zeros(3, dtype=int)
mapping = {'Normal': 0, 'Benign': 1, 'Malignant': 2}
index = 0
for label in y:
try:
labels[index] = mapping[label]
class_counts[mapping[label]] += 1
except:
pass
index += 1
return labels, class_counts
def make_dict(dataset_df):
dict = {}
for column in dataset_df:
dict[column] = []
return dict
def add_row(dict, df_row):
for key in dict.keys():
dict[key].append(df_row[key])
def split_train_test(dataset_df):
labels = dataset_df[FLAGS.csv_label_columns[0]]
sparse_labels, class_counts = get_sparse_labels(labels)
test_fraction_count = (class_counts * test_set_fraction).astype(int)
print("Number of records for each class: {}".format(class_counts))
print("Number of records for each class in test set: {}".format(test_fraction_count))
train_dict = make_dict(dataset_df)
test_dict = make_dict(dataset_df)
test_count_so_far = np.zeros(3)
index = 0
for label in sparse_labels:
if test_count_so_far[label] < test_fraction_count[label]:
test_count_so_far[label] += 1
add_row(test_dict, dataset_df.iloc[index])
else:
add_row(train_dict, dataset_df.iloc[index])
index += 1
return train_dict, test_dict
train_dict, test_dict = split_train_test(dataset_df)
training_df = pd.DataFrame(train_dict)
testing_df = pd.DataFrame(test_dict)
training_df.to_csv(os.path.join("./data", "train_set.csv"), index=False)
testing_df.to_csv(os.path.join("./data", "test_set.csv"), index=False)