diff --git a/orangecontrib/text/ontology.py b/orangecontrib/text/ontology.py index f734e3216..3da9a7a5d 100644 --- a/orangecontrib/text/ontology.py +++ b/orangecontrib/text/ontology.py @@ -1,17 +1,13 @@ from typing import List, Set, Dict, Tuple, Optional, Callable from collections import Counter from itertools import chain -import os -import pickle import numpy as np +from sklearn.metrics.pairwise import cosine_similarity from orangecontrib.text.vectorization.sbert import SBERT -from Orange.misc.environ import cache_dir from Orange.util import dummy_callback, wrap_callback -EMB_DIM = 384 - class Tree: @@ -222,103 +218,58 @@ def generate_ontology( return generation[best, :], roots[best] -def cos_sim(x: np.array, y: np.array) -> float: - dot = np.dot(x, y) - return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y)) - - -class EmbeddingStorage: - - def __init__(self): - self.cache_dir = os.path.join(cache_dir(), 'ontology') - if not os.path.isdir(self.cache_dir): - os.makedirs(self.cache_dir) - self.similarities = dict() - try: - with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file: - self.similarities = pickle.load(file) - except IOError: - self.similarities = dict() - self.embeddings = dict() - - def save_similarities(self): - with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file: - pickle.dump(self.similarities, file) - - def get_embedding(self, word: str) -> Optional[np.array]: - if word in self.embeddings: - return self.embeddings[word] - try: - emb = np.load(os.path.join(self.cache_dir, f'{word}.npy')) - self.embeddings[word] = emb - return emb - except IOError: - return None - - def save_embedding(self, word: str, emb: np.array) -> None: - self.embeddings[word] = emb - np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb) - - def clear_storage(self) -> None: - self.similarities = dict() - self.embeddings = dict() - if os.path.isdir(self.cache_dir): - for file in os.listdir(self.cache_dir): - os.remove(os.path.join(self.cache_dir, file)) - - class OntologyHandler: - def __init__(self): self.embedder = SBERT() - self.storage = EmbeddingStorage() def generate( self, words: List[str], callback: Callable = dummy_callback - ) -> Dict: - if len(words) == 0: - return {} + ) -> Tuple[Dict, int]: + embeddings = self.embedder(words, wrap_callback(callback, end=0.1)) + non_none = [(w, e) for w, e in zip(words, embeddings) if e is not None] + skipped = len(words) - len(non_none) + if len(non_none) == 0: + return {}, skipped + words, embeddings = zip(*non_none) + sims = self._get_similarities(embeddings) + callback(0.2) if len(words) == 1: - return {words[0]: {}} - if len(words) == 2: - return {sorted(words)[0]: {sorted(words)[1]: {}}} - sims = self._get_similarities( - words, - self._get_embeddings(words, wrap_callback(callback, end=0.1)), - wrap_callback(callback, start=0.1, end=0.2) - ) - if len(words) == 3: + return {words[0]: {}}, skipped + elif len(words) == 2: + return {sorted(words)[0]: {sorted(words)[1]: {}}}, skipped + elif len(words) == 3: root = np.argmin(np.sum(sims, axis=1)) rest = sorted([words[i] for i in range(3) if i != root]) - return {words[root]: {rest[0]: {}, rest[1]: {}}} + return {words[root]: {rest[0]: {}, rest[1]: {}}}, skipped ontology, root = generate_ontology( words, sims, callback=wrap_callback(callback, start=0.2) ) - return Tree.from_prufer_sequence(ontology, words, root).to_dict() + return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped def insert( self, tree: Dict, words: List[str], callback: Callable = dummy_callback - ) -> Dict: + ) -> Tuple[Dict, int]: tree = Tree.from_dict(tree) - self._get_embeddings(words, wrap_callback(callback, end=0.3)) ticks = iter(np.linspace(0.3, 0.9, len(words))) + skipped = 0 for word in words: - tick = next(ticks) tree.adj_list.append(set()) tree.labels.append(word) - sims = self._get_similarities( - tree.labels, - self._get_embeddings(tree.labels, lambda x: callback(tick)), - lambda x: callback(tick) - ) + embeddings = self.embedder(tree.labels) + if embeddings[-1] is None: + # the last embedding is for the newly provided word + # if embedding is not successful skip it + skipped += 1 + continue + sims = self._get_similarities(embeddings) idx = len(tree.adj_list) - 1 fitness_function = FitnessFunction(tree.labels, sims).fitness scores = list() @@ -331,65 +282,20 @@ def insert( best = np.argmax(scores) tree.adj_list[best].add(idx) tree.adj_list[idx].add(best) - callback(tick) + callback(next(ticks)) - return tree.to_dict() + return tree.to_dict(), skipped def score(self, tree: Dict, callback: Callable = dummy_callback) -> float: + if not tree: + return 0 tree = Tree.from_dict(tree) - sims = self._get_similarities( - tree.labels, - self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)), - wrap_callback(callback, start=0.7, end=0.8) - ) + embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7)) + sims = self._get_similarities(embeddings) callback(0.9) fitness_function = FitnessFunction(tree.labels, sims).fitness return fitness_function(tree, tree.root)[0] - def _get_embeddings( - self, - words: List[str], - callback: Callable = dummy_callback - ) -> np.array: - embeddings = np.zeros((len(words), EMB_DIM)) - missing, missing_idx = list(), list() - ticks = iter(np.linspace(0.0, 0.6, len(words))) - for i, word in enumerate(words): - callback(next(ticks)) - emb = self.storage.get_embedding(word) - if emb is None: - missing.append(word) - missing_idx.append(i) - else: - embeddings[i, :] = emb - if len(missing_idx) > 0: - embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9)) - if None in embs: - raise RuntimeError("Couldn't obtain embeddings.") - embeddings[missing_idx, :] = np.array(embs) - for i in missing_idx: - self.storage.save_embedding(words[i], embeddings[i, :]) - - return embeddings - - def _get_similarities( - self, - words: List[str], - embeddings: np.array, - callback: Callable = dummy_callback - ) -> np.array: - sims = np.zeros((len(words), len(words))) - ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2))) - for i in range(len(words)): - for j in range(i + 1, len(words)): - callback(next(ticks)) - key = tuple(sorted((words[i], words[j]))) - try: - sim = self.storage.similarities[key] - except KeyError: - sim = cos_sim(embeddings[i, :], embeddings[j, :]) - self.storage.similarities[key] = sim - sims[i, j] = sim - sims[j, i] = sim - self.storage.save_similarities() - return sims + @staticmethod + def _get_similarities(embeddings: np.array) -> np.array: + return cosine_similarity(embeddings, embeddings) diff --git a/orangecontrib/text/tests/test_ontology.py b/orangecontrib/text/tests/test_ontology.py index cd0251377..2076dabea 100644 --- a/orangecontrib/text/tests/test_ontology.py +++ b/orangecontrib/text/tests/test_ontology.py @@ -1,22 +1,32 @@ import unittest +from typing import List, Union from unittest.mock import patch -from collections.abc import Iterator -import os +from typing import Iterator import asyncio import numpy as np -from orangecontrib.text.ontology import Tree, EmbeddingStorage, OntologyHandler, EMB_DIM +from orangecontrib.text.ontology import Tree, OntologyHandler +EMB_DIM = 384 RESPONSE = [ f'{{ "embedding": {[i] * EMB_DIM} }}'.encode() for i in range(4) ] +RESPONSE2 = [np.zeros(384), np.ones(384), np.zeros(384), np.ones(384)*2] +RESPONSE3 = [np.zeros(384), np.ones(384), np.arange(384), np.ones(384)*2] -class DummyResponse: +def arrays_to_response(array: List[Union[np.ndarray, List]]) -> Iterator[bytes]: + return iter(array_to_response(a) for a in array) + + +def array_to_response(array: Union[np.ndarray, List]) -> bytes: + return f'{{ "embedding": {array.tolist()} }}'.encode() + +class DummyResponse: def __init__(self, content): self.content = content @@ -72,54 +82,11 @@ def test_assertion_errors(self): Tree.from_prufer_sequence([1, 0, 3], list(map(str, range(4)))) -class TestEmbeddingStorage(unittest.TestCase): - - def setUp(self): - self.storage = EmbeddingStorage() - - def tearDown(self): - self.storage.clear_storage() - - def test_clear_storage(self): - self.storage.save_embedding("testword", np.zeros(3)) - self.assertEqual(len(self.storage.embeddings), 1) - self.storage.clear_storage() - self.assertEqual(len(self.storage.embeddings), 0) - self.assertEqual(len(os.listdir(self.storage.cache_dir)), 0) - - def test_save_embedding(self): - self.storage.save_embedding("testword", np.zeros(3)) - self.storage.save_embedding("testword2", np.zeros(3)) - self.assertEqual(len(self.storage.embeddings), 2) - self.assertEqual(len(os.listdir(self.storage.cache_dir)), 2) - - def test_get_embedding(self): - self.storage.save_embedding("testword", np.arange(3)) - emb = self.storage.get_embedding("testword") - self.assertEqual(emb.tolist(), [0, 1, 2]) - - def test_get_from_cache(self): - self.storage.save_embedding("testword", np.arange(3)) - self.storage.embeddings = dict() - emb = self.storage.get_embedding("testword") - self.assertEqual(emb.tolist(), [0, 1, 2]) - - def test_similarities(self): - self.storage.similarities['a', 'b'] = 0.75 - self.storage.save_similarities() - storage = EmbeddingStorage() - self.assertEqual(len(storage.similarities), 1) - self.assertTrue(('a', 'b') in storage.similarities) - self.assertEqual(storage.similarities['a', 'b'], 0.75) - - class TestOntologyHandler(unittest.TestCase): - def setUp(self): self.handler = OntologyHandler() def tearDown(self): - self.handler.storage.clear_storage() self.handler.embedder.clear_cache() @patch('orangecontrib.text.ontology.generate_ontology') @@ -128,48 +95,94 @@ def test_small_trees(self, mock): self.handler.generate(words) mock.assert_not_called() + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3))) def test_generate_small(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - tree = self.handler.generate(['1', '2', '3']) + tree, skipped = self.handler.generate(['1', '2', '3']) self.assertTrue(isinstance(tree, dict)) + self.assertEqual(skipped, 0) - @patch('httpx.AsyncClient.post') - def test_generate(self, mock): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - self.handler.storage.save_embedding('4', np.ones(384) * 2) - tree = self.handler.generate(['1', '2', '3', '4']) + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3))) + def test_generate(self): + tree, skipped = self.handler.generate(['1', '2', '3', '4']) self.assertTrue(isinstance(tree, dict)) - mock.request.assert_not_called() - mock.get_response.assert_not_called() + self.assertEqual(skipped, 0) @patch('httpx.AsyncClient.post', make_dummy_post(iter(RESPONSE))) def test_generate_with_unknown_embeddings(self): - tree = self.handler.generate(['1', '2', '3', '4']) + tree, skipped = self.handler.generate(['1', '2', '3', '4']) self.assertTrue(isinstance(tree, dict)) + self.assertEqual(skipped, 0) + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE2))) def test_insert(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - self.handler.storage.save_embedding('4', np.ones(384) * 2) - tree = self.handler.generate(['1', '2', '3']) - new_tree = self.handler.insert(tree, ['4']) + tree, skipped = self.handler.generate(['1', '2', '3']) + self.assertEqual(skipped, 0) + new_tree, skipped = self.handler.insert(tree, ['4']) self.assertGreater( len(Tree.from_dict(new_tree).adj_list), len(Tree.from_dict(tree).adj_list) ) + self.assertEqual(skipped, 0) + @patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384)))) def test_score(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - tree = self.handler.generate(['1', '2', '3']) + tree, skipped = self.handler.generate(['1', '2', '3']) + score = self.handler.score(tree) + self.assertGreater(score, 0) + self.assertEqual(skipped, 0) + + @patch('httpx.AsyncClient.post', make_dummy_post(b'')) + def test_embedding_fails_generate(self): + """ Tests the case when embedding fails totally - return empty tree """ + tree, skipped = self.handler.generate(['1', '2', '3']) + score = self.handler.score(tree) + self.assertDictEqual(tree, {}) + self.assertEqual(score, 0) + self.assertEqual(skipped, 3) + + @patch('httpx.AsyncClient.post', make_dummy_post( + iter(list(arrays_to_response([np.arange(384), np.ones(384)])) + [b""] * 3) + )) + def test_some_embedding_fails_generate(self): + """ + Tests the case when embedding fail partially + - consider only successfully embedded words + """ + tree, skipped = self.handler.generate(['1', '2', '3']) score = self.handler.score(tree) + self.assertDictEqual(tree, {'1': {'2': {}}}) self.assertGreater(score, 0) + self.assertEqual(skipped, 1) + + @patch('httpx.AsyncClient.post', make_dummy_post( + # success for generate part and fail for insert part + iter(list(arrays_to_response(RESPONSE3)) + [b""] * 3) + )) + def test_embedding_fails_insert(self): + """ + Tests the case when embedding fails for word that is tried to be inserted + - don't insert it + """ + tree, skipped = self.handler.generate(['1', '2', '3', '4']) + self.assertEqual(skipped, 0) + new_tree, skipped = self.handler.insert(tree, ['5']) + self.assertDictEqual(tree, new_tree) + self.assertEqual(skipped, 1) + + @patch('httpx.AsyncClient.post', make_dummy_post( + # success for generate part and fail for part of new inputs + iter(list(arrays_to_response(RESPONSE3)) + [b""] * 3) + )) + def test_some_embedding_fails_insert(self): + """ + ests the case when embedding fails for some words that are tried to be + inserted - insert only successfully embedded words + """ + tree, skipped = self.handler.generate(['1', '2', '3']) + self.assertEqual(skipped, 0) + new_tree, skipped = self.handler.insert(tree, ['4', '5']) + self.assertDictEqual(new_tree, {'1': {'2': {'4': {}}, '3': {}}}) + self.assertEqual(skipped, 1) if __name__ == '__main__': diff --git a/orangecontrib/text/widgets/owontology.py b/orangecontrib/text/widgets/owontology.py index b306befb2..6263bbc3a 100644 --- a/orangecontrib/text/widgets/owontology.py +++ b/orangecontrib/text/widgets/owontology.py @@ -153,7 +153,7 @@ def __init__(self, data_changed_cb: Callable): edit_triggers = QTreeView.DoubleClicked | QTreeView.EditKeyPressed super().__init__( - editTriggers=int(edit_triggers), + editTriggers=edit_triggers, selectionMode=QTreeView.ExtendedSelection, dragEnabled=True, acceptDrops=True, @@ -165,7 +165,7 @@ def __init__(self, data_changed_cb: Callable): self.__disconnected = False - def startDrag(self, actions: Qt.DropActions): + def startDrag(self, actions: Qt.DropAction): with disconnected(self.model().dataChanged, self.__data_changed_cb): super().startDrag(actions) self.drop_finished.emit() @@ -598,6 +598,7 @@ class Outputs: class Warning(OWWidget.Warning): no_words_column = Msg("Input is missing 'Words' column.") + skipped_words = Msg("{} terms are skipped due to server connection error.") class Error(OWWidget.Error): load_error = Msg("{}") @@ -626,7 +627,7 @@ def _setup_gui(self): edit_triggers = QListView.DoubleClicked | QListView.EditKeyPressed self.__library_view = QListView( - editTriggers=int(edit_triggers), + editTriggers=edit_triggers, minimumWidth=200, sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Expanding), ) @@ -830,23 +831,28 @@ def _cancel_tasks(self): def _run(self): self.__run_button.setText("Stop") + self.Warning.skipped_words.clear() words = self.__ontology_view.get_words() handler = self.__onto_handler.generate self.start(_run, handler, (words,)) def _run_insert(self): self.__inc_button.setText("Stop") + self.Warning.skipped_words.clear() tree = self.__ontology_view.get_data() words = self.__get_selected_input_words() handler = self.__onto_handler.insert self.start(_run, handler, (tree, words)) - def on_done(self, data: Dict): + def on_done(self, result: Tuple[Dict, int]): + data, num_skipped = result self.__inc_button.setText(self.INC_BUTTON) self.__run_button.setText(self.RUN_BUTTON) self.__ontology_view.set_data(data, keep_history=True) self.__set_current_modified(self.CACHED) self.__update_score() + if num_skipped > 0: + self.Warning.skipped_words(num_skipped) def __update_score(self): tree = self.__ontology_view.get_data() diff --git a/orangecontrib/text/widgets/tests/test_owontology.py b/orangecontrib/text/widgets/tests/test_owontology.py index 31b967369..6156d5d93 100644 --- a/orangecontrib/text/widgets/tests/test_owontology.py +++ b/orangecontrib/text/widgets/tests/test_owontology.py @@ -5,9 +5,10 @@ import unittest from unittest.mock import Mock, patch +import numpy as np from AnyQt.QtCore import Qt, QItemSelectionModel, QItemSelection, \ QItemSelectionRange -from AnyQt.QtWidgets import QFileDialog +from AnyQt.QtWidgets import QFileDialog, QPushButton from Orange.data import Table from Orange.widgets.tests.base import WidgetTest @@ -35,15 +36,15 @@ def setUp(self): def test_run(self): result = _run(self.handler.generate, (self.words,), self.state) - self.assertEqual(result, {"bar": {"foo": {}}}) + self.assertEqual(result, ({"bar": {"foo": {}}}, 0)) def test_run_single_word(self): result = _run(self.handler.generate, (["foo"],), self.state) - self.assertEqual(result, {"foo": {}}) + self.assertEqual(result, ({"foo": {}}, 0)) def test_run_empty(self): result = _run(self.handler.generate, ([],), self.state) - self.assertEqual(result, {}) + self.assertEqual(result, ({}, 0)) def test_run_interrupt(self): state = Mock() @@ -261,6 +262,85 @@ def test_library_save(self): def test_send_report(self): self.widget.send_report() + def test_skipped_words_generate(self): + """ + Test case when embedding fails when generating the ontology. It results + in exclusion of non-embedded terms and warning. + """ + get_ontology_data = self.widget._OWOntology__ontology_view.get_data + self.assertDictEqual(get_ontology_data(), {"foo1": {"bar1": {}, "baz1": {}}}) + + # generate with embedding error - two skipped + with patch( + "orangecontrib.text.vectorization.sbert.SBERT.__call__", + return_value=[np.ones(300), None, None], + ): + self.widget._OWOntology__run_button.click() + self.wait_until_finished() + self.assertDictEqual(get_ontology_data(), {"foo1": {}}) + self.assertTrue(self.widget.Warning.skipped_words.is_shown()) + self.assertEqual( + str(self.widget.Warning.skipped_words), + "2 terms are skipped due to server connection error.", + ) + + # generate without embedding error + with patch( + "orangecontrib.text.vectorization.sbert.SBERT.__call__", + return_value=[np.ones(300)], + ): + self.widget._OWOntology__run_button.click() + self.wait_until_finished() + self.assertDictEqual(get_ontology_data(), {"foo1": {}}) + self.assertFalse(self.widget.Warning.skipped_words.is_shown()) + + def test_skipped_words_insert(self): + """ + Test case when embedding fails when inserting the term. It results + in exclusion of non-embedded terms and warning. + """ + words = create_words_table(["foo2", "foo3"]) + self.send_signal(self.widget.Inputs.words, words) + + # insert with an embedding error + with patch( + "orangecontrib.text.vectorization.sbert.SBERT.__call__", + side_effect=[ + [np.ones(300), np.ones(300), np.ones(300), None], + [np.ones(300), np.ones(300), np.ones(300)], + ], + ): + get_ontology_data = self.widget._OWOntology__ontology_view.get_data + self.assertDictEqual( + get_ontology_data(), {"foo1": {"bar1": {}, "baz1": {}}} + ) + + self.widget._OWOntology__input_view.setCurrentIndex( + self.widget._OWOntology__input_model.index(0, 0) + ) + self.widget._OWOntology__inc_button.click() + self.wait_until_finished() + self.assertDictEqual( + get_ontology_data(), {"foo1": {"bar1": {}, "baz1": {}}} + ) + self.assertTrue(self.widget.Warning.skipped_words.is_shown()) + self.assertEqual( + str(self.widget.Warning.skipped_words), + "1 terms are skipped due to server connection error.", + ) + + # insert without embedding error + with patch( + "orangecontrib.text.vectorization.sbert.SBERT.__call__", + return_value=[np.ones(300)] * 4, + ): + self.widget._OWOntology__inc_button.click() + self.wait_until_finished() + self.assertDictEqual( + get_ontology_data(), {"foo1": {"bar1": {}, "baz1": {}, "foo2": {}}} + ) + self.assertFalse(self.widget.Warning.skipped_words.is_shown()) + if __name__ == "__main__": unittest.main()