-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_test_splits.py
62 lines (44 loc) · 1.91 KB
/
generate_test_splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import random
import os
from collections import defaultdict
def generate_hold_out_split (dataset, training = 0.8, base_dir="splits"):
r = random.Random()
r.seed(1489215)
article_ids = list(dataset.articles.keys()) # get a list of article ids
r.shuffle(article_ids) # and shuffle that list
training_ids = article_ids[:int(training * len(article_ids))]
hold_out_ids = article_ids[int(training * len(article_ids)):]
# write the split body ids out to files for future use
with open(base_dir+ "/"+ "training_ids.txt", "w+") as f:
f.write("\n".join([str(id) for id in training_ids]))
with open(base_dir+ "/"+ "hold_out_ids.txt", "w+") as f:
f.write("\n".join([str(id) for id in hold_out_ids]))
def read_ids(file,base):
ids = []
with open(base+"/"+file,"r") as f:
for line in f:
ids.append(int(line))
return ids
def kfold_split(dataset, training = 0.8, n_folds = 10, base_dir="splits"):
if not (os.path.exists(base_dir+ "/"+ "training_ids.txt")
and os.path.exists(base_dir+ "/"+ "hold_out_ids.txt")):
generate_hold_out_split(dataset,training,base_dir)
training_ids = read_ids("training_ids.txt", base_dir)
hold_out_ids = read_ids("hold_out_ids.txt", base_dir)
folds = []
for k in range(n_folds):
folds.append(training_ids[int(k*len(training_ids)/n_folds):int((k+1)*len(training_ids)/n_folds)])
return folds,hold_out_ids
def get_stances_for_folds(dataset,folds,hold_out):
stances_folds = defaultdict(list)
stances_hold_out = []
for stance in dataset.stances:
if stance['Body ID'] in hold_out:
stances_hold_out.append(stance)
else:
fold_id = 0
for fold in folds:
if stance['Body ID'] in fold:
stances_folds[fold_id].append(stance)
fold_id += 1
return stances_folds,stances_hold_out