-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
146 lines (108 loc) · 4.47 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torch.utils.data import Dataset
import numpy as np
import os
import csv
import pickle
import re
from flair.data import Sentence
import torch as tr
import time
from tqdm import tqdm
import string
import flair
def compute_embeddings(labels, emb_name, embedding_model, max_len, publications_dir):
""" Precompute embeddings for each PMID in the dataset"""
cache = f"/media/DATOS/lbugnon/emb_{emb_name}.pk"
if os.path.isfile(cache):
return pickle.load(open(cache, "rb"))
embeddings = {}
for k, pmid in enumerate(tqdm(labels["PMID"].unique())):
with open(f"{publications_dir}{pmid}.txt", encoding="utf8") as fin:
text = fin.read()
tokens, keywords, tsize = embed(text, max_len, embedding_model)
embeddings[pmid] = tokens, keywords
pickle.dump(embeddings, open(cache, "wb"))
return embeddings
def embed(text, max_len, embedding_model):
tokens = text.translate(str.maketrans("", "", string.punctuation)).lower().split()
tokens = Sentence([t for t in tokens if len(t)>2])
tsize = len(tokens)
tokens.tokens = tokens.tokens[:max_len]
keywords = get_keywords(tokens)
# Apply word embedding
flair.device = 'cpu'
if embedding_model:
embedding_model.embed(tokens)
word_emb = tr.cat([token.embedding.unsqueeze(0).detach().cpu() for token in
tokens])
else:
word_emb = tr.zeros(len(tokens), 1)
return word_emb, keywords, tsize
def get_keywords(tokens):
"""Get keywords positions"""
keywords = {}
for k, token in enumerate(tokens):
if "xxx" in token.text: # keyword, either gene or drug
if token.text not in keywords:
keywords[token.text] = []
keywords[token.text].append(k)
return keywords
class InteractionsDataset(Dataset):
def __init__(self, labels, embeddings, max_len, work_dir="tmp", publications_dir=None, emb_path=None):
"""
:param publications_dir: Directory with the whole texts.
:param labels: DataFrame with examples to load. Columns are: PMID,
keyterm1, keyterm2, interaction.
:param tokenizer: Tokenizer used for specific embeding model (
"fastext" or "biobert")
:param work_dir: Temporary directory.
"""
self.embeddings = embeddings
if not embeddings:
self.files = {}
for f in os.listdir(emb_path):
if "pk" in f:
self.files[int(f.split(".")[0])] = emb_path + f
self.max_len = max_len
self.publications_dir = publications_dir
self.labels = labels
self.interactions = sorted(np.unique(labels["interaction"]))
def __len__(self):
return len(self.labels)
def __getitem__(self, item):
pmid, target_gene, target_drug, label = self.labels.iloc[item, :]
if self.embeddings:
word_emb, keywords = self.embeddings[pmid]
if type(word_emb) != np.ndarray:
word_emb = word_emb.clone().detach()
else:
word_emb = tr.tensor(word_emb)
else:
word_emb, keywords, _ = pickle.load(open(self.files[pmid], "rb"))
# is_gene, is_drug, is_target
keyword_emb = tr.zeros((word_emb.shape[0], 3))
for k in keywords:
if "g" in k:
keyword_emb[keywords[k], 0] = 1
if "d" in k:
keyword_emb[keywords[k], 1] = 1
if target_drug == k or target_gene == k:
keyword_emb[keywords[k], 2] = 1
emb_size = word_emb.shape[1] + 3
embedding = tr.zeros((self.max_len, emb_size))
# Concatenate word embeddings and one-hot keyword embeddings
embedding[:word_emb.shape[0], :] = tr.cat((word_emb, keyword_emb),
axis=1)
return embedding.T, self.interactions.index(label), f"{pmid}_{target_gene}_{target_drug}"
def get_class_weight(self):
w = tr.zeros(len(self.interactions))
for k, label in enumerate(self.interactions):
w[k] = np.sum(self.labels["interaction"] == label)
weight = ((1/w)/(tr.sum(1/w)))
return weight
def get_samples_weights(self):
"""Get samples probabilities given its labels"""
w = self.get_class_weight()
weights = [w[self.interactions.index(l)] for l in self.labels[
"interaction"]]
return weights