Skip to content

Commit

Permalink
Update pre-processing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ColinDaglish committed Jul 21, 2023
1 parent 9b8ee7a commit a442636
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 115 deletions.
49 changes: 10 additions & 39 deletions src/modules/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def _replace_blanks(series: Series) -> Series:
return blanks_replaced


def shorten_tokens(word_tokens: list, lemmatize: bool = True) -> list:
def shorten_tokens(word_tokens: Series, lemmatize: bool = True) -> list:
"""Shorten tokens to root words
Parameters
----------
word_tokens:list
list of word tokens to shorten
word_tokens:Series
Series of listed word tokens
lemmatize: bool, default = True
whether to use lemmatizer or revert back to False (stemmer)"""
if lemmatize:
Expand Down Expand Up @@ -165,43 +165,14 @@ def _initialise_nltk_component(extension: str, download_object: str):
None
"""
if sys.platform.startswith("linux"):
_initialise_nltk_linux(download_object)
else:
_initialise_nltk_windows(extension, download_object)


def _initialise_nltk_linux(download_object: str) -> None:
"""initialise nltk component for linux environment (for github actions)
Parameters
----------
download_object: str
nltk object to download
Returns
-------
None
"""
nltk.download(download_object)
nltk.data.path.append("../home/runner/nltk_data")
return None


def _initialise_nltk_windows(extension: str, download_object: str):
"""initialise nltk component for a windows environment
Parameters
----------
extension: str
the filepath extension leading to where the model is saved
download_object: str
the object to download from nltk
Returns
-------
None
"""
username = os.getenv("username")
path = "C:/Users/" + username + "/AppData/Roaming/nltk_data/" + extension
if not os.path.exists(path):
nltk.download(download_object)
nltk.data.path.append("../local_packages/nltk_data")
nltk.data.path.append("../home/runner/nltk_data")
else:
username = os.getenv("username")
path = "C:/Users/" + username + "/AppData/Roaming/nltk_data/" + extension
if not os.path.exists(path):
nltk.download(download_object)
nltk.data.path.append("../local_packages/nltk_data")

Check warning on line 175 in src/modules/preprocessing.py

View check run for this annotation

Codecov / codecov/patch

src/modules/preprocessing.py#L171-L175

Added lines #L171 - L175 were not covered by tests
return None


Expand Down
122 changes: 46 additions & 76 deletions tests/modules/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,15 @@
from nltk.corpus import stopwords as sw
from pandas import Series

from src.modules.preprocessing import ( # _correct_spelling,; _remove_punctuation_string,; _update_spelling_words,; remove_punctuation,; spellcorrect_series, # noqa:E501
_initialise_nltk_component,
_replace_blanks,
_update_nltk_stopwords,
initialise_update_stopwords,
lemmatizer,
load_config,
rejoin_tokens,
remove_blank_rows,
remove_nltk_stopwords,
stemmer,
)
from src.modules import preprocessing as prep


class TestLoadConfig:
def test_input_type_error(self):
"""test for assertion error"""
bad_input = 123
with pytest.raises(Exception) as e_info:
load_config(bad_input)
prep.load_config(bad_input)
assert (
str(e_info.value) == "filepath must be a string"
), "Did not raise TypeError"
Expand All @@ -31,31 +20,50 @@ def test_input_file_not_found(self):
"""test for error feedback on file not found"""
bad_input = "src/superman.yaml"
with pytest.raises(Exception) as e_info:
load_config(bad_input)
prep.load_config(bad_input)
assert (
str(e_info.value.args[1]) == "No such file or directory"
), "Did not raise file not found error"

def test_return_dict(self):
assert (
type(load_config("src/config.yaml")) is dict
type(prep.load_config("src/config.yaml")) is dict
), "output is not <class 'dict'>"


class TestPrependStrToListObjects:
def test_prepend_str_to_list_objects(self):
list_x = [1, 2, 3]
string_x = "qu_"
expected = ["qu_1", "qu_2", "qu_3"]
actual = prep.prepend_str_to_list_objects(list_x, string_x)
assert actual == expected, "did not correctly prepend string to list objects"


class TestGetResponseLength:
def test_get_response_length(self):
series = Series(["hello", "world"])
expected = Series([5, 5])
actual = prep.get_response_length(series)
assert all(
expected == actual
), "Did not correctly identify the response lengths"


class TestRemoveBlankRows:
def test_blank_rows_removed(self):
"""test that blank rows are removed"""
series_with_empty_row = Series([1.0, "", 3.0])
expected_outcome = Series([1.0, 3.0])
actual = remove_blank_rows(series_with_empty_row)
actual = prep.remove_blank_rows(series_with_empty_row)
actual_reindexed = actual.reset_index(drop=True)
assert all(
actual_reindexed == expected_outcome
), "function does not remove blank rows"

def test_return_series(self):
"""test that function returns a Series"""
actual = remove_blank_rows(Series([1.0, "", 3.0]))
actual = prep.remove_blank_rows(Series([1.0, "", 3.0]))
assert (
type(actual) is Series
), "output is not <class 'pandas.core.series.Series'>"
Expand All @@ -65,120 +73,82 @@ class TestReplaceBlanks:
def test_blank_replacement(self):
"""test replace blanks with NaN"""
series_with_empty_row = Series([1.0, "", 3.0])
actual = _replace_blanks(series_with_empty_row)
actual = prep._replace_blanks(series_with_empty_row)
assert np.isnan(actual[1]), "did not replace blank with NaN"

def test_return_series(self):
"""test that function returns a Series"""
actual = remove_blank_rows(Series([1.0, "", 3.0]))
actual = prep.remove_blank_rows(Series([1.0, "", 3.0]))
assert (
type(actual) is Series
), "output is not <class 'pandas.core.series.Series'>"


# class TestSpellCorrectSeries:
# def test_spell_correct_series(self):
# series = Series(["I live in a housr", "I own a housr"])
# actual = spellcorrect_series(series)
# expected = Series(["I live in a house", "I own a house"])
# assert all(actual == expected), "Not fixed spelling across series"

# def test_update_spelling_on_series(self):
# series = Series(["I live in a housr", "I own a housr"])
# additional_words = {"housr": 1}
# actual = spellcorrect_series(series, additional_words)
# expected = Series(["I live in a housr", "I own a housr"])
# assert all(actual == expected), "Updated spelling doesn't work across series" # noqa:E501


# class TestCorrectSpelling:
# def test_spelling_fixed(self):
# house_str = "I live flar away"
# corrected = _correct_spelling(house_str)
# assert corrected == "I live far away", "spelling not fixed correctly"
class TestShortenTokens:
def test_shorten_tokens_lemmatize(self):
words = Series([["houses"]])
expected = Series([["house"]])
actual = prep.shorten_tokens(words, True)
assert all(actual == expected), "Did not lemmatize correctly over series"


# class TestUpdateSpellingWords:
# def test_update_word_list(self):
# additional_words = {"monsterp": 1}
# tb.en.spelling = _update_spelling_words(additional_words)
# assert (
# "monsterp" in tb.en.spelling.keys()
# ), "spelling word list not updated correctly"


# class TestRemovePunctuation:
# def test_remove_punctuation(self):
# series = Series(["this is!", "my series?"])
# actual = remove_punctuation(series)
# expected = Series(["this is", "my series"])
# assert all(actual == expected), "Remove punctuation not working on series"


# class TestRemovePunctuationstring:
# def test_remove_punctuation(self):
# test_string = "my #$%&()*+,-./:;<=>?@[]^_`{|}~?name"
# actual = _remove_punctuation_string(test_string)
# expected = "my name"
# assert actual == expected, "punctuation not removed correctly"
def test_shorten_tokens_stemmer(self):
words = Series([["houses"]])
expected = Series([["hous"]])
actual = prep.shorten_tokens(words, False)
assert all(actual == expected), "Did not stemmer correctly over series"


class TestStemmer:
def test_stemmer(self):
word_list = ["flying", "fly", "Beautiful", "Beauty"]
actual = stemmer(word_list)
actual = prep.stemmer(word_list)
expected = ["fli", "fli", "beauti", "beauti"]
assert actual == expected, "words are not being stemmed correctly"


class TestLemmatizer:
def test_lemmatization(self):
word_list = ["house", "houses", "housing"]
actual = lemmatizer(word_list)
actual = prep.lemmatizer(word_list)
expected = ["house", "house", "housing"]
assert actual == expected, "words are not being lemmatized correctly"


class TestRemoveNLTKStopwords:
def test_remove_standard_stopwords(self):
tokens = ["my", "name", "is", "elf", "who", "are", "you"]
actual = remove_nltk_stopwords(tokens)
actual = prep.remove_nltk_stopwords(tokens)
expected = ["name", "elf"]
assert actual == expected, "core stopwords not being removed correctly"

def test_remove_additional_stopwords(self):
tokens = ["my", "name", "is", "elf", "who", "are", "you"]
actual = remove_nltk_stopwords(tokens, ["elf"])
actual = prep.remove_nltk_stopwords(tokens, ["elf"])
expected = ["name"]
assert actual == expected, "additional stopwords not being removed correctly"


class TestInitialiseUpdateStopwords:
def test_add_word_to_stopwords(self):
additional_words = ["elf", "santa"]
new_stopwords = initialise_update_stopwords(additional_words)
new_stopwords = prep.initialise_update_stopwords(additional_words)
actual = [word in new_stopwords for word in additional_words]
assert all(actual), "new words not added to stopwords"


class TestUpdateNLTKStopwords:
def test_add_word_to_stopwords(self):
_initialise_nltk_component("corpora/stopwords", "stopwords")
prep._initialise_nltk_component("corpora/stopwords", "stopwords")
stopwords = sw.words("english")
additional_words = ["elf", "santa"]
new_stopwords = _update_nltk_stopwords(stopwords, additional_words)
new_stopwords = prep._update_nltk_stopwords(stopwords, additional_words)
actual = [word in new_stopwords for word in additional_words]
assert all(actual), "new words not added to stopwords"


class TestRejoinTokens:
def test_region_tokens(self):
tokens = ["my", "name", "is", "elf"]
actual = rejoin_tokens(tokens)
actual = prep.rejoin_tokens(tokens)
expected = "my name is elf"
assert actual == expected, "did not rejoin tokens correctly"


class TestInitialiseNLTKComponent:
def test_initialise_component(self):
pass

0 comments on commit a442636

Please sign in to comment.