Skip to content

Commit

Permalink
Merge pull request #896 from PrimozGodec/fix-ontology-cache
Browse files Browse the repository at this point in the history
[FIX] Ontology - remove cache and other fixes
  • Loading branch information
ajdapretnar authored Oct 5, 2022
2 parents d450f48 + ef2ee05 commit 6a055b3
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 207 deletions.
162 changes: 34 additions & 128 deletions orangecontrib/text/ontology.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import List, Set, Dict, Tuple, Optional, Callable
from collections import Counter
from itertools import chain
import os
import pickle

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

from orangecontrib.text.vectorization.sbert import SBERT
from Orange.misc.environ import cache_dir
from Orange.util import dummy_callback, wrap_callback

EMB_DIM = 384


class Tree:

Expand Down Expand Up @@ -222,103 +218,58 @@ def generate_ontology(
return generation[best, :], roots[best]


def cos_sim(x: np.array, y: np.array) -> float:
dot = np.dot(x, y)
return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y))


class EmbeddingStorage:

def __init__(self):
self.cache_dir = os.path.join(cache_dir(), 'ontology')
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)
self.similarities = dict()
try:
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file:
self.similarities = pickle.load(file)
except IOError:
self.similarities = dict()
self.embeddings = dict()

def save_similarities(self):
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file:
pickle.dump(self.similarities, file)

def get_embedding(self, word: str) -> Optional[np.array]:
if word in self.embeddings:
return self.embeddings[word]
try:
emb = np.load(os.path.join(self.cache_dir, f'{word}.npy'))
self.embeddings[word] = emb
return emb
except IOError:
return None

def save_embedding(self, word: str, emb: np.array) -> None:
self.embeddings[word] = emb
np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb)

def clear_storage(self) -> None:
self.similarities = dict()
self.embeddings = dict()
if os.path.isdir(self.cache_dir):
for file in os.listdir(self.cache_dir):
os.remove(os.path.join(self.cache_dir, file))


class OntologyHandler:

def __init__(self):
self.embedder = SBERT()
self.storage = EmbeddingStorage()

def generate(
self,
words: List[str],
callback: Callable = dummy_callback
) -> Dict:
if len(words) == 0:
return {}
) -> Tuple[Dict, int]:
embeddings = self.embedder(words, wrap_callback(callback, end=0.1))
non_none = [(w, e) for w, e in zip(words, embeddings) if e is not None]
skipped = len(words) - len(non_none)
if len(non_none) == 0:
return {}, skipped
words, embeddings = zip(*non_none)
sims = self._get_similarities(embeddings)
callback(0.2)
if len(words) == 1:
return {words[0]: {}}
if len(words) == 2:
return {sorted(words)[0]: {sorted(words)[1]: {}}}
sims = self._get_similarities(
words,
self._get_embeddings(words, wrap_callback(callback, end=0.1)),
wrap_callback(callback, start=0.1, end=0.2)
)
if len(words) == 3:
return {words[0]: {}}, skipped
elif len(words) == 2:
return {sorted(words)[0]: {sorted(words)[1]: {}}}, skipped
elif len(words) == 3:
root = np.argmin(np.sum(sims, axis=1))
rest = sorted([words[i] for i in range(3) if i != root])
return {words[root]: {rest[0]: {}, rest[1]: {}}}
return {words[root]: {rest[0]: {}, rest[1]: {}}}, skipped
ontology, root = generate_ontology(
words,
sims,
callback=wrap_callback(callback, start=0.2)
)
return Tree.from_prufer_sequence(ontology, words, root).to_dict()
return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped

def insert(
self,
tree: Dict,
words: List[str],
callback: Callable = dummy_callback
) -> Dict:
) -> Tuple[Dict, int]:
tree = Tree.from_dict(tree)
self._get_embeddings(words, wrap_callback(callback, end=0.3))
ticks = iter(np.linspace(0.3, 0.9, len(words)))
skipped = 0

for word in words:
tick = next(ticks)
tree.adj_list.append(set())
tree.labels.append(word)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, lambda x: callback(tick)),
lambda x: callback(tick)
)
embeddings = self.embedder(tree.labels)
if embeddings[-1] is None:
# the last embedding is for the newly provided word
# if embedding is not successful skip it
skipped += 1
continue
sims = self._get_similarities(embeddings)
idx = len(tree.adj_list) - 1
fitness_function = FitnessFunction(tree.labels, sims).fitness
scores = list()
Expand All @@ -331,65 +282,20 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
callback(tick)
callback(next(ticks))

return tree.to_dict()
return tree.to_dict(), skipped

def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
if not tree:
return 0
tree = Tree.from_dict(tree)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)),
wrap_callback(callback, start=0.7, end=0.8)
)
embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7))
sims = self._get_similarities(embeddings)
callback(0.9)
fitness_function = FitnessFunction(tree.labels, sims).fitness
return fitness_function(tree, tree.root)[0]

def _get_embeddings(
self,
words: List[str],
callback: Callable = dummy_callback
) -> np.array:
embeddings = np.zeros((len(words), EMB_DIM))
missing, missing_idx = list(), list()
ticks = iter(np.linspace(0.0, 0.6, len(words)))
for i, word in enumerate(words):
callback(next(ticks))
emb = self.storage.get_embedding(word)
if emb is None:
missing.append(word)
missing_idx.append(i)
else:
embeddings[i, :] = emb
if len(missing_idx) > 0:
embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9))
if None in embs:
raise RuntimeError("Couldn't obtain embeddings.")
embeddings[missing_idx, :] = np.array(embs)
for i in missing_idx:
self.storage.save_embedding(words[i], embeddings[i, :])

return embeddings

def _get_similarities(
self,
words: List[str],
embeddings: np.array,
callback: Callable = dummy_callback
) -> np.array:
sims = np.zeros((len(words), len(words)))
ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2)))
for i in range(len(words)):
for j in range(i + 1, len(words)):
callback(next(ticks))
key = tuple(sorted((words[i], words[j])))
try:
sim = self.storage.similarities[key]
except KeyError:
sim = cos_sim(embeddings[i, :], embeddings[j, :])
self.storage.similarities[key] = sim
sims[i, j] = sim
sims[j, i] = sim
self.storage.save_similarities()
return sims
@staticmethod
def _get_similarities(embeddings: np.array) -> np.array:
return cosine_similarity(embeddings, embeddings)
Loading

0 comments on commit 6a055b3

Please sign in to comment.