-
Notifications
You must be signed in to change notification settings - Fork 1
/
vectorize.py
60 lines (42 loc) · 1.75 KB
/
vectorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from reach import Reach
from tqdm import tqdm
import numpy as np
import fasttext
#############################################################
#############################################################
################## Sampling #######################
#############################################################
#############################################################
# fasttext_model_path = '/home/fivez/disorder_linking/data/embeddings/pub2fast.bin'
class Vectorize:
def __init__(self, fasttext_model_path):
self.fasttext_model_path = fasttext_model_path
print('Loading fastText model...')
self.fasttext_model = fasttext.FastText.load_model(self.fasttext_model_path)
print('Done')
self.pretrained_name_embeddings = None
self.construct_oov = False
def allow_construct_oov(self):
if not self.construct_oov:
self.construct_oov = True
def vectorize_string(self, string, norm):
tokens = string.split()
token_embeddings = []
for token in tokens:
vector = self.fasttext_model.get_word_vector(token)
if norm:
vector = Reach.normalize(vector)
token_embeddings.append(vector)
token_embeddings = np.array(token_embeddings)
return token_embeddings
def create_reach_object(self, names, outfile=''):
names = sorted(names)
vectors = []
for name in tqdm(names):
token_embs = self.vectorize_string(name, norm=False)
vector = np.average(np.array(token_embs), axis=0)
vectors.append(vector)
reach_object = Reach(vectors, names)
if outfile:
reach_object.save_fast_format(outfile)
return reach_object