From 6f0b184178a14e785f0feb2cf3ec562c23eaef09 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 22 Mar 2024 11:35:06 +0100 Subject: [PATCH 1/4] add prediction label type for span classifier --- .../tutorial-training/how-to-train-span-classifier.md | 7 ++++--- flair/models/entity_linker_model.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md index e1d916ff7d..e5d32cb426 100644 --- a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md +++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md @@ -154,7 +154,7 @@ from flair.nn.multitask import make_multitask_model_and_corpus # 1. get the corpus ner_corpus = NER_MULTI_WIKINER() -nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one +nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one # --- Embeddings that are shared by both models --- # shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True) @@ -171,12 +171,13 @@ ner_model = SequenceTagger( ) -nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True) +nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True) nel_model = SpanClassifier( embeddings=shared_embeddings, label_dictionary=nel_label_dict, - label_type="ner", + label_type="nel", + span_label_type="ner", decoder=PrototypicalDecoder( num_prototypes=len(nel_label_dict), embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 1d716e7904..d003a99f56 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -94,6 +94,7 @@ def __init__( label_dictionary: Dictionary, pooling_operation: str = "first_last", label_type: str = "nel", + span_label_type: Optional[str] = None, candidates: Optional[CandidateGenerator] = None, **classifierargs, ) -> None: @@ -107,6 +108,7 @@ def __init__( text representation we take the average of the embeddings of the token in the mention. `first_last` concatenates the embedding of the first and the embedding of the last token. label_type: name of the label you use. + span_label_type: name of the label you use for inputs of predictions. candidates: If provided, use a :class:`CandidateGenerator` for prediction candidates. **classifierargs: The arguments propagated to :meth:`flair.nn.DefaultClassifier.__init__` """ @@ -121,6 +123,7 @@ def __init__( self.pooling_operation = pooling_operation self._label_type = label_type + self._span_label_type = span_label_type cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = { "average": self.emb_mean, @@ -153,6 +156,11 @@ def emb_mean(self, span, embedding_names): return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0) def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: + if self._span_label_type is not None: + spans = sentence.get_spans(self._span_label_type) + # only use span label type if there are predictions, otherwise search for output label type (training labels) + if spans: + return spans return sentence.get_spans(self.label_type) def _filter_data_point(self, data_point: Sentence) -> bool: @@ -170,6 +178,7 @@ def _get_state_dict(self): "pooling_operation": self.pooling_operation, "loss_weights": self.weight_dict, "candidates": self.candidates, + "span_label_type": self._span_label_type, } return model_state From efc13da6911ae2d3218fd3efc1a79c29db1736f3 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Sun, 30 Jun 2024 20:52:26 +0200 Subject: [PATCH 2/4] Adapt _filter_data_point to enable prediction --- flair/models/entity_linker_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index d003a99f56..c613fe952f 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -164,7 +164,10 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: return sentence.get_spans(self.label_type) def _filter_data_point(self, data_point: Sentence) -> bool: - return bool(data_point.get_labels(self.label_type)) + if self._span_label_type is not None: + return bool(data_point.get_labels(self._span_label_type)) + else: + return bool(data_point.get_labels(self.label_type)) def _get_embedding_for_data_point(self, prediction_data_point: Span) -> torch.Tensor: return self.aggregated_embedding(prediction_data_point, self.embeddings.get_names()) From 1447745a9231b6b252b228929c7329775f39c09d Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 5 Jul 2024 18:55:11 +0200 Subject: [PATCH 3/4] allow training on labels even if span label type is already set. --- flair/data.py | 4 ++-- flair/models/entity_linker_model.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flair/data.py b/flair/data.py index 77fff1f200..bc35c83c5c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1363,7 +1363,7 @@ def __init__( test, train = randomly_split_into_two_datasets(train, test_size, random_seed) log.warning( "No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data", - test_portion, + test_portion * 100, test_size, ) @@ -1375,7 +1375,7 @@ def __init__( dev, train = randomly_split_into_two_datasets(train, dev_size, random_seed) log.warning( "No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data", - dev_portion, + dev_portion * 100, dev_size, ) diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index c613fe952f..0154883446 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -165,9 +165,9 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: def _filter_data_point(self, data_point: Sentence) -> bool: if self._span_label_type is not None: - return bool(data_point.get_labels(self._span_label_type)) - else: - return bool(data_point.get_labels(self.label_type)) + if bool(data_point.get_labels(self._span_label_type)): + return True + return bool(data_point.get_labels(self.label_type)) def _get_embedding_for_data_point(self, prediction_data_point: Span) -> torch.Tensor: return self.aggregated_embedding(prediction_data_point, self.embeddings.get_names()) From 1aec900b9a039251b6875535a3c61378b18012f6 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 5 Jul 2024 19:00:21 +0200 Subject: [PATCH 4/4] code formatting --- flair/models/entity_linker_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 0154883446..9f516a703c 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -164,9 +164,8 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: return sentence.get_spans(self.label_type) def _filter_data_point(self, data_point: Sentence) -> bool: - if self._span_label_type is not None: - if bool(data_point.get_labels(self._span_label_type)): - return True + if self._span_label_type is not None and bool(data_point.get_labels(self._span_label_type)): + return True return bool(data_point.get_labels(self.label_type)) def _get_embedding_for_data_point(self, prediction_data_point: Span) -> torch.Tensor: