diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index 52b59d1f2..1c1936267 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -232,10 +232,7 @@ def _all_scores_for_token( scores = scores.numpy() for i_batch, batch in enumerate(scores): for i, (tag_id, tag_scores) in enumerate(zip(tag_seq, batch)): - if isinstance(tag_id, int): - tag_id_int = tag_id - else: - tag_id_int = int(tag_id.item()) + tag_id_int = tag_id if isinstance(tag_id, int) else int(tag_id.item()) if tag_id_int != np.argmax(tag_scores): swap_index_score = int(np.argmax(tag_scores))