Skip to content

Commit

Permalink
Merge pull request #3528 from flairNLP/fix-doc-build
Browse files Browse the repository at this point in the history
fix doc build
  • Loading branch information
alanakbik authored Aug 19, 2024
2 parents 091e62b + 975931c commit b16a2f3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 247 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
uses: actions/cache@v3
with:
path: ./cache
key: cache-v1.1
key: cache-v1.2
- name: Run tests
run: |
python -c 'import flair'
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ resources/taggers/
regression_train/
/doc_build/

scripts/
scripts/
7 changes: 6 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ sphinx
importlib-metadata
sphinx-multiversion
pydata-sphinx-theme<0.14
sphinx_design
sphinx_design

# previous dependencies that are required to build docs for later versions too.
semver
gensim
bpemb
245 changes: 1 addition & 244 deletions tests/test_datasets_biomedical.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import inspect
import logging
import os
import tempfile
from operator import itemgetter
from pathlib import Path
from typing import Callable, List, Optional, Type
from typing import List, Optional

import pytest
from tqdm import tqdm

import flair
from flair.data import Sentence, Token, _iter_dataset
from flair.datasets import ColumnCorpus, biomedical
from flair.datasets.biomedical import (
CoNLLWriter,
Entity,
HunerDataset,
InternalBioNerDataset,
filter_nested_entities,
)
Expand All @@ -26,80 +17,6 @@
logger.propagate = True


def has_balanced_parantheses(text: str) -> bool:
stack = []
opening = ["(", "[", "{"]
closing = [")", "]", "}"]
for c in text:
if c in opening:
stack.append(c)
elif c in closing:
if not stack:
return False
last_paren = stack.pop()
if opening.index(last_paren) != closing.index(c):
return False

return len(stack) == 0


def gene_predicate(member):
return inspect.isclass(member) and "HUNER_GENE_" in str(member)


def chemical_predicate(member):
return inspect.isclass(member) and "HUNER_CHEMICAL_" in str(member)


def disease_predicate(member):
return inspect.isclass(member) and "HUNER_DISEASE_" in str(member)


def species_predicate(member):
return inspect.isclass(member) and "HUNER_SPECIES_" in str(member)


def cellline_predicate(member):
return inspect.isclass(member) and "HUNER_CELL_LINE_" in str(member)


CELLLINE_DATASETS = [
i[1] for i in sorted(inspect.getmembers(biomedical, predicate=cellline_predicate), key=itemgetter(0))
]
CHEMICAL_DATASETS = [
i[1] for i in sorted(inspect.getmembers(biomedical, predicate=chemical_predicate), key=itemgetter(0))
]
DISEASE_DATASETS = [
i[1] for i in sorted(inspect.getmembers(biomedical, predicate=disease_predicate), key=itemgetter(0))
]
GENE_DATASETS = [i[1] for i in sorted(inspect.getmembers(biomedical, predicate=gene_predicate), key=itemgetter(0))]
SPECIES_DATASETS = [
i[1] for i in sorted(inspect.getmembers(biomedical, predicate=species_predicate), key=itemgetter(0))
]
ALL_DATASETS = CELLLINE_DATASETS + CHEMICAL_DATASETS + DISEASE_DATASETS + GENE_DATASETS + SPECIES_DATASETS


def simple_tokenizer(text: str) -> List[str]:
tokens: List[str] = []
word = ""
index = -1
for index, char in enumerate(text):
if char == " " or char == "-":
if len(word) > 0:
tokens.append(word)

word = ""
else:
word += char

# increment for last token in sentence if not followed by whitespace
index += 1
if len(word) > 0:
tokens.append(word)

return tokens


def test_write_to_conll():
text = "This is entity1 entity2 and a long entity3"
dataset = InternalBioNerDataset(
Expand Down Expand Up @@ -220,163 +137,3 @@ def test_filter_nested_entities(caplog):
sorted(entities, key=lambda x: str(x)),
):
assert str(e1) == str(e2)


def sanity_check_all_corpora(check: Callable[[ColumnCorpus], None]):
for _, CorpusType in tqdm(ALL_DATASETS):
corpus = CorpusType()
check(corpus)


@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
def test_sanity_not_starting_with_minus(CorpusType: Type[ColumnCorpus]):
corpus = CorpusType() # type: ignore[call-arg]
entities_starting_with_minus = []
for sentence in _iter_dataset(corpus.get_all_sentences()):
entities = sentence.get_spans("ner")
for entity in entities:
if str(entity.tokens[0].text).startswith("-"):
entities_starting_with_minus.append(" ".join([t.text for t in entity.tokens]))

assert len(entities_starting_with_minus) == 0, "|".join(entities_starting_with_minus)


