Skip to content

Commit

Permalink
Merge pull request #3372 from flairNLP/fix_has_unknown_label
Browse files Browse the repository at this point in the history
fix has unknown label is not always initialized
  • Loading branch information
alanakbik authored Nov 12, 2023
2 parents 16c88a5 + 7d20d4f commit f30f580
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ def predict(

overall_loss = torch.zeros(1, device=flair.device)
label_count = 0
has_any_unknown_label = False
for batch in batches:
# filter data points in batch
batch = [dp for dp in batch if self._filter_data_point(dp)]
Expand Down Expand Up @@ -865,6 +866,7 @@ def predict(
has_unknown_label = True

if has_unknown_label:
has_any_unknown_label = True
scores = torch.index_select(scores, 0, torch.tensor(filtered_indices, device=flair.device))

gold_labels = self._prepare_label_tensor([data_points[index] for index in filtered_indices])
Expand Down Expand Up @@ -908,7 +910,7 @@ def predict(
self._post_process_batch_after_prediction(batch, label_name)

if return_loss:
if has_unknown_label:
if has_any_unknown_label:
log.info(
"During evaluation, encountered labels that are not in the label_dictionary:"
"Evaluation loss is computed without them."
Expand Down

0 comments on commit f30f580

Please sign in to comment.