-
Notifications
You must be signed in to change notification settings - Fork 0
/
embed_flair.py
46 lines (32 loc) · 1.14 KB
/
embed_flair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Precompute Flair embedding for the DGIsinc dataset"""
from flair.embeddings import FlairEmbeddings, StackedEmbeddings
import json
import os
import pandas as pd
from dataset import embed
import pickle
from tqdm import tqdm
import sys
from sklearn.decomposition import IncrementalPCA
import torch as tr
embedding_model = StackedEmbeddings([
FlairEmbeddings('pubmed-forward'),
FlairEmbeddings('pubmed-backward'),
])
conf = json.load(open(sys.argv[1]))
if not os.path.isdir(conf["flair_path"]):
os.mkdir(conf["flair_path"])
PUBLICATIONS_DIR = os.path.join(conf["base_dir"], "publications/")
LABELS_PATH = os.path.join(conf["base_dir"], "labels.csv")
MAX_LEN = 10000
labels = pd.read_csv(LABELS_PATH)
interactions = labels["interaction"].unique().tolist()
embeddings = None
# Transform and save embedding
embeddings = {}
for n, pmid in enumerate(tqdm(labels["PMID"].unique())):
with open(f"{PUBLICATIONS_DIR}{pmid}.txt", encoding="utf8") as fin:
text = fin.read()
embeddings = embed(text, MAX_LEN, embedding_model)
pickle.dump(embeddings, open(f"{conf['flair_path']}{pmid}.pk",
"wb"))