Skip to content

Commit

Permalink
Merge pull request #166 from tompollard/tp/refactor1
Browse files Browse the repository at this point in the history
Refactoring for readability: Add _handle_deprecations, _validate_arguments, _validate_data methods.
  • Loading branch information
tompollard authored Jun 4, 2024
2 parents ee563de + 1813697 commit 767a88c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 124 deletions.
245 changes: 140 additions & 105 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,87 +220,16 @@ def __init__(self, data: pd.DataFrame,
tukey_test: bool = False,
pval_threshold: Optional[float] = None) -> None:

# labels is now rename
if labels is not None and rename is not None:
raise TypeError("TableOne received both labels and rename.")
elif labels is not None:
warnings.warn("The labels argument is deprecated; use "
"rename instead.", DeprecationWarning)
self._alt_labels = labels
else:
self._alt_labels = rename

# isnull is now missing
if isnull is not None:
warnings.warn("The isnull argument is deprecated; use "
"missing instead.", DeprecationWarning)
self._isnull = isnull
else:
self._isnull = missing

# pval_test_name is now htest_name
if pval_test_name:
warnings.warn("The pval_test_name argument is deprecated; use "
"htest_name instead.", DeprecationWarning)
self._pval_test_name = pval_test_name
else:
self._pval_test_name = htest_name

# remarks are now specified by individual test names
if remarks:
warnings.warn("The remarks argument is deprecated; specify tests "
"by name instead (e.g. diptest = True)",
DeprecationWarning)
self._dip_test = remarks
self._normal_test = remarks
self._tukey_test = remarks
else:
self._dip_test = dip_test
self._normal_test = normal_test
self._tukey_test = tukey_test

# groupby should be a string
if not groupby:
groupby = ''
elif groupby and type(groupby) == list:
groupby = groupby[0]
self._handle_deprecations(labels, rename, isnull, pval_test_name, remarks)

# nonnormal should be a string
if not nonnormal:
nonnormal = []
elif nonnormal and type(nonnormal) == str:
nonnormal = [nonnormal]

# min_max should be a list
if min_max and isinstance(min_max, bool):
warnings.warn("min_max should specify a list of variables.")
min_max = None

# if the input dataframe is empty, raise error
if data.empty:
raise InputError("Input data is empty.")

# if the input dataframe has a non-unique index, raise error
if not data.index.is_unique:
raise InputError("Input data contains duplicate values in the "
"index. Reset the index and try again.")

# if columns are not specified, use all columns
# Default assignment for columns if not provided
if not columns:
columns = data.columns.values # type: ignore

# check that the columns exist in the dataframe
if not set(columns).issubset(data.columns): # type: ignore
notfound = list(set(columns) - set(data.columns)) # type: ignore
raise InputError("""Columns not found in
dataset: {}""".format(notfound))
self._validate_data(data, columns)

# check for duplicate columns
dups = data[columns].columns[
data[columns].columns.duplicated()].unique()
if not dups.empty:
raise InputError("""Input data contains duplicate
columns: {}""".format(dups))
(groupby, nonnormal, min_max, pval_adjust, order) = self._validate_arguments(
groupby, nonnormal, min_max, pval_adjust, order, pval)

# if categorical not specified, try to identify categorical
if not categorical and type(categorical) != list:
Expand All @@ -309,16 +238,6 @@ def __init__(self, data: pd.DataFrame,
if groupby:
categorical = [x for x in categorical if x != groupby]

if isinstance(pval_adjust, bool) and pval_adjust:
msg = ("pval_adjust expects a string, but a boolean was specified."
" Defaulting to the 'bonferroni' correction.")
warnings.warn(msg)
pval_adjust = "bonferroni"

# if custom order is provided, ensure that values are strings
if order:
order = {k: ["{}".format(v) for v in order[k]] for k in order}

# if input df has ordered categorical variables, get the order.
order_cats = [x for x in data.select_dtypes("category")
if data[x].dtype.ordered] # type: ignore
Expand All @@ -336,38 +255,41 @@ def __init__(self, data: pd.DataFrame,
elif order_cats:
order = d_order_cats # type: ignore

if pval and not groupby:
raise InputError("If pval=True then groupby must be specified.")

self._alt_labels = rename
self._columns = list(columns) # type: ignore
self._continuous = [c for c in columns # type: ignore
if c not in categorical + [groupby]]
self._categorical = categorical
self._nonnormal = nonnormal
self._min_max = min_max
self._pval = pval
self._pval_adjust = pval_adjust
self._htest = htest
self._sort = sort
self._groupby = groupby
# degrees of freedom for standard deviation
self._ddof = ddof
self._decimals = decimals
self._dip_test = dip_test
self._groupby = groupby
self._htest = htest
self._isnull = missing
self._label_suffix = label_suffix
self._limit = limit
self._min_max = min_max
self._nonnormal = nonnormal
self._normal_test = normal_test
self._order = order
self._label_suffix = label_suffix
self._decimals = decimals
self._smd = smd
self._pval_threshold = pval_threshold
self._overall = overall
self._pval = pval
self._pval_adjust = pval_adjust
self._pval_test_name = htest_name
self._pval_threshold = pval_threshold

# column names that cannot be contained in a groupby
self._reserved_columns = ['Missing', 'P-Value', 'Test',
'P-Value (adjusted)', 'SMD', 'Overall']

self._row_percent = row_percent
self._smd = smd
self._sort = sort
self._tukey_test = tukey_test

# display notes and warnings below the table
self._warnings = {}

# output column names that cannot be contained in a groupby
self._reserved_columns = ['Missing', 'P-Value', 'Test',
'P-Value (adjusted)', 'SMD', 'Overall']

if self._groupby:
self._groupbylvls = sorted(data.groupby(groupby).groups.keys()) # type: ignore

Expand Down Expand Up @@ -445,6 +367,119 @@ def __init__(self, data: pd.DataFrame,
if display_all:
self._set_display_options()

def _handle_deprecations(self, labels, rename, isnull, pval_test_name, remarks):
"""
Raise deprecation warnings.
"""
if labels is not None:
warnings.warn("The 'labels' argument of TableOne() is deprecated and will be removed in a future version. "
"Use 'rename' instead.", DeprecationWarning, stacklevel=3)

if labels is not None and rename is not None:
raise TypeError("TableOne received both 'labels' and 'rename'. Please use only 'rename'.")

if isnull is not None:
warnings.warn("The 'isnull' argument is deprecated; use 'missing' instead.",
DeprecationWarning, stacklevel=3)

# pval_test_name is now htest_name
if pval_test_name:
warnings.warn("The pval_test_name argument is deprecated; use htest_name instead.",
DeprecationWarning, stacklevel=3)

if remarks:
warnings.warn("The 'remarks' argument is deprecated; specify tests "
"by name instead (e.g. diptest = True)",
DeprecationWarning, stacklevel=2)

def _validate_arguments(self, groupby, nonnormal, min_max, pval_adjust, order, pval):
"""
Run validation checks on the arguments.
"""
# validate 'groupby' argument
if groupby:
if isinstance(groupby, list):
raise ValueError(f"Invalid 'groupby' type: expected a string, received a list. Use '{groupby[0]}' if it's the intended group.")
elif not isinstance(groupby, str):
raise TypeError(f"Invalid 'groupby' type: expected a string, received {type(groupby).__name__}.")
else:
# If 'groupby' is not provided or is explicitly None, treat it as an empty string.
groupby = ''

# Validate 'nonnormal' argument
if nonnormal is None:
nonnormal = []
elif isinstance(nonnormal, str):
nonnormal = [nonnormal]
elif not isinstance(nonnormal, list):
raise TypeError(f"Invalid 'nonnormal' type: expected a list or a string, received {type(nonnormal).__name__}.")
else:
# Ensure all elements in the list are strings
if not all(isinstance(item, str) for item in nonnormal):
raise ValueError("All items in 'nonnormal' list must be strings.")

# Validate 'min_max' argument
if min_max is None:
min_max = []
elif isinstance(min_max, list):
# Optionally, further validate that the list contains only strings (if needed)
if not all(isinstance(item, str) for item in min_max):
raise ValueError("All items in 'min_max' list must be strings representing column names.")
else:
raise TypeError(f"Invalid 'min_max' type: expected a list, received {type(min_max).__name__}.")

# Validate 'pval_adjust' argument
if pval_adjust is not None:
valid_methods = {"bonferroni", "sidak", "holm-sidak", "simes-hochberg", "hommel", None}
if isinstance(pval_adjust, str):
if pval_adjust.lower() not in valid_methods:
raise ValueError(f"Invalid 'pval_adjust' value: '{pval_adjust}'. "
f"Expected one of {', '.join(valid_methods)} or None.")
else:
raise TypeError(f"Invalid type for 'pval_adjust': expected a string or None, "
f"received {type(pval_adjust).__name__}.")

# Validate 'order' argument
if order is not None:
if not isinstance(order, dict):
raise TypeError("The 'order' parameter must be a dictionary where keys are column names and values are lists of ordered categories.")

for key, values in order.items():
if not isinstance(values, list):
raise TypeError(f"The value for '{key}' in 'order' must be a list of categories.")

# Convert all items in the list to strings safely and efficiently
order[key] = [str(v) for v in values]

if pval and not groupby:
raise ValueError("The 'pval' parameter is set to True, but no 'groupby' parameter was specified. "
"Please provide a 'groupby' column name to perform p-value calculations.")

return groupby, nonnormal, min_max, pval_adjust, order

def _validate_data(self, data, columns):
"""
Run validation checks on the input dataframe.
"""
if data.empty:
raise ValueError("Input data is empty.")

if not data.index.is_unique:
raise InputError("Input data contains duplicate values in the "
"index. Reset the index and try again.")

if not set(columns).issubset(data.columns): # type: ignore
missing_cols = list(set(columns) - set(data.columns)) # type: ignore
raise InputError("""The following columns were not found in the
dataset: {}""".format(missing_cols))

# check for duplicate columns
dups = data[columns].columns[
data[columns].columns.duplicated()].unique()
if not dups.empty:
raise InputError("""Input data contains duplicate
columns: {}""".format(dups))

def __str__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

Expand Down
Loading

0 comments on commit 767a88c

Please sign in to comment.