From eede2a4c0e3a91e3b0fb9a8fcd3271b0f7fe0e88 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:30:26 +0100 Subject: [PATCH] Increase robustness of `is_split_into_words` check: resolves ValueError (#39) * Increase robustness of "is_split_into_words" check * Update the changelog * Add a test case --- CHANGELOG.md | 1 + span_marker/tokenizer.py | 8 ++++++-- tests/test_modeling.py | 7 +++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e184b39e..39c24bc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Types of changes ### Fixed - No longer override `language` metadata from the dataset if the language was also set manually via `SpanMarkerModelCardData`. +- No longer crash on `predict` with `ValueError: Failed to concatenate on axis=1 ...` if the first sentence in a list of sentences is just one word. ## [1.4.0] diff --git a/span_marker/tokenizer.py b/span_marker/tokenizer.py index 89135f2a..dedbc4e9 100644 --- a/span_marker/tokenizer.py +++ b/span_marker/tokenizer.py @@ -180,10 +180,14 @@ def __call__( ) -> Dict[str, List]: tokens = batch["tokens"] labels = batch.get("ner_tags", None) - # TODO: Increase robustness of this is_split_into_words = True - if isinstance(tokens, str) or (tokens and " " in tokens[0]): + if isinstance(tokens, str): is_split_into_words = False + elif tokens: + for token in tokens: + if " " in token: + is_split_into_words = False + break batch_encoding = self.tokenizer( tokens, diff --git a/tests/test_modeling.py b/tests/test_modeling.py index 954c1c93..b17313e8 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -195,6 +195,13 @@ def test_correct_predictions_with_document_level_context( ) +def test_predict_where_first_sentence_is_word(finetuned_conll_span_marker_model: SpanMarkerModel) -> None: + model = finetuned_conll_span_marker_model.try_cuda() + outputs = model.predict(["One", "Two Three Four Five"]) + assert len(outputs) == 2 + assert isinstance(outputs[0], list) + + def test_incorrect_predict_inputs(finetuned_conll_span_marker_model: SpanMarkerModel): model = finetuned_conll_span_marker_model.try_cuda() with pytest.raises(ValueError, match="could not recognize your input"):