Skip to content

Commit

Permalink
Increase robustness of is_split_into_words check: resolves ValueErr…
Browse files Browse the repository at this point in the history
…or (#39)

* Increase robustness of "is_split_into_words" check

* Update the changelog

* Add a test case
  • Loading branch information
tomaarsen authored Oct 31, 2023
1 parent 2644068 commit eede2a4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Types of changes
### Fixed

- No longer override `language` metadata from the dataset if the language was also set manually via `SpanMarkerModelCardData`.
- No longer crash on `predict` with `ValueError: Failed to concatenate on axis=1 ...` if the first sentence in a list of sentences is just one word.

## [1.4.0]

Expand Down
8 changes: 6 additions & 2 deletions span_marker/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,14 @@ def __call__(
) -> Dict[str, List]:
tokens = batch["tokens"]
labels = batch.get("ner_tags", None)
# TODO: Increase robustness of this
is_split_into_words = True
if isinstance(tokens, str) or (tokens and " " in tokens[0]):
if isinstance(tokens, str):
is_split_into_words = False
elif tokens:
for token in tokens:
if " " in token:
is_split_into_words = False
break

batch_encoding = self.tokenizer(
tokens,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ def test_correct_predictions_with_document_level_context(
)


def test_predict_where_first_sentence_is_word(finetuned_conll_span_marker_model: SpanMarkerModel) -> None:
model = finetuned_conll_span_marker_model.try_cuda()
outputs = model.predict(["One", "Two Three Four Five"])
assert len(outputs) == 2
assert isinstance(outputs[0], list)


def test_incorrect_predict_inputs(finetuned_conll_span_marker_model: SpanMarkerModel):
model = finetuned_conll_span_marker_model.try_cuda()
with pytest.raises(ValueError, match="could not recognize your input"):
Expand Down

0 comments on commit eede2a4

Please sign in to comment.