Skip to content

Commit

Permalink
apply lint
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Aug 15, 2023
1 parent ac508fe commit bd7fb1d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 33 deletions.
16 changes: 6 additions & 10 deletions src/nlpsig/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def _check_feature_exists(self, feature: str) -> bool:
# not in ._feature_list, but is a valid column name in self.df,
# so add to feature list
self._feature_list += [feature]

return feature in self._feature_list

def _obtain_feature_columns(
self,
features: list[str] | str | None,
Expand Down Expand Up @@ -369,8 +369,8 @@ def _obtain_feature_columns(
# convert to list of strings
if isinstance(features, str):
features = [features]
if isinstance(features, list):

if isinstance(features, list):
# check each item in features is in self._feature_list
# if it isn't, but is a column in self.df, it will add
# it to self._feature_list
Expand Down Expand Up @@ -777,9 +777,7 @@ def pad(
raise ValueError("`method` must be either 'k_last' or 'max'.")

# obtain feature colnames
feature_colnames = self._obtain_feature_columns(
features=features
)
feature_colnames = self._obtain_feature_columns(features=features)
if len(feature_colnames) > 0:
if isinstance(standardise_method, str):
standardise_method = [standardise_method] * len(feature_colnames)
Expand Down Expand Up @@ -881,9 +879,7 @@ def get_time_feature(
(can be found in `._feature_list` attribute).
"""
if time_feature not in self._feature_list:
raise ValueError(
f"`time_feature` should be in {self._feature_list}."
)
raise ValueError(f"`time_feature` should be in {self._feature_list}.")

if not self.time_features_added:
self.set_time_features()
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_df_with_datetime():
return pd.DataFrame(
{
"text": [f"text_{i}" for i in range(n_entries)],
"binary_var": [rng.choice([0,1]) for i in range(n_entries)],
"binary_var": [rng.choice([0, 1]) for i in range(n_entries)],
"continuous_var": rng.random(n_entries),
"id_col": [0 for i in range(100)]
+ [rng.integers(1, 5) for i in range(n_entries - 100)],
Expand All @@ -45,7 +45,7 @@ def test_df_no_time():
return pd.DataFrame(
{
"text": [f"text_{i}" for i in range(n_entries)],
"binary_var": [rng.choice([0,1]) for i in range(n_entries)],
"binary_var": [rng.choice([0, 1]) for i in range(n_entries)],
"continuous_var": rng.random(n_entries),
"id_col": [0 for i in range(100)]
+ [rng.integers(1, 5) for i in range(n_entries - 100)],
Expand All @@ -60,7 +60,7 @@ def test_df_to_pad():
return pd.DataFrame(
{
"text": [f"text_{i}" for i in range(n_entries)],
"binary_var": [rng.choice([0,1]) for i in range(n_entries)],
"binary_var": [rng.choice([0, 1]) for i in range(n_entries)],
"continuous_var": rng.random(n_entries),
"id_col": 0,
"label_col": [rng.integers(0, 4) for i in range(n_entries)],
Expand Down
34 changes: 16 additions & 18 deletions tests/test_data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def test_default_initialisation_datetime(
# 1 dummy id column
assert obj.df.shape == (
len(obj.original_df.index),
1
+ len(obj.original_df.columns)
+ emb.shape[1]
+ len(obj._feature_list)
+ 1,
1 + len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list) + 1,
)
assert obj.pooled_embeddings is None
assert set(obj._feature_list) == {
Expand Down Expand Up @@ -67,10 +63,7 @@ def test_default_initialisation_no_time(
# 1 dummy id column
assert obj.df.shape == (
len(obj.original_df.index),
len(obj.original_df.columns)
+ emb.shape[1]
+ len(obj._feature_list)
+ 1,
len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list) + 1,
)
assert obj.pooled_embeddings is None
assert obj._feature_list == ["timeline_index"]
Expand Down Expand Up @@ -105,10 +98,7 @@ def test_initialisation_with_id_and_label_datetime(
# 3 time features
assert obj.df.shape == (
len(obj.original_df.index),
1
+ len(obj.original_df.columns)
+ emb.shape[1]
+ len(obj._feature_list),
1 + len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list),
)
assert obj.pooled_embeddings is None
assert set(obj._feature_list) == {
Expand Down Expand Up @@ -516,7 +506,10 @@ def test_obtain_colnames_both(test_df_with_datetime, emb, emb_reduced):
)
assert obj._obtain_embedding_colnames(embeddings="full") == emb_names
assert obj._obtain_embedding_colnames(embeddings="dim_reduced") == emb_reduced_names
assert obj._obtain_embedding_colnames(embeddings="both") == emb_reduced_names + emb_names
assert (
obj._obtain_embedding_colnames(embeddings="both")
== emb_reduced_names + emb_names
)


def test_obtain_feature_columns_string(test_df_with_datetime, emb):
Expand Down Expand Up @@ -548,9 +541,11 @@ def test_obtain_feature_columns_string_additional_binary(test_df_with_datetime,
"timeline_index",
"binary_var",
}


def test_obtain_feature_columns_string_additional_continuous(test_df_with_datetime, emb):

def test_obtain_feature_columns_string_additional_continuous(
test_df_with_datetime, emb
):
# default initialisation
obj = PrepareData(original_df=test_df_with_datetime, embeddings=emb)
# originally only have the time features
Expand Down Expand Up @@ -592,7 +587,9 @@ def test_obtain_feature_columns_list_additional(test_df_with_datetime, emb):
"time_diff",
"timeline_index",
}
assert obj._obtain_feature_columns(["time_encoding", "timeline_index", "binary_var", "continuous_var"]) == [
assert obj._obtain_feature_columns(
["time_encoding", "timeline_index", "binary_var", "continuous_var"]
) == [
"time_encoding",
"timeline_index",
"binary_var",
Expand Down Expand Up @@ -734,6 +731,7 @@ def test_standardise_pd_wrong_method(vec_to_standardise, test_df_no_time, emb):
obj = PrepareData(original_df=test_df_no_time, embeddings=emb)
incorrect_method = "fake_method"
with pytest.raises(
ValueError, match=re.escape(f"`method`: {incorrect_method} must be in {implemented}.")
ValueError,
match=re.escape(f"`method`: {incorrect_method} must be in {implemented}."),
):
obj._standardise_pd(vec=vec_to_standardise, method=incorrect_method)
4 changes: 2 additions & 2 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def test_pad_by_id_k_last_additional(test_df_with_datetime, emb):
assert type(obj.array_padded) == np.ndarray
assert np.array_equal(padded_array, obj.array_padded)
assert obj.array_padded.shape == (len(obj.original_df["id_col"].unique()), k, ncol)


def test_pad_by_id_max(test_df_with_datetime, emb):
obj = PrepareData(
Expand Down Expand Up @@ -1278,7 +1278,7 @@ def test_pad_by_id_max(test_df_with_datetime, emb):
assert type(obj.array_padded) == np.ndarray
assert np.array_equal(padded_array, obj.array_padded)
assert obj.array_padded.shape == (len(obj.original_df["id_col"].unique()), k, ncol)


def test_pad_by_id_max_additional(test_df_with_datetime, emb):
obj = PrepareData(
Expand Down

0 comments on commit bd7fb1d

Please sign in to comment.