@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_sanity_no_repeating_Bs(CorpusType: Type[ColumnCorpus]):
corpus = CorpusType() # type: ignore[call-arg]
longest_repeat_tokens: List[Token] = []
repeat_tokens: List[Token] = []
for sentence in _iter_dataset(corpus.get_all_sentences()):
for token in sentence.tokens:
if token.get_labels()[0].value.startswith("B") or token.get_labels()[0].value.startswith("S"):
repeat_tokens.append(token)
else:
if len(repeat_tokens) > len(longest_repeat_tokens):
longest_repeat_tokens = repeat_tokens
repeat_tokens = []

assert len(longest_repeat_tokens) < 4


@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_sanity_no_long_entities(CorpusType: Type[ColumnCorpus]):
corpus = CorpusType() # type: ignore[call-arg]
longest_entity: List[str] = []
for sentence in _iter_dataset(corpus.get_all_sentences()):
entities = sentence.get_spans("ner")
for entity in entities:
if len(entity.tokens) > len(longest_entity):
longest_entity = [t.text for t in entity.tokens]

assert len(longest_entity) < 10, " ".join(longest_entity)


@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_sanity_no_unmatched_parentheses(CorpusType: Type[ColumnCorpus]):
corpus = CorpusType() # type: ignore[call-arg]
unbalanced_entities = []
for sentence in _iter_dataset(corpus.get_all_sentences()):
entities = sentence.get_spans("ner")
for entity in entities:
entity_text = "".join(t.text for t in entity.tokens)
if not has_balanced_parantheses(entity_text):
unbalanced_entities.append(entity_text)

assert unbalanced_entities == []


@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_sanity_not_too_many_entities(CorpusType: Type[ColumnCorpus]):
corpus = CorpusType() # type: ignore[call-arg]
n_entities_per_sentence = []
for sentence in _iter_dataset(corpus.get_all_sentences()):
entities = sentence.get_spans("ner")
n_entities_per_sentence.append(len(entities))
avg_entities_per_sentence = sum(n_entities_per_sentence) / len(n_entities_per_sentence)

assert avg_entities_per_sentence <= 5


@pytest.mark.parametrize("CorpusType", ALL_DATASETS)
@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_sanity_no_misaligned_entities(CorpusType: Type[HunerDataset]):
dataset_name = CorpusType.__class__.__name__.lower()
base_path = flair.cache_root / "datasets"
data_folder = base_path / dataset_name

corpus = CorpusType()
internal = corpus.to_internal(data_folder)
for doc_id, _doc_text in internal.documents.items():
misaligned_starts = []
misaligned_ends: List[int] = []

entities = internal.entities_per_document[doc_id]
entity_starts = [i.char_span.start for i in entities]
entity_ends = [i.char_span.stop for i in entities]

for start in entity_starts:
if start not in entity_starts:
misaligned_starts.append(start)

for end in entity_ends:
if end not in entity_ends:
misaligned_starts.append(end)

assert len(misaligned_starts) <= len(entities) // 10
assert len(misaligned_ends) <= len(entities) // 10


@pytest.mark.skip(reason="We skip this test because it's only relevant for development purposes")
def test_scispacy_tokenization():
from flair.tokenization import SciSpacyTokenizer

spacy_tokenizer = SciSpacyTokenizer()

sentence = Sentence("HBeAg(+) patients", use_tokenizer=spacy_tokenizer)
assert len(sentence) == 5
assert sentence[0].text == "HBeAg"
assert sentence[0].start_position == 0
assert sentence[1].text == "("
assert sentence[1].start_position == 5
assert sentence[2].text == "+"
assert sentence[2].start_position == 6
assert sentence[3].text == ")"
assert sentence[3].start_position == 7
assert sentence[4].text == "patients"
assert sentence[4].start_position == 9

sentence = Sentence("HBeAg(+)/HBsAg(+)", use_tokenizer=spacy_tokenizer)
assert len(sentence) == 9

assert sentence[0].text == "HBeAg"
assert sentence[0].start_position == 0
assert sentence[1].text == "("
assert sentence[1].start_position == 5
assert sentence[2].text == "+"
assert sentence[2].start_position == 6
assert sentence[3].text == ")"
assert sentence[3].start_position == 7
assert sentence[4].text == "/"
assert sentence[4].start_position == 8
assert sentence[5].text == "HBsAg"
assert sentence[5].start_position == 9
assert sentence[6].text == "("
assert sentence[6].start_position == 14
assert sentence[7].text == "+"
assert sentence[7].start_position == 15
assert sentence[8].text == ")"
assert sentence[8].start_position == 16

sentence = Sentence("doxorubicin (DOX)-induced", use_tokenizer=spacy_tokenizer)

assert len(sentence) == 5
assert sentence[0].text == "doxorubicin"
assert sentence[1].text == "("
assert sentence[2].text == "DOX"
assert sentence[3].text == ")"
assert sentence[4].text == "-induced"

0 comments on commit b16a2f3

Please sign in to comment.