Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

Commit

Permalink
Spacy compatibility improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
oroszgy committed Oct 28, 2021
1 parent 99d5a82 commit edc8b47
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
7 changes: 4 additions & 3 deletions lemmy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lemmy.lemmatizer import Lemmatizer


@Language.factory("lemmy3", default_config={"data": None})
def create(nlp: Language, name: str, data: Union[str, Path]) -> Lemmatizer:
return Lemmatizer
# noinspection PyUnusedLocal
@Language.factory("lemmy3")
def create(nlp: Language, name: str, model_path: Union[str, Path]) -> "HunLemmatizer":
return Lemmatizer.from_disk(model_path)
7 changes: 3 additions & 4 deletions lemmy/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import timeit
from pathlib import Path
from typing import Tuple, List

Expand Down Expand Up @@ -37,7 +36,7 @@ def read_file(path: Path, ignore_contractions: bool = True) -> Tuple[TaggedWords
def debug(model_path: Path, test_data: Path, ignore_contractions: bool = True):
lemmatizer: Lemmatizer = Lemmatizer.from_disk(model_path)
tagged_words, lemmata = read_file(test_data, ignore_contractions)
predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position) for tag, word, position in tagged_words]
predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position == 0) for tag, word, position in tagged_words]
for lemma, pred, (tag, word, position) in zip(lemmata, predicted, tagged_words):
if lemma != pred:
print(f"Wrong lemma for {word}[{tag}]: '{pred}', should be '{lemma}'")
Expand All @@ -47,7 +46,7 @@ def debug(model_path: Path, test_data: Path, ignore_contractions: bool = True):
def evaluate(model_path: Path, test_data: Path, ignore_contractions: bool = True):
lemmatizer: Lemmatizer = Lemmatizer.from_disk(model_path)
tagged_words, lemmata = read_file(test_data, ignore_contractions)
predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position) for tag, word, position in tagged_words]
predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position == 0) for tag, word, position in tagged_words]
accuracy = sum([gt == pred for gt, pred in zip(lemmata, predicted)]) / float(len(lemmata))
print(f"Accuracy: {accuracy:.2%}")

Expand All @@ -61,7 +60,7 @@ def train(train_path: Path, model_path: Path, max_iterations: int = typer.Option
lemmatizer: Lemmatizer = Lemmatizer()
lemmatizer.fit(tagged_words, lemmata, max_iteration=max_iterations)

predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position) for tag, word, position in tagged_words]
predicted: List[Lemma] = [lemmatizer.lemmatize(tag, word, position == 0) for tag, word, position in tagged_words]
accuracy = sum([gt == pred for gt, pred in zip(lemmata, predicted)]) / float(len(lemmata))
print(f"Accuracy: {accuracy:.2%}")

Expand Down
4 changes: 3 additions & 1 deletion lemmy/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _to_bytes(self) -> Dict[str, Any]:

@classmethod
def from_bytes(cls, bytes_data: bytes) -> C:
msg = srsly.msgpack_loads(bytes_data)
msg = srsly.msgpack_loads(bytes_data, use_list=False)
version = msg["_version"]
if version != cls._version():
raise TypeError(f"Incompatible versions: expected: {cls._version()} but got {version}")
Expand All @@ -34,11 +34,13 @@ def from_bytes(cls, bytes_data: bytes) -> C:
def _from_bytes(cls, msg: Dict[str, Any]) -> C:
raise NotImplementedError

# noinspection PyUnusedLocal
def to_disk(self, path: Union[str, Path], exclude=tuple()) -> None:
path = ensure_path(path)
with path.open("wb") as file_:
file_.write(self.to_bytes())

# noinspection PyUnusedLocal
@classmethod
def from_disk(cls, path: Union[str, Path], exclude=tuple()) -> C:
path = ensure_path(path)
Expand Down

0 comments on commit edc8b47

Please sign in to comment.