-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathvector_scaling.py
30 lines (27 loc) · 1.03 KB
/
vector_scaling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import sys
import numpy
if __name__ == "__main__":
embeddings_path = sys.argv[1]
weights_path = sys.argv[2]
output_path = sys.argv[3]
weights = {}
with open(weights_path, 'r') as f:
for line in f:
line_parts = line.strip().split()
assert(len(line_parts) == 2), "Incorrect weight format"
weights[line_parts[0]] = float(line_parts[1])
with open(embeddings_path, 'r') as f:
with open(output_path, 'w') as o:
for line in f:
line_parts = line.strip().split()
if len(line_parts) <= 2:
o.write(line.strip() + "\n")
continue
vector = numpy.array([float(val) for val in line_parts[1:]])
word = line_parts[0]
if word in weights:
vector = weights[word] * vector
o.write(word)
for val in vector:
o.write(" " + str("{:.8f}".format(val)))
o.write("\n")