Skip to content

Commit

Permalink
Merge pull request #3432 from flairNLP/span_classifier_ner_type
Browse files Browse the repository at this point in the history
add prediction label type for span classifier
  • Loading branch information
alanakbik authored Jul 8, 2024
2 parents 355e2c0 + 1aec900 commit c6a2643
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
ner_corpus = NER_MULTI_WIKINER()
nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one
nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one

# --- Embeddings that are shared by both models --- #
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)
Expand All @@ -171,12 +171,13 @@ ner_model = SequenceTagger(
)


nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True)
nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True)

nel_model = SpanClassifier(
embeddings=shared_embeddings,
label_dictionary=nel_label_dict,
label_type="ner",
label_type="nel",
span_label_type="ner",
decoder=PrototypicalDecoder(
num_prototypes=len(nel_label_dict),
embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans
Expand Down
4 changes: 2 additions & 2 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def __init__(
test, train = randomly_split_into_two_datasets(train, test_size, random_seed)
log.warning(
"No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data",
test_portion,
test_portion * 100,
test_size,
)

Expand All @@ -1375,7 +1375,7 @@ def __init__(
dev, train = randomly_split_into_two_datasets(train, dev_size, random_seed)
log.warning(
"No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data",
dev_portion,
dev_portion * 100,
dev_size,
)

Expand Down
11 changes: 11 additions & 0 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
label_dictionary: Dictionary,
pooling_operation: str = "first_last",
label_type: str = "nel",
span_label_type: Optional[str] = None,
candidates: Optional[CandidateGenerator] = None,
**classifierargs,
) -> None:
Expand All @@ -107,6 +108,7 @@ def __init__(
text representation we take the average of the embeddings of the token in the mention.
`first_last` concatenates the embedding of the first and the embedding of the last token.
label_type: name of the label you use.
span_label_type: name of the label you use for inputs of predictions.
candidates: If provided, use a :class:`CandidateGenerator` for prediction candidates.
**classifierargs: The arguments propagated to :meth:`flair.nn.DefaultClassifier.__init__`
"""
Expand All @@ -121,6 +123,7 @@ def __init__(

self.pooling_operation = pooling_operation
self._label_type = label_type
self._span_label_type = span_label_type

cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = {
"average": self.emb_mean,
Expand Down Expand Up @@ -153,9 +156,16 @@ def emb_mean(self, span, embedding_names):
return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0)

def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]:
if self._span_label_type is not None:
spans = sentence.get_spans(self._span_label_type)
# only use span label type if there are predictions, otherwise search for output label type (training labels)
if spans:
return spans
return sentence.get_spans(self.label_type)

def _filter_data_point(self, data_point: Sentence) -> bool:
if self._span_label_type is not None and bool(data_point.get_labels(self._span_label_type)):
return True
return bool(data_point.get_labels(self.label_type))

def _get_embedding_for_data_point(self, prediction_data_point: Span) -> torch.Tensor:
Expand All @@ -170,6 +180,7 @@ def _get_state_dict(self):
"pooling_operation": self.pooling_operation,
"loss_weights": self.weight_dict,
"candidates": self.candidates,
"span_label_type": self._span_label_type,
}
return model_state

Expand Down

0 comments on commit c6a2643

Please sign in to comment.