Skip to content

Commit

Permalink
Ontology - remove cache
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Aug 30, 2022
1 parent 339ad59 commit 9303a25
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 185 deletions.
129 changes: 12 additions & 117 deletions orangecontrib/text/ontology.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import List, Set, Dict, Tuple, Optional, Callable
from collections import Counter
from itertools import chain
import os
import pickle

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

from orangecontrib.text.vectorization.sbert import SBERT
from Orange.misc.environ import cache_dir
from Orange.util import dummy_callback, wrap_callback

EMB_DIM = 384


class Tree:

Expand Down Expand Up @@ -222,56 +218,9 @@ def generate_ontology(
return generation[best, :], roots[best]


def cos_sim(x: np.array, y: np.array) -> float:
dot = np.dot(x, y)
return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y))


class EmbeddingStorage:

def __init__(self):
self.cache_dir = os.path.join(cache_dir(), 'ontology')
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)
self.similarities = dict()
try:
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file:
self.similarities = pickle.load(file)
except IOError:
self.similarities = dict()
self.embeddings = dict()

def save_similarities(self):
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file:
pickle.dump(self.similarities, file)

def get_embedding(self, word: str) -> Optional[np.array]:
if word in self.embeddings:
return self.embeddings[word]
try:
emb = np.load(os.path.join(self.cache_dir, f'{word}.npy'))
self.embeddings[word] = emb
return emb
except IOError:
return None

def save_embedding(self, word: str, emb: np.array) -> None:
self.embeddings[word] = emb
np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb)

def clear_storage(self) -> None:
self.similarities = dict()
self.embeddings = dict()
if os.path.isdir(self.cache_dir):
for file in os.listdir(self.cache_dir):
os.remove(os.path.join(self.cache_dir, file))


class OntologyHandler:

def __init__(self):
self.embedder = SBERT()
self.storage = EmbeddingStorage()

def generate(
self,
Expand All @@ -284,11 +233,9 @@ def generate(
return {words[0]: {}}
if len(words) == 2:
return {sorted(words)[0]: {sorted(words)[1]: {}}}
sims = self._get_similarities(
words,
self._get_embeddings(words, wrap_callback(callback, end=0.1)),
wrap_callback(callback, start=0.1, end=0.2)
)
embeddings = self.embedder(words, wrap_callback(callback, end=0.1))
sims = self._get_similarities(embeddings)
callback(0.2)
if len(words) == 3:
root = np.argmin(np.sum(sims, axis=1))
rest = sorted([words[i] for i in range(3) if i != root])
Expand All @@ -307,18 +254,13 @@ def insert(
callback: Callable = dummy_callback
) -> Dict:
tree = Tree.from_dict(tree)
self._get_embeddings(words, wrap_callback(callback, end=0.3))
ticks = iter(np.linspace(0.3, 0.9, len(words)))

for word in words:
tick = next(ticks)
tree.adj_list.append(set())
tree.labels.append(word)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, lambda x: callback(tick)),
lambda x: callback(tick)
)
embeddings = self.embedder(tree.labels)
sims = self._get_similarities(embeddings)
idx = len(tree.adj_list) - 1
fitness_function = FitnessFunction(tree.labels, sims).fitness
scores = list()
Expand All @@ -331,65 +273,18 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
callback(tick)
callback(next(ticks))

return tree.to_dict()

def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
tree = Tree.from_dict(tree)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)),
wrap_callback(callback, start=0.7, end=0.8)
)
embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7))
sims = self._get_similarities(embeddings)
callback(0.9)
fitness_function = FitnessFunction(tree.labels, sims).fitness
return fitness_function(tree, tree.root)[0]

def _get_embeddings(
self,
words: List[str],
callback: Callable = dummy_callback
) -> np.array:
embeddings = np.zeros((len(words), EMB_DIM))
missing, missing_idx = list(), list()
ticks = iter(np.linspace(0.0, 0.6, len(words)))
for i, word in enumerate(words):
callback(next(ticks))
emb = self.storage.get_embedding(word)
if emb is None:
missing.append(word)
missing_idx.append(i)
else:
embeddings[i, :] = emb
if len(missing_idx) > 0:
embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9))
if None in embs:
raise RuntimeError("Couldn't obtain embeddings.")
embeddings[missing_idx, :] = np.array(embs)
for i in missing_idx:
self.storage.save_embedding(words[i], embeddings[i, :])

return embeddings

def _get_similarities(
self,
words: List[str],
embeddings: np.array,
callback: Callable = dummy_callback
) -> np.array:
sims = np.zeros((len(words), len(words)))
ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2)))
for i in range(len(words)):
for j in range(i + 1, len(words)):
callback(next(ticks))
key = tuple(sorted((words[i], words[j])))
try:
sim = self.storage.similarities[key]
except KeyError:
sim = cos_sim(embeddings[i, :], embeddings[j, :])
self.storage.similarities[key] = sim
sims[i, j] = sim
sims[j, i] = sim
self.storage.save_similarities()
return sims
@staticmethod
def _get_similarities(embeddings: np.array) -> np.array:
return cosine_similarity(embeddings, embeddings)
84 changes: 19 additions & 65 deletions orangecontrib/text/tests/test_ontology.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import unittest
from typing import List, Union
from unittest.mock import patch
from collections.abc import Iterator
import os
from typing import Iterator
import asyncio

import numpy as np

from orangecontrib.text.ontology import Tree, EmbeddingStorage, OntologyHandler, EMB_DIM
from orangecontrib.text.ontology import Tree, OntologyHandler


EMB_DIM = 384
RESPONSE = [
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
for i in range(4)
]
RESPONSE2 = [np.zeros(384), np.ones(384), np.zeros(384), np.ones(384)*2]
RESPONSE3 = [np.zeros(384), np.ones(384), np.arange(384), np.ones(384)*2]


class DummyResponse:
def arrays_to_response(array: List[Union[np.ndarray, List]]) -> Iterator[bytes]:
return iter(array_to_response(a) for a in array)


def array_to_response(array: Union[np.ndarray, List]) -> bytes:
return f'{{ "embedding": {array.tolist()} }}'.encode()


class DummyResponse:
def __init__(self, content):
self.content = content

Expand Down Expand Up @@ -72,54 +82,11 @@ def test_assertion_errors(self):
Tree.from_prufer_sequence([1, 0, 3], list(map(str, range(4))))


class TestEmbeddingStorage(unittest.TestCase):

def setUp(self):
self.storage = EmbeddingStorage()

def tearDown(self):
self.storage.clear_storage()

def test_clear_storage(self):
self.storage.save_embedding("testword", np.zeros(3))
self.assertEqual(len(self.storage.embeddings), 1)
self.storage.clear_storage()
self.assertEqual(len(self.storage.embeddings), 0)
self.assertEqual(len(os.listdir(self.storage.cache_dir)), 0)

def test_save_embedding(self):
self.storage.save_embedding("testword", np.zeros(3))
self.storage.save_embedding("testword2", np.zeros(3))
self.assertEqual(len(self.storage.embeddings), 2)
self.assertEqual(len(os.listdir(self.storage.cache_dir)), 2)

def test_get_embedding(self):
self.storage.save_embedding("testword", np.arange(3))
emb = self.storage.get_embedding("testword")
self.assertEqual(emb.tolist(), [0, 1, 2])

def test_get_from_cache(self):
self.storage.save_embedding("testword", np.arange(3))
self.storage.embeddings = dict()
emb = self.storage.get_embedding("testword")
self.assertEqual(emb.tolist(), [0, 1, 2])

def test_similarities(self):
self.storage.similarities['a', 'b'] = 0.75
self.storage.save_similarities()
storage = EmbeddingStorage()
self.assertEqual(len(storage.similarities), 1)
self.assertTrue(('a', 'b') in storage.similarities)
self.assertEqual(storage.similarities['a', 'b'], 0.75)


class TestOntologyHandler(unittest.TestCase):

def setUp(self):
self.handler = OntologyHandler()

def tearDown(self):
self.handler.storage.clear_storage()
self.handler.embedder.clear_cache()

@patch('orangecontrib.text.ontology.generate_ontology')
Expand All @@ -128,45 +95,32 @@ def test_small_trees(self, mock):
self.handler.generate(words)
mock.assert_not_called()

@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
def test_generate_small(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
self.assertTrue(isinstance(tree, dict))

@patch('httpx.AsyncClient.post')
def test_generate(self, mock):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
self.handler.storage.save_embedding('4', np.ones(384) * 2)
@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
def test_generate(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))
mock.request.assert_not_called()
mock.get_response.assert_not_called()

@patch('httpx.AsyncClient.post', make_dummy_post(iter(RESPONSE)))
def test_generate_with_unknown_embeddings(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))

@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE2)))
def test_insert(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
self.handler.storage.save_embedding('4', np.ones(384) * 2)
tree = self.handler.generate(['1', '2', '3'])
new_tree = self.handler.insert(tree, ['4'])
self.assertGreater(
len(Tree.from_dict(new_tree).adj_list),
len(Tree.from_dict(tree).adj_list)
)

@patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
score = self.handler.score(tree)
self.assertGreater(score, 0)
Expand Down
6 changes: 3 additions & 3 deletions orangecontrib/text/widgets/owontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, data_changed_cb: Callable):

edit_triggers = QTreeView.DoubleClicked | QTreeView.EditKeyPressed
super().__init__(
editTriggers=int(edit_triggers),
editTriggers=edit_triggers,
selectionMode=QTreeView.ExtendedSelection,
dragEnabled=True,
acceptDrops=True,
Expand All @@ -165,7 +165,7 @@ def __init__(self, data_changed_cb: Callable):

self.__disconnected = False

def startDrag(self, actions: Qt.DropActions):
def startDrag(self, actions: Qt.DropAction):
with disconnected(self.model().dataChanged, self.__data_changed_cb):
super().startDrag(actions)
self.drop_finished.emit()
Expand Down Expand Up @@ -626,7 +626,7 @@ def _setup_gui(self):

edit_triggers = QListView.DoubleClicked | QListView.EditKeyPressed
self.__library_view = QListView(
editTriggers=int(edit_triggers),
editTriggers=edit_triggers,
minimumWidth=200,
sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Expanding),
)
Expand Down

0 comments on commit 9303a25

Please sign in to comment.