From 1b5df97522d4e13a8581836d514e37c5f8c5a31c Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 28 Jun 2024 16:34:18 +0200 Subject: [PATCH 1/6] character embeddings store their embedding name too --- .gitignore | 2 ++ flair/embeddings/token.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 53e248b6d1..cbd5be92d3 100644 --- a/.gitignore +++ b/.gitignore @@ -109,3 +109,5 @@ venv.bak/ resources/taggers/ regression_train/ /doc_build/ + +scripts/ \ No newline at end of file diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 76173bac80..669d3162a5 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -454,13 +454,14 @@ def __init__( path_to_char_dict: Optional[Union[str, Dictionary]] = None, char_embedding_dim: int = 25, hidden_size_char: int = 25, + name: str = "Char" ) -> None: """Instantiates a bidirectional lstm layer toi encode words by their character representation. Uses the default character dictionary if none provided. """ super().__init__() - self.name = "Char" + self.name = name self.static_embeddings = False self.instance_parameters = self.get_instance_parameters(locals=locals()) @@ -556,6 +557,7 @@ def to_params(self) -> Dict[str, Any]: "path_to_char_dict": self.char_dictionary, "char_embedding_dim": self.char_embedding_dim, "hidden_size_char": self.hidden_size_char, + "name": self.name, } From 90479f547f4593c7aa7a1c9c938205791bb13eb7 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 28 Jun 2024 18:55:19 +0200 Subject: [PATCH 2/6] black --- flair/embeddings/token.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 669d3162a5..7cfbd73b9f 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -454,7 +454,7 @@ def __init__( path_to_char_dict: Optional[Union[str, Dictionary]] = None, char_embedding_dim: int = 25, hidden_size_char: int = 25, - name: str = "Char" + name: str = "Char", ) -> None: """Instantiates a bidirectional lstm layer toi encode words by their character representation. From 52c7b151ef4c5c1fc784c200d9f9a1adcc2a7e48 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Wed, 3 Jul 2024 15:29:32 -0700 Subject: [PATCH 3/6] fix: memory leak in TextPairRegressor when embed_separately=False keep reference to concatenated sentence that is created when not embedding data points separately in a DataPair. those embeddings are then able to be cleared in clear_embeddings, freeing the memory from GPU --- flair/data.py | 3 +++ flair/models/pairwise_regression_model.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/flair/data.py b/flair/data.py index bc35c83c5c..e4150cd736 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1223,6 +1223,7 @@ def __init__(self, first: DT, second: DT2) -> None: super().__init__() self.first = first self.second = second + self.concatenated_data: Optional[Union[DT, DT2]] = None def to(self, device: str, pin_memory: bool = False): self.first.to(device, pin_memory) @@ -1231,6 +1232,8 @@ def to(self, device: str, pin_memory: bool = False): def clear_embeddings(self, embedding_names: Optional[List[str]] = None): self.first.clear_embeddings(embedding_names) self.second.clear_embeddings(embedding_names) + if self.concatenated_data is not None: + self.concatenated_data.clear_embeddings(embedding_names) @property def embedding(self): diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index c267ea5a80..7527612441 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -178,14 +178,17 @@ def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torc 0, ) else: - concatenated_sentence = Sentence( - prediction_data_point.first.to_tokenized_string() - + self.sep - + prediction_data_point.second.to_tokenized_string(), - use_tokenizer=False, - ) - self.embeddings.embed(concatenated_sentence) - return concatenated_sentence.get_embedding(embedding_names) + # If the concatenated version of the text pair does not exist yet, create it + if prediction_data_point.concatenated_data is None: + concatenated_sentence = Sentence( + prediction_data_point.first.to_tokenized_string() + + self.sep + + prediction_data_point.second.to_tokenized_string(), + use_tokenizer=False, + ) + prediction_data_point.concatenated_data = concatenated_sentence + self.embeddings.embed(prediction_data_point.concatenated_data) + return prediction_data_point.concatenated_data.get_embedding(embedding_names) def _get_state_dict(self): model_state = { From b534186b3b0ff3c1c76c00626cce8cbd8a264bd4 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Wed, 3 Jul 2024 15:41:55 -0700 Subject: [PATCH 4/6] feat: improve exception handling in data classes --- flair/data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index e4150cd736..4137e72913 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1090,7 +1090,8 @@ def get_language_code(self) -> str: try: self.language_code = langdetect.detect(self.to_plain_string()) - except Exception: + except Exception as e: + log.debug(e) self.language_code = "en" return self.language_code @@ -1668,7 +1669,9 @@ def make_label_dictionary( [f"'{label[0]}' (in {label[1]} sentences)" for label in sentence_label_type_counter.most_common()] ) log.error(f"ERROR: The corpus contains the following label types: {contained_labels}") - raise Exception + raise ValueError( + f"You specified a label type ({label_type}) that is not contained in the corpus:\n{contained_labels}" + ) log.info( f"Dictionary created for label '{label_type}' with {len(label_dictionary)} " From 1a6f11067d9dd8b7295ca8110cee9c01144879bb Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Wed, 3 Jul 2024 16:32:06 -0700 Subject: [PATCH 5/6] refactor: add type hints to relevant files --- flair/data.py | 85 ++++++++++++----------- flair/models/lemmatizer_model.py | 2 +- flair/models/pairwise_regression_model.py | 2 +- 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/flair/data.py b/flair/data.py index 4137e72913..69d85baf92 100644 --- a/flair/data.py +++ b/flair/data.py @@ -5,8 +5,9 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict from operator import itemgetter +from os import PathLike from pathlib import Path -from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast +from typing import Any, DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast import torch from deprecated.sphinx import deprecated @@ -49,7 +50,7 @@ class BoundingBox(NamedTuple): class Dictionary: """This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.""" - def __init__(self, add_unk=True) -> None: + def __init__(self, add_unk: bool = True) -> None: # init dictionaries self.item2idx: Dict[bytes, int] = {} self.idx2item: List[bytes] = [] @@ -143,21 +144,21 @@ def is_span_prediction_problem(self) -> bool: def start_stop_tags_are_set(self) -> bool: return {b"", b""}.issubset(self.item2idx.keys()) - def save(self, savefile): + def save(self, savefile: PathLike): import pickle with open(savefile, "wb") as f: mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx} pickle.dump(mappings, f) - def __setstate__(self, d): + def __setstate__(self, d: Dict) -> None: self.__dict__ = d # set 'add_unk' if the dictionary was created with a version of Flair older than 0.9 if "add_unk" not in self.__dict__: self.__dict__["add_unk"] = b"" in self.__dict__["idx2item"] @classmethod - def load_from_file(cls, filename: Union[str, Path]): + def load_from_file(cls, filename: Union[str, Path]) -> "Dictionary": import pickle with Path(filename).open("rb") as f: @@ -174,7 +175,7 @@ def load_from_file(cls, filename: Union[str, Path]): return dictionary @classmethod - def load(cls, name: str): + def load(cls, name: str) -> "Dictionary": from flair.file_utils import cached_path hu_path: str = "https://flair.informatik.hu-berlin.de/resources/characters" @@ -282,11 +283,11 @@ class DataPoint: def __init__(self) -> None: self.annotation_layers: Dict[str, List[Label]] = {} self._embeddings: Dict[str, torch.Tensor] = {} - self._metadata: Dict[str, typing.Any] = {} + self._metadata: Dict[str, Any] = {} @property @abstractmethod - def embedding(self): + def embedding(self) -> torch.Tensor: pass def set_embedding(self, name: str, vector: torch.Tensor): @@ -316,7 +317,7 @@ def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> Lis embeddings.append(embed) return embeddings - def to(self, device: str, pin_memory: bool = False): + def to(self, device: str, pin_memory: bool = False) -> None: for name, vector in self._embeddings.items(): if str(vector.device) != str(device): if pin_memory: @@ -324,7 +325,7 @@ def to(self, device: str, pin_memory: bool = False): else: self._embeddings[name] = vector.to(device, non_blocking=True) - def clear_embeddings(self, embedding_names: Optional[List[str]] = None): + def clear_embeddings(self, embedding_names: Optional[List[str]] = None) -> None: if embedding_names is None: self._embeddings = {} else: @@ -332,19 +333,19 @@ def clear_embeddings(self, embedding_names: Optional[List[str]] = None): if name in self._embeddings: del self._embeddings[name] - def has_label(self, type) -> bool: + def has_label(self, type: str) -> bool: return type in self.annotation_layers - def add_metadata(self, key: str, value: typing.Any) -> None: + def add_metadata(self, key: str, value: Any) -> None: self._metadata[key] = value - def get_metadata(self, key: str) -> typing.Any: + def get_metadata(self, key: str) -> Any: return self._metadata[key] def has_metadata(self, key: str) -> bool: return key in self._metadata - def add_label(self, typename: str, value: str, score: float = 1.0, **metadata): + def add_label(self, typename: str, value: str, score: float = 1.0, **metadata) -> "DataPoint": label = Label(self, value, score, **metadata) if typename not in self.annotation_layers: @@ -358,16 +359,16 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): self.annotation_layers[typename] = [Label(self, value, score, **metadata)] return self - def remove_labels(self, typename: str): + def remove_labels(self, typename: str) -> None: if typename in self.annotation_layers: del self.annotation_layers[typename] - def get_label(self, label_type: Optional[str] = None, zero_tag_value="O"): + def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O") -> Label: if len(self.get_labels(label_type)) == 0: return Label(self, zero_tag_value) return self.get_labels(label_type)[0] - def get_labels(self, typename: Optional[str] = None): + def get_labels(self, typename: Optional[str] = None) -> List[Label]: if typename is None: return self.labels @@ -385,7 +386,7 @@ def labels(self) -> List[Label]: def unlabeled_identifier(self): raise NotImplementedError - def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True): + def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True) -> str: all_labels = [] keys = [main_label] if main_label is not None else self.annotation_layers.keys() @@ -431,7 +432,7 @@ def tag(self): def score(self): return self.labels[0].score - def __lt__(self, other): + def __lt__(self, other: "DataPoint"): return self.start_position < other.start_position def __len__(self) -> int: @@ -482,7 +483,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def to_dict(self) -> Dict[str, typing.Any]: + def to_dict(self) -> Dict[str, Any]: return { "concept_id": self.concept_id, "concept_name": self.concept_name, @@ -517,7 +518,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): super().set_label(typename, value, score, **metadata) return self - def remove_labels(self, typename: str): + def remove_labels(self, typename: str) -> None: # labels also need to be deleted at Sentence object for label in self.get_labels(typename): self.sentence.annotation_layers[typename].remove(label) @@ -567,7 +568,7 @@ def text(self) -> str: def unlabeled_identifier(self) -> str: return f'Token[{self.idx - 1}]: "{self.text}"' - def add_tags_proba_dist(self, tag_type: str, tags: List[Label]): + def add_tags_proba_dist(self, tag_type: str, tags: List[Label]) -> None: self.tags_proba_dist[tag_type] = tags def get_tags_proba_dist(self, tag_type: str) -> List[Label]: @@ -616,7 +617,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata): else: DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata) - def to_dict(self, tag_type: Optional[str] = None): + def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: return { "text": self.text, "start_pos": self.start_position, @@ -958,7 +959,7 @@ def right_context(self, context_length: int, respect_document_boundaries: bool = def __str__(self) -> str: return self.to_tagged_string() - def to_tagged_string(self, main_label=None) -> str: + def to_tagged_string(self, main_label: Optional[str] = None) -> str: already_printed = [self] output = super().__str__() @@ -978,7 +979,7 @@ def to_tagged_string(self, main_label=None) -> str: return output @property - def text(self): + def text(self) -> str: return self.to_original_text() def to_tokenized_string(self) -> str: @@ -987,7 +988,7 @@ def to_tokenized_string(self) -> str: return self.tokenized - def to_plain_string(self): + def to_plain_string(self) -> str: plain = "" for token in self.tokens: plain += token.text @@ -1036,7 +1037,7 @@ def to_original_text(self) -> str: [t.text + t.whitespace_after * " " for t in self.tokens] ).strip() - def to_dict(self, tag_type: Optional[str] = None): + def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]: return { "text": self.to_original_text(), "labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self], @@ -1045,7 +1046,7 @@ def to_dict(self, tag_type: Optional[str] = None): "tokens": [token.to_dict(tag_type) for token in self.tokens], } - def get_span(self, start: int, stop: int): + def get_span(self, start: int, stop: int) -> Span: span_slice = slice(start, stop) return self[span_slice] @@ -1308,7 +1309,7 @@ def text(self): class Image(DataPoint): - def __init__(self, data=None, imageURL=None) -> None: + def __init__(self, data=None, imageURL=None): super().__init__() self.data = data @@ -1407,7 +1408,7 @@ def downsample( downsample_dev: bool = True, downsample_test: bool = True, random_seed: Optional[int] = None, - ): + ) -> "Corpus": """Reduce all datasets in corpus proportionally to the given percentage.""" if downsample_train and self._train is not None: self._train = self._downsample_to_proportion(self._train, percentage, random_seed) @@ -1474,7 +1475,7 @@ def _filter_empty_sentences(dataset) -> Dataset: return subset - def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary: + def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dictionary: """Creates a dictionary of all tokens contained in the corpus. By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary. @@ -1496,7 +1497,7 @@ def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary: return vocab_dictionary - def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]: + def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> List[str]: tokens_and_frequencies = Counter(self._get_all_tokens()) tokens: List[str] = [] @@ -1565,20 +1566,20 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict: } @staticmethod - def _get_tokens_per_sentence(sentences): + def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> List[int]: return [len(x.tokens) for x in sentences] @staticmethod - def _count_sentence_labels(sentences): - label_count = defaultdict(lambda: 0) + def _count_sentence_labels(sentences: Iterable[Sentence]) -> DefaultDict[str, int]: + label_count: DefaultDict[str, int] = defaultdict(lambda: 0) for sent in sentences: for label in sent.labels: label_count[label.value] += 1 return label_count @staticmethod - def _count_token_labels(sentences, label_type): - label_count = defaultdict(lambda: 0) + def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> DefaultDict[str, int]: + label_count: DefaultDict[str, int] = defaultdict(lambda: 0) for sent in sentences: for token in sent.tokens: if label_type in token.annotation_layers: @@ -1894,7 +1895,7 @@ def __init__(self, datasets: Iterable[Dataset], ids: Iterable[str]) -> None: def __len__(self) -> int: return self.cumulative_sizes[-1] - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Sentence: if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") @@ -1906,11 +1907,11 @@ def __getitem__(self, idx): return sentence @property - def cummulative_sizes(self): + def cummulative_sizes(self) -> List[int]: return self.cumulative_sizes -def iob2(tags): +def iob2(tags: List) -> bool: """Converts the tags to the IOB2 format. Check that tags have a valid IOB format. @@ -1957,7 +1958,9 @@ def randomly_split_into_two_datasets( return Subset(dataset, first_dataset), Subset(dataset, second_dataset) -def get_spans_from_bio(bioes_tags: List[str], bioes_scores=None) -> List[typing.Tuple[List[int], float, str]]: +def get_spans_from_bio( + bioes_tags: List[str], bioes_scores: Optional[List[float]] = None +) -> List[typing.Tuple[List[int], float, str]]: # add a dummy "O" to close final prediction bioes_tags.append("O") # return complex list diff --git a/flair/models/lemmatizer_model.py b/flair/models/lemmatizer_model.py index 6700b089d0..6f0854d4b5 100644 --- a/flair/models/lemmatizer_model.py +++ b/flair/models/lemmatizer_model.py @@ -474,7 +474,7 @@ def predict( # option 1: greedy decoding if self.beam_size == 1: # predictions - predicted: List[List[int]] = [[] for _ in range(number_tokens)] + predicted: List[List[Union[int, float]]] = [[] for _ in range(number_tokens)] for _decode_step in range(max_length): # decode next character diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index 7527612441..c657a4fe6d 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -153,7 +153,7 @@ def _prepare_target_tensor(self, pairs: List[TextPair]): def _filter_data_point(self, pair: TextPair) -> bool: return len(pair) > 0 - def _encode_data_points(self, data_points: List[TextPair]): + def _encode_data_points(self, data_points: List[TextPair]) -> torch.Tensor: # get a tensor of data points data_point_tensor = torch.stack([self._get_embedding_for_data_point(data_point) for data_point in data_points]) From 05622e6f024dbb7e9f9269b6bef2ae0bb4fb3b1d Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 12 Jul 2024 01:15:40 -0700 Subject: [PATCH 6/6] fix: memory leak in TextPairRegressor when embed_separately=False same as fix for TextPairRegressor --- flair/models/pairwise_classification_model.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/flair/models/pairwise_classification_model.py b/flair/models/pairwise_classification_model.py index d449a83eae..262fd08cb5 100644 --- a/flair/models/pairwise_classification_model.py +++ b/flair/models/pairwise_classification_model.py @@ -84,14 +84,17 @@ def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torc 0, ) else: - concatenated_sentence = Sentence( - prediction_data_point.first.to_tokenized_string() - + self.sep - + prediction_data_point.second.to_tokenized_string(), - use_tokenizer=False, - ) - self.embeddings.embed(concatenated_sentence) - return concatenated_sentence.get_embedding(embedding_names) + # If the concatenated version of the text pair does not exist yet, create it + if prediction_data_point.concatenated_data is None: + concatenated_sentence = Sentence( + prediction_data_point.first.to_tokenized_string() + + self.sep + + prediction_data_point.second.to_tokenized_string(), + use_tokenizer=False, + ) + prediction_data_point.concatenated_data = concatenated_sentence + self.embeddings.embed(prediction_data_point.concatenated_data) + return prediction_data_point.concatenated_data.get_embedding(embedding_names) def _get_state_dict(self): model_state = {