-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpetal_snorkel_train_golden.py
189 lines (165 loc) · 7.33 KB
/
petal_snorkel_train_golden.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# -*- coding: utf-8 -*-
# ------------------------------------------------- #
# Title: petal_snorkel.py
# Description: main file for running snorkel ML model
# ChangeLog: (Name, Date: MM-DD-YY, Changes)
# <ARalevski, 10-01-2021, created script>
# <PahtJ, 10-15-21, updated pickle>
# <ARalevski, 10-17-2021, updated cardinality>
# Authors: [email protected], [email protected]
# ------------------------------------------------- #
from genericpath import exists
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import CountVectorizer
import os.path as osp
from tqdm import trange
import sys
sys.path.insert(0, '../snorkel')
from create_labeling_functions import create_labeling_functions
from utils import smaller_models, single_model_to_dict, compare_single_model_dicts, normalize_L, compare_single_model_dict
from copy import deepcopy
from ast import literal_eval
from snorkel.labeling.model import MajorityLabelVoter
from snorkel.labeling import PandasLFApplier
from snorkel.labeling.model import LabelModel
import wget
import pickle
import json
# import csv file and load train/test/split of dataset
golden_json_url = 'https://raw.githubusercontent.com/nasa-petal/data-collection-and-prep/main/golden.json'
filename = 'golden.json'
large_model = 'large_model_trained.pickle'
small_models = 'small_models_trained.pickle'
'''
Download golden json
'''
if not osp.exists(filename):
wget.download(golden_json_url)
with open(filename, 'r') as f:
golden_json = json.load(f)
'''
Train snorkel using Golden.json
'''
datalist = list()
for paper in golden_json:
data = dict()
data['text'] = ' '.join(literal_eval(
paper['title']) + literal_eval(paper['abstract']))
data['doi'] = paper['doi']
data['paperid'] = paper['paper']
data['title'] = ' '.join(literal_eval(paper['title']))
data['abstract'] = ' '.join(literal_eval(paper['abstract']))
data['label_level_1'] = paper['level1'] # Assign this because it's coming from golden json
datalist.append(data)
df = pd.DataFrame(datalist)
df_bio = pd.read_csv(r'./biomimicry_functions_enumerated.csv')
labels = dict(
zip(df_bio['function_enumerated'].tolist(), df_bio['function'].tolist()))
'''
Loop through all Golden JSON and create L-matrix
'''
if not osp.exists('golden_lf.pickle'):
labeling_function_list = create_labeling_functions(r'./biomimicry_functions_enumerated.csv', r'./biomimicry_function_rules.csv')
applier = PandasLFApplier(lfs=labeling_function_list)
L_golden = applier.apply(df=df)
labels_overlap, L_matches, translators, translators_to_str, L_match_all, global_translator, global_translator_str, dfs = smaller_models(
L_golden, 5, 2, labels_list=labels, df=df)
with open('golden_lf.pickle', 'wb') as f:
pickle.dump({'L_golden': L_golden, 'labels_overlap': labels_overlap, 'L_matches': L_matches,
'translators': translators, 'translators_to_str': translators_to_str,
'L_matches_all': L_match_all, 'global_translator': global_translator,
'global_translator_str': global_translator_str, 'dfs': dfs}, f)
with open('golden_lf.pickle', 'rb') as f:
data = pickle.load(f)
L_golden = data['L_golden']
print("Unique Matches in golden.json: ")
print(*np.unique(L_golden).tolist(), sep=", ")
L_matches = data['L_matches']
labels_overlap = data['labels_overlap']
translators = data['translators']
translators_to_str = data['translators_to_str']
global_translator_str = data['global_translator_str']
dfs = data['dfs']
global_translator = data['global_translator']
L_match_all = data['L_matches_all']
'''
Train small models
Note: some models are very small to splitting them into test and train can be tricky
'''
if not osp.exists(small_models):
models = list()
for i in trange(len(L_matches), desc="training small models"):
L_match = L_matches[i]
# TODO split the dataset so theres an equal amount of all labels
L_train = L_match
L_test = L_match
labels = labels_overlap[i]
cardinality = len(labels) # How many labels to predict
majority_model = MajorityLabelVoter(cardinality=cardinality)
# Looks at each text and sees which label is predicted the most
preds_train = majority_model.predict(L=L_train)
# Train LabelModel - this outputs probabilistic floats
label_model = LabelModel(
cardinality=cardinality, verbose=True, device='cpu')
label_model.fit(L_train=L_train, n_epochs=350, log_freq=100, seed=123)
# This gives you the probability of which label paper falls under
probs_train = label_model.predict_proba(L=L_train)
# this label model can help predict the type of paper
models.append(label_model)
with open(small_models, 'wb') as f:
pickle.dump({"Label_models": models, 'labels_overlap': labels_overlap,
'translators': translators, 'translators_to_str': translators_to_str,
'texts_df': dfs}, f)
'''
Train large models
Note: some models are very small to splitting them into test and train can be tricky
'''
if not osp.exists(large_model):
# Training a single large model
cardinality = len(global_translator)
majority_model = MajorityLabelVoter(cardinality=cardinality)
preds_train = majority_model.predict(L=L_match_all)
label_model = LabelModel(cardinality=cardinality, verbose=True, device='cpu')
label_model.fit(L_train=L_match_all, n_epochs=300, log_freq=50, seed=123)
with open(large_model, 'wb') as f:
pickle.dump({"Label_model": label_model, 'global_translator': global_translator,
'global_translator_str': global_translator_str, 'text_df': df}, f)
'''
Evaluation using smaller models
'''
if osp.exists(small_models):
with open(small_models,'rb') as f:
smaller_model_data = pickle.load(f)
results = list()
for i in range(len(smaller_model_data['Label_models'])):
results.extend(single_model_to_dict(L_matches[i],smaller_model_data['Label_models'][i], smaller_model_data['translators_to_str'][i],i,dfs[i]))
# Filter papers by unique doi
df_sm = pd.DataFrame(results)
doi_all = df_sm['doi'].unique()
results = list()
for doi in doi_all:
df_unique_doi = df_sm[df_sm['doi']==doi]
papers = [df_unique_doi.iloc[i].to_dict() for i in range(len(df_unique_doi))]
for p in range(len(papers)):
if p == 0:
results.append(papers[p])
else:
results[-1] = compare_single_model_dict(results[-1],papers[p])
df_sm = pd.DataFrame(results)
df_sm.to_csv("golden json matches small models.csv")
'''
Evaluate using larger model
'''
if osp.exists(large_model):
with open(large_model,'rb') as f:
large_model_data = pickle.load(f)
large_label_model = large_model_data['Label_model']
global_translator = large_model_data['global_translator'] # old labels to new
global_translator_str = large_model_data['global_translator_str']
large_model_L = normalize_L(L=L_golden,translator=global_translator)
large_model_results = single_model_to_dict(large_model_L,large_label_model, global_translator_str,0,df)
df_lg = pd.DataFrame(large_model_results)
df_lg.to_csv("golden json matches large model.csv")