From d34c32b004845663989d1e9f37273ab66908fde5 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 00:35:09 -0400 Subject: [PATCH 1/6] temp --- tableone/tableone.py | 67 +++++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index 553c5df..75598cb 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -259,6 +259,12 @@ def __init__(self, data: pd.DataFrame, self._normal_test = normal_test self._tukey_test = tukey_test + # if columns are not specified, use all columns + if not columns: + columns = data.columns.values # type: ignore + + self._validate_data(data, columns) + # groupby should be a string if not groupby: groupby = '' @@ -276,32 +282,6 @@ def __init__(self, data: pd.DataFrame, warnings.warn("min_max should specify a list of variables.") min_max = None - # if the input dataframe is empty, raise error - if data.empty: - raise InputError("Input data is empty.") - - # if the input dataframe has a non-unique index, raise error - if not data.index.is_unique: - raise InputError("Input data contains duplicate values in the " - "index. Reset the index and try again.") - - # if columns are not specified, use all columns - if not columns: - columns = data.columns.values # type: ignore - - # check that the columns exist in the dataframe - if not set(columns).issubset(data.columns): # type: ignore - notfound = list(set(columns) - set(data.columns)) # type: ignore - raise InputError("""Columns not found in - dataset: {}""".format(notfound)) - - # check for duplicate columns - dups = data[columns].columns[ - data[columns].columns.duplicated()].unique() - if not dups.empty: - raise InputError("""Input data contains duplicate - columns: {}""".format(dups)) - # if categorical not specified, try to identify categorical if not categorical and type(categorical) != list: categorical = self._detect_categorical_columns(data[columns]) @@ -445,6 +425,41 @@ def __init__(self, data: pd.DataFrame, if display_all: self._set_display_options() + def _handle_deprecations(self): + """ + Raise deprecation warnings. + """ + pass + + def _validate_arguments(self): + """ + Run validation checks on the arguments. + """ + pass + + def _validate_data(self, data, columns): + """ + Run validation checks on the input dataframe. + """ + if data.empty: + raise ValueError("Input data is empty.") + + if not data.index.is_unique: + raise InputError("Input data contains duplicate values in the " + "index. Reset the index and try again.") + + if columns and not set(columns).issubset(data.columns): # type: ignore + missing_cols = list(set(columns) - set(data.columns)) # type: ignore + raise InputError("""Columns not found in + dataset: {}""".format(missing_cols)) + + # check for duplicate columns + dups = data[columns].columns[ + data[columns].columns.duplicated()].unique() + if not dups.empty: + raise InputError("""Input data contains duplicate + columns: {}""".format(dups)) + def __str__(self) -> str: return self.tableone.to_string() + self._generate_remarks('\n') From db6b1acd9d1e1aa0e00cc64b2ff2c48bb4e6e6c7 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 00:39:23 -0400 Subject: [PATCH 2/6] Move data validation checks to _validate_data method. --- tableone/tableone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index 75598cb..d0a0c6e 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -448,9 +448,9 @@ def _validate_data(self, data, columns): raise InputError("Input data contains duplicate values in the " "index. Reset the index and try again.") - if columns and not set(columns).issubset(data.columns): # type: ignore + if not set(columns).issubset(data.columns): # type: ignore missing_cols = list(set(columns) - set(data.columns)) # type: ignore - raise InputError("""Columns not found in + raise InputError("""The following columns were not found in the dataset: {}""".format(missing_cols)) # check for duplicate columns From 5bf08cc35529c427cd0fcae3665bd3c2a6ad2ec9 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 01:19:55 -0400 Subject: [PATCH 3/6] Require 'groupby' to be str. --- tableone/tableone.py | 23 +++++++++++++++-------- tests/unit/test_tableone.py | 32 ++++++++++++++++---------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index d0a0c6e..91c3990 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -259,17 +259,15 @@ def __init__(self, data: pd.DataFrame, self._normal_test = normal_test self._tukey_test = tukey_test - # if columns are not specified, use all columns + self._handle_deprecations() + + # Default assignment for columns if not provided if not columns: columns = data.columns.values # type: ignore self._validate_data(data, columns) - # groupby should be a string - if not groupby: - groupby = '' - elif groupby and type(groupby) == list: - groupby = groupby[0] + groupby = self._validate_arguments(groupby) # nonnormal should be a string if not nonnormal: @@ -431,11 +429,20 @@ def _handle_deprecations(self): """ pass - def _validate_arguments(self): + def _validate_arguments(self, groupby): """ Run validation checks on the arguments. """ - pass + if groupby: + if isinstance(groupby, list): + raise ValueError(f"Invalid 'groupby' type: expected a string, received a list. Use '{groupby[0]}' if it's the intended group.") + elif not isinstance(groupby, str): + raise TypeError(f"Invalid 'groupby' type: expected a string, received {type(groupby).__name__}.") + else: + # If 'groupby' is not provided or is explicitly None, treat it as an empty string. + groupby = '' + + return groupby def _validate_data(self, data, columns): """ diff --git a/tests/unit/test_tableone.py b/tests/unit/test_tableone.py index 971e41f..625c91f 100644 --- a/tests/unit/test_tableone.py +++ b/tests/unit/test_tableone.py @@ -146,7 +146,7 @@ def test_examples_used_in_the_readme_run_without_raising_error_pn( columns = ['Age', 'SysABP', 'Height', 'Weight', 'ICU', 'death'] categorical = ['ICU', 'death'] - groupby = ['death'] + groupby = 'death' nonnormal = ['Age'] TableOne(data_pn, columns=columns, categorical=categorical, groupby=groupby, @@ -353,7 +353,7 @@ def test_input_data_not_modified(self, data_groups): df_groupby = data_groups.copy() TableOne(df_groupby, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group']) + categorical=['group'], groupby='group') assert (df_groupby['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() assert (df_groupby['weight'] == df_orig['weight']).all() @@ -361,7 +361,7 @@ def test_input_data_not_modified(self, data_groups): # sorted df_sorted = data_groups.copy() TableOne(df_sorted, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group'], + categorical=['group'], groupby='group', sort=True) assert (df_sorted['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -370,7 +370,7 @@ def test_input_data_not_modified(self, data_groups): # pval df_pval = data_groups.copy() TableOne(df_pval, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group'], + categorical=['group'], groupby='group', sort=True, pval=True) assert (df_pval['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -381,7 +381,7 @@ def test_input_data_not_modified(self, data_groups): TableOne(df_pval_adjust, columns=['group', 'age', 'weight'], categorical=['group'], - groupby=['group'], sort=True, pval=True, + groupby='group', sort=True, pval=True, pval_adjust='bonferroni') assert (df_pval_adjust['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -391,7 +391,7 @@ def test_input_data_not_modified(self, data_groups): df_labels = data_groups.copy() TableOne(df_labels, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group'], + categorical=['group'], groupby='group', rename={'age': 'age, years'}) assert (df_labels['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -401,7 +401,7 @@ def test_input_data_not_modified(self, data_groups): df_limit = data_groups.copy() TableOne(df_limit, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group'], + categorical=['group'], groupby='group', limit=2) assert (df_limit['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -411,7 +411,7 @@ def test_input_data_not_modified(self, data_groups): df_nonnormal = data_groups.copy() TableOne(df_nonnormal, columns=['group', 'age', 'weight'], - categorical=['group'], groupby=['group'], + categorical=['group'], groupby='group', nonnormal=['age']) assert (df_nonnormal['group'] == df_orig['group']).all() assert (df_groupby['age'] == df_orig['age']).all() @@ -517,7 +517,7 @@ def test_tableone_columns_in_consistent_order_pn(self, data_pn): """ df = data_pn.copy() columns = ['Age', 'SysABP', 'Height', 'Weight', 'ICU', 'death'] - groupby = ['death'] + groupby = 'death' table = TableOne(df, columns=columns, groupby=groupby, pval=True, htest_name=True, overall=False) @@ -553,7 +553,7 @@ def test_check_null_counts_are_correct_pn(self, data_pn): """ columns = ['Age', 'SysABP', 'Height', 'Weight', 'ICU', 'death'] categorical = ['ICU', 'death'] - groupby = ['death'] + groupby = 'death' # test when not grouping table = TableOne(data_pn, columns=columns, @@ -612,7 +612,7 @@ def test_the_decimals_argument_for_continuous_variables(self, data_pn): """ columns = ['Age', 'SysABP', 'Height', 'Weight', 'ICU', 'death'] categorical = ['ICU', 'death'] - groupby = ['death'] + groupby = 'death' nonnormal = ['Age'] # no decimals argument @@ -692,7 +692,7 @@ def test_the_decimals_argument_for_categorical_variables(self, data_pn): """ columns = ['Age', 'SysABP', 'Height', 'Weight', 'ICU', 'death'] categorical = ['ICU', 'death'] - groupby = ['death'] + groupby = 'death' nonnormal = ['Age'] # decimals = 1 @@ -1065,7 +1065,7 @@ def test_min_max_for_nonnormal_variables(self, data_pn): nonnormal = ['Age'] # optionally, a categorical variable for stratification - groupby = ['death'] + groupby = 'death' t1 = TableOne(data_pn, columns=columns, categorical=categorical, groupby=groupby, nonnormal=nonnormal, decimals=decimals, @@ -1096,7 +1096,7 @@ def test_row_percent_false(self, data_pn): nonnormal = ['Age'] # optionally, a categorical variable for stratification - groupby = ['death'] + groupby = 'death' group = "Grouped by death" # row_percent = False @@ -1146,7 +1146,7 @@ def test_row_percent_true(self, data_pn): nonnormal = ['Age'] # optionally, a categorical variable for stratification - groupby = ['death'] + groupby = 'death' group = "Grouped by death" # row_percent = True @@ -1196,7 +1196,7 @@ def test_row_percent_true_and_overall_false(self, data_pn): nonnormal = ['Age'] # optionally, a categorical variable for stratification - groupby = ['death'] + groupby = 'death' group = "Grouped by death" # row_percent = True From cfc534e5937b181e20f3fef5fae827b968822114 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 01:55:16 -0400 Subject: [PATCH 4/6] Move validation of input arguments to _validate_arguments method. --- tableone/tableone.py | 81 ++++++++++++++++++++++++------------- tests/unit/test_tableone.py | 6 +-- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index 91c3990..0fe560f 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -267,18 +267,8 @@ def __init__(self, data: pd.DataFrame, self._validate_data(data, columns) - groupby = self._validate_arguments(groupby) - - # nonnormal should be a string - if not nonnormal: - nonnormal = [] - elif nonnormal and type(nonnormal) == str: - nonnormal = [nonnormal] - - # min_max should be a list - if min_max and isinstance(min_max, bool): - warnings.warn("min_max should specify a list of variables.") - min_max = None + (groupby, nonnormal, min_max, pval_adjust, order) = self._validate_arguments( + groupby, nonnormal, min_max, pval_adjust, order, pval) # if categorical not specified, try to identify categorical if not categorical and type(categorical) != list: @@ -287,16 +277,6 @@ def __init__(self, data: pd.DataFrame, if groupby: categorical = [x for x in categorical if x != groupby] - if isinstance(pval_adjust, bool) and pval_adjust: - msg = ("pval_adjust expects a string, but a boolean was specified." - " Defaulting to the 'bonferroni' correction.") - warnings.warn(msg) - pval_adjust = "bonferroni" - - # if custom order is provided, ensure that values are strings - if order: - order = {k: ["{}".format(v) for v in order[k]] for k in order} - # if input df has ordered categorical variables, get the order. order_cats = [x for x in data.select_dtypes("category") if data[x].dtype.ordered] # type: ignore @@ -314,9 +294,6 @@ def __init__(self, data: pd.DataFrame, elif order_cats: order = d_order_cats # type: ignore - if pval and not groupby: - raise InputError("If pval=True then groupby must be specified.") - self._columns = list(columns) # type: ignore self._continuous = [c for c in columns # type: ignore if c not in categorical + [groupby]] @@ -429,10 +406,11 @@ def _handle_deprecations(self): """ pass - def _validate_arguments(self, groupby): + def _validate_arguments(self, groupby, nonnormal, min_max, pval_adjust, order, pval): """ Run validation checks on the arguments. """ + # validate 'groupby' argument if groupby: if isinstance(groupby, list): raise ValueError(f"Invalid 'groupby' type: expected a string, received a list. Use '{groupby[0]}' if it's the intended group.") @@ -442,7 +420,56 @@ def _validate_arguments(self, groupby): # If 'groupby' is not provided or is explicitly None, treat it as an empty string. groupby = '' - return groupby + # Validate 'nonnormal' argument + if nonnormal is None: + nonnormal = [] + elif isinstance(nonnormal, str): + nonnormal = [nonnormal] + elif not isinstance(nonnormal, list): + raise TypeError(f"Invalid 'nonnormal' type: expected a list or a string, received {type(nonnormal).__name__}.") + else: + # Ensure all elements in the list are strings + if not all(isinstance(item, str) for item in nonnormal): + raise ValueError("All items in 'nonnormal' list must be strings.") + + # Validate 'min_max' argument + if min_max is None: + min_max = [] + elif isinstance(min_max, list): + # Optionally, further validate that the list contains only strings (if needed) + if not all(isinstance(item, str) for item in min_max): + raise ValueError("All items in 'min_max' list must be strings representing column names.") + else: + raise TypeError(f"Invalid 'min_max' type: expected a list, received {type(min_max).__name__}.") + + # Validate 'pval_adjust' argument + if pval_adjust is not None: + valid_methods = {"bonferroni", "sidak", "holm-sidak", "simes-hochberg", "hommel", None} + if isinstance(pval_adjust, str): + if pval_adjust.lower() not in valid_methods: + raise ValueError(f"Invalid 'pval_adjust' value: '{pval_adjust}'. " + f"Expected one of {', '.join(valid_methods)} or None.") + else: + raise TypeError(f"Invalid type for 'pval_adjust': expected a string or None, " + f"received {type(pval_adjust).__name__}.") + + # Validate 'order' argument + if order is not None: + if not isinstance(order, dict): + raise TypeError("The 'order' parameter must be a dictionary where keys are column names and values are lists of ordered categories.") + + for key, values in order.items(): + if not isinstance(values, list): + raise TypeError(f"The value for '{key}' in 'order' must be a list of categories.") + + # Convert all items in the list to strings safely and efficiently + order[key] = [str(v) for v in values] + + if pval and not groupby: + raise ValueError("The 'pval' parameter is set to True, but no 'groupby' parameter was specified. " + "Please provide a 'groupby' column name to perform p-value calculations.") + + return groupby, nonnormal, min_max, pval_adjust, order def _validate_data(self, data, columns): """ diff --git a/tests/unit/test_tableone.py b/tests/unit/test_tableone.py index 625c91f..4234aac 100644 --- a/tests/unit/test_tableone.py +++ b/tests/unit/test_tableone.py @@ -434,7 +434,7 @@ def test_groupby_with_group_named_isnull_pn(self, data_pn): tableone_columns = list(table.tableone.columns.levels[1]) table = TableOne(df, columns=columns, groupby=groupby, pval=True, - pval_adjust='b') + pval_adjust='bonferroni') tableone_columns = (tableone_columns + list(table.tableone.columns.levels[1])) tableone_columns = np.unique(tableone_columns) @@ -445,7 +445,7 @@ def test_groupby_with_group_named_isnull_pn(self, data_pn): # for each output column name in tableone, try them as a group df.loc[0:20, 'ICU'] = c if 'adjust' in c: - pval_adjust = 'b' + pval_adjust = 'bonferroni' else: pval_adjust = None @@ -836,7 +836,7 @@ def test_pval_correction(self): # catch the pval_adjust=True with warnings.catch_warnings(record=False): warnings.simplefilter('ignore', category=UserWarning) - TableOne(df, groupby="even", pval=True, pval_adjust=True) + TableOne(df, groupby="even", pval=True, pval_adjust='bonferroni') for k in pvals_expected: assert (t1.tableone.loc[k][group][col].values[0] == From bb5fafcc4ef12714c2d564ebf97d897add08a370 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 02:23:46 -0400 Subject: [PATCH 5/6] handle deprecations in _handle_deprecations method. --- tableone/tableone.py | 69 ++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 41 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index 0fe560f..c16aaea 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -220,46 +220,14 @@ def __init__(self, data: pd.DataFrame, tukey_test: bool = False, pval_threshold: Optional[float] = None) -> None: - # labels is now rename - if labels is not None and rename is not None: - raise TypeError("TableOne received both labels and rename.") - elif labels is not None: - warnings.warn("The labels argument is deprecated; use " - "rename instead.", DeprecationWarning) - self._alt_labels = labels - else: - self._alt_labels = rename - - # isnull is now missing - if isnull is not None: - warnings.warn("The isnull argument is deprecated; use " - "missing instead.", DeprecationWarning) - self._isnull = isnull - else: - self._isnull = missing - - # pval_test_name is now htest_name - if pval_test_name: - warnings.warn("The pval_test_name argument is deprecated; use " - "htest_name instead.", DeprecationWarning) - self._pval_test_name = pval_test_name - else: - self._pval_test_name = htest_name + self._handle_deprecations(labels, rename, isnull, pval_test_name, remarks) - # remarks are now specified by individual test names - if remarks: - warnings.warn("The remarks argument is deprecated; specify tests " - "by name instead (e.g. diptest = True)", - DeprecationWarning) - self._dip_test = remarks - self._normal_test = remarks - self._tukey_test = remarks - else: - self._dip_test = dip_test - self._normal_test = normal_test - self._tukey_test = tukey_test - - self._handle_deprecations() + self._alt_labels = rename + self._isnull = missing + self._pval_test_name = htest_name + self._dip_test = dip_test + self._normal_test = normal_test + self._tukey_test = tukey_test # Default assignment for columns if not provided if not columns: @@ -400,11 +368,30 @@ def __init__(self, data: pd.DataFrame, if display_all: self._set_display_options() - def _handle_deprecations(self): + def _handle_deprecations(self, labels, rename, isnull, pval_test_name, remarks): """ Raise deprecation warnings. """ - pass + if labels is not None: + warnings.warn("The 'labels' argument of TableOne() is deprecated and will be removed in a future version. " + "Use 'rename' instead.", DeprecationWarning, stacklevel=3) + + if labels is not None and rename is not None: + raise TypeError("TableOne received both 'labels' and 'rename'. Please use only 'rename'.") + + if isnull is not None: + warnings.warn("The 'isnull' argument is deprecated; use 'missing' instead.", + DeprecationWarning, stacklevel=3) + + # pval_test_name is now htest_name + if pval_test_name: + warnings.warn("The pval_test_name argument is deprecated; use htest_name instead.", + DeprecationWarning, stacklevel=3) + + if remarks: + warnings.warn("The 'remarks' argument is deprecated; specify tests " + "by name instead (e.g. diptest = True)", + DeprecationWarning, stacklevel=2) def _validate_arguments(self, groupby, nonnormal, min_max, pval_adjust, order, pval): """ From 18136971162c139db2fab2754999d4519f57d66d Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 4 Jun 2024 02:43:26 -0400 Subject: [PATCH 6/6] reorder attributes alphabetically. --- tableone/tableone.py | 45 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tableone/tableone.py b/tableone/tableone.py index c16aaea..814a8e0 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -222,13 +222,6 @@ def __init__(self, data: pd.DataFrame, self._handle_deprecations(labels, rename, isnull, pval_test_name, remarks) - self._alt_labels = rename - self._isnull = missing - self._pval_test_name = htest_name - self._dip_test = dip_test - self._normal_test = normal_test - self._tukey_test = tukey_test - # Default assignment for columns if not provided if not columns: columns = data.columns.values # type: ignore @@ -262,35 +255,41 @@ def __init__(self, data: pd.DataFrame, elif order_cats: order = d_order_cats # type: ignore + self._alt_labels = rename self._columns = list(columns) # type: ignore self._continuous = [c for c in columns # type: ignore if c not in categorical + [groupby]] self._categorical = categorical - self._nonnormal = nonnormal - self._min_max = min_max - self._pval = pval - self._pval_adjust = pval_adjust - self._htest = htest - self._sort = sort - self._groupby = groupby - # degrees of freedom for standard deviation self._ddof = ddof + self._decimals = decimals + self._dip_test = dip_test + self._groupby = groupby + self._htest = htest + self._isnull = missing + self._label_suffix = label_suffix self._limit = limit + self._min_max = min_max + self._nonnormal = nonnormal + self._normal_test = normal_test self._order = order - self._label_suffix = label_suffix - self._decimals = decimals - self._smd = smd - self._pval_threshold = pval_threshold self._overall = overall + self._pval = pval + self._pval_adjust = pval_adjust + self._pval_test_name = htest_name + self._pval_threshold = pval_threshold + + # column names that cannot be contained in a groupby + self._reserved_columns = ['Missing', 'P-Value', 'Test', + 'P-Value (adjusted)', 'SMD', 'Overall'] + self._row_percent = row_percent + self._smd = smd + self._sort = sort + self._tukey_test = tukey_test # display notes and warnings below the table self._warnings = {} - # output column names that cannot be contained in a groupby - self._reserved_columns = ['Missing', 'P-Value', 'Test', - 'P-Value (adjusted)', 'SMD', 'Overall'] - if self._groupby: self._groupbylvls = sorted(data.groupby(groupby).groups.keys()) # type: ignore