From d83d76bcc8961c84016bff17bfce65790da0d7c9 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 9 Jul 2024 13:39:04 -0400 Subject: [PATCH 1/2] Move formatting related functions to a formatting module. --- tableone/formatting.py | 292 +++++++++++++++++++++++++++++++++++++ tableone/tableone.py | 317 +++-------------------------------------- 2 files changed, 311 insertions(+), 298 deletions(-) create mode 100644 tableone/formatting.py diff --git a/tableone/formatting.py b/tableone/formatting.py new file mode 100644 index 0000000..a453f0f --- /dev/null +++ b/tableone/formatting.py @@ -0,0 +1,292 @@ +import warnings + +import numpy as np +import pandas as pd + + +def docstring_copier(*sub): + """ + Wrap the TableOne docstring (not ideal :/) + """ + def dec(obj): + obj.__doc__ = obj.__doc__.format(*sub) + return obj + return dec + + +def set_display_options(max_rows=None, + max_columns=None, + width=None, + max_colwidth=None): + """ + Set pandas display options. Display all rows and columns by default. + """ + display_options = {'display.max_rows': max_rows, + 'display.max_columns': max_columns, + 'display.width': width, + 'display.max_colwidth': max_colwidth} + + for k in display_options: + try: + pd.set_option(k, display_options[k]) + except ValueError: + msg = """Newer version of Pandas required to set the '{}' + option.""".format(k) + warnings.warn(msg) + + +def format_pvalues(table, pval, pval_adjust, pval_threshold): + """ + Formats the p value columns, applying rounding rules and adding + significance markers based on defined thresholds. + """ + # round pval column and convert to string + if pval and pval_adjust: + if pval_threshold: + asterisk_mask = table['P-Value (adjusted)'] < pval_threshold + + table['P-Value (adjusted)'] = table['P-Value (adjusted)'].apply( + '{:.3f}'.format).astype(str) + table.loc[table['P-Value (adjusted)'] == '0.000', + 'P-Value (adjusted)'] = '<0.001' + + if pval_threshold: + table.loc[asterisk_mask, 'P-Value (adjusted)'] = ( + table['P-Value (adjusted)'][asterisk_mask].astype(str)+"*" # type: ignore + ) + + elif pval: + if pval_threshold: + asterisk_mask = table['P-Value'] < pval_threshold + + table['P-Value'] = table['P-Value'].apply( + '{:.3f}'.format).astype(str) + table.loc[table['P-Value'] == '0.000', 'P-Value'] = '<0.001' + + if pval_threshold: + table.loc[asterisk_mask, 'P-Value'] = ( + table['P-Value'][asterisk_mask].astype(str)+"*" # type: ignore + ) + + return table + + +def format_smd_columns(table, smd, smd_table): + """ + Formats the SMD (Standardized Mean Differences) columns. Rounds the SMD values + and ensures they are presented as strings. + """ + # round smd columns and convert to string + if smd and smd_table is not None: + for c in list(smd_table.columns): + table[c] = table[c].apply('{:.3f}'.format).astype(str) + table.loc[table[c] == '0.000', c] = '<0.001' + + return table + + +def apply_limits(table, data, limits, categorical, order): + """ + Applies limits to the number of categories shown for each categorical variable + in the DataFrame, based on specified requirements. + """ + # set the limit on the number of categorical variables + if limits: + levelcounts = data[categorical].nunique() + + for k, _ in levelcounts.items(): + # set the limit for the variable + if (isinstance(limits, int) + and levelcounts[k] >= limits): + limit = limits + elif isinstance(limits, dict) and k in limits: + limit = limits[k] + else: + continue + + if not order or (order and k not in order): + # re-order the variables by frequency + count = data[k].value_counts().sort_values(ascending=False) + new_idx = [(k, '{}'.format(i)) for i in count.index] + else: + # apply order + all_var = table.loc[k].index.unique(level='value') + new_idx = [(k, '{}'.format(v)) for v in order[k]] + new_idx += [(k, '{}'.format(v)) for v in all_var + if v not in order[k]] + + # restructure to match the original idx + new_idx_array = np.empty((len(new_idx),), dtype=object) + new_idx_array[:] = [tuple(i) for i in new_idx] + orig_idx = table.index.values.copy() + orig_idx[table.index.get_loc(k)] = new_idx_array + table = table.reindex(orig_idx) + + # drop the rows > the limit + table = table.drop(new_idx_array[limit:]) # type: ignore + + return table + + +def sort_and_reindex(table, smd, smd_table, sort, columns): + """ + Sorts and reindexes the table to meet requirements. + """ + # sort the table rows + sort_columns = ['Missing', 'P-Value', 'P-Value (adjusted)', 'Test'] + + if smd and smd_table is not None: + sort_columns = sort_columns + list(smd_table.columns) + + if sort and isinstance(sort, bool): + new_index = sorted(table.index.values, key=lambda x: x[0].lower()) + elif sort and isinstance(sort, str) and (sort in sort_columns): + try: + new_index = table.sort_values(sort).index + except KeyError: + new_index = sorted(table.index.values, + key=lambda x: columns.index(x[0])) + warnings.warn(f'Sort variable not found: {sort}') + elif sort and isinstance(sort, str) and (sort not in sort_columns): + new_index = sorted(table.index.values, + key=lambda x: columns.index(x[0])) + warnings.warn(f'Sort must be in the following list: {sort}') + else: + # sort by the columns argument + new_index = sorted(table.index.values, + key=lambda x: columns.index(x[0])) + table = table.reindex(new_index) + + return table + + +def apply_order(table, order, groupby): + """ + Applies a predefined order to rows based on specified requirements. + May include reordering based on categorical group levels or other criteria. + """ + # if an order is specified, apply it + if order: + for k in order: + # Skip if the variable isn't present + try: + all_var = table.loc[k].index.unique(level='value') + except KeyError: + if k not in groupby: # type: ignore + warnings.warn(f"Order variable not found: {k}") + continue + + # Remove value from order if it is not present + if [i for i in order[k] if i not in all_var]: + rm_var = [i for i in order[k] if i not in all_var] + order[k] = [i for i in order[k] if i in all_var] + warnings.warn(f'Order value not found: "{k}: {rm_var}"') + + new_seq = [(k, '{}'.format(v)) for v in order[k]] + new_seq += [(k, '{}'.format(v)) for v in all_var + if v not in order[k]] + + # restructure to match the original idx + new_idx_array = np.empty((len(new_seq),), dtype=object) + new_idx_array[:] = [tuple(i) for i in new_seq] + orig_idx = table.index.values.copy() + orig_idx[table.index.get_loc(k)] = new_idx_array + table = table.reindex(orig_idx) + + return table + + +def mask_duplicate_values(table, optional_columns, smd, smd_table): + """ + Masks duplicate values, ensuring that repeated values (e.g. counts of + missing values) are only displayed once. + """ + # only display data in first level row + dupe_mask = table.groupby(level=[0]).cumcount().ne(0) # type: ignore + dupe_columns = ['Missing'] + + if smd and smd_table is not None: + optional_columns = optional_columns + list(smd_table.columns) + for col in optional_columns: + if col in table.columns.values: + dupe_columns.append(col) + + table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('') + + return table + + +def create_row_labels(columns, alt_labels, label_suffix, nonnormal, + min_max, categorical) -> dict: + """ + Take the original labels for rows. Rename if alternative labels are + provided. Append label suffix if label_suffix is True. + + Returns + ---------- + labels : dictionary + Dictionary, keys are original column name, values are final label. + + """ + # start with the original column names + labels = {} + for c in columns: + labels[c] = c + + # replace column names with alternative names if provided + if alt_labels: + for k in alt_labels.keys(): + labels[k] = alt_labels[k] + + # append the label suffix + if label_suffix: + for k in labels.keys(): + if k in nonnormal: + if min_max and k in min_max: + labels[k] = "{}, {}".format(labels[k], + "median [min,max]") + else: + labels[k] = "{}, {}".format(labels[k], + "median [Q1,Q3]") + elif k in categorical: + labels[k] = "{}, {}".format(labels[k], "n (%)") + else: + if min_max and k in min_max: + labels[k] = "{}, {}".format(labels[k], + "mean [min,max]") + else: + labels[k] = "{}, {}".format(labels[k], + "mean (SD)") + + return labels + + +def reorder_columns(table, optional_columns, groupby, order, overall): + """ + Reorder columns for consistent, predictable formatting. + """ + if groupby and order and (groupby in order): + header = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore + cols = order[groupby] + ['{}'.format(v) for v in header if v not in order[groupby]] + elif groupby: + cols = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore + else: + cols = ['{}'.format(v) for v in table.columns.values] + + if groupby and overall: + cols = ['Overall'] + [x for x in cols if x != 'Overall'] + + if 'Missing' in cols: + cols = ['Missing'] + [x for x in cols if x != 'Missing'] + + # move optional_columns to the end of the dataframe + for col in optional_columns: + if col in cols: + cols = [x for x in cols if x != col] + [col] + + if groupby: + table = table.reindex(cols, axis=1, level=1) + else: + table = table.reindex(cols, axis=1) + + return table diff --git a/tableone/tableone.py b/tableone/tableone.py index c196bab..94796bb 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -10,6 +10,10 @@ from tabulate import tabulate from tableone.deprecations import handle_deprecated_parameters +from tableone.formatting import (docstring_copier, set_display_options, format_pvalues, + format_smd_columns, apply_limits, sort_and_reindex, + apply_order, mask_duplicate_values, create_row_labels, + reorder_columns) from tableone.preprocessors import (ensure_list, detect_categorical, order_categorical, get_groups, handle_categorical_nulls) from tableone.statistics import Statistics @@ -42,16 +46,6 @@ def load_dataset(name: str) -> pd.DataFrame: return df -def docstring_copier(*sub): - """ - Wrap the TableOne docstring (not ideal :/) - """ - def dec(obj): - obj.__doc__ = obj.__doc__.format(*sub) - return obj - return dec - - class TableOne: """ @@ -266,7 +260,7 @@ def __init__(self, data: pd.DataFrame, # set display options if display_all: - self._set_display_options() + set_display_options() def __str__(self) -> str: return self.tableone.to_string() + self._generate_remarks('\n') @@ -436,23 +430,6 @@ def create_intermediate_tables(self, data): self.smd_table, self._groupby) - def _set_display_options(self): - """ - Set pandas display options. Display all rows and columns by default. - """ - display_options = {'display.max_rows': None, - 'display.max_columns': None, - 'display.width': None, - 'display.max_colwidth': None} - - for k in display_options: - try: - pd.set_option(k, display_options[k]) - except ValueError: - msg = """Newer version of Pandas required to set the '{}' - option.""".format(k) - warnings.warn(msg) - def tabulate(self, headers=None, tablefmt='grid', **kwargs) -> str: """ Pretty-print tableone data. Wrapper for the Python 'tabulate' library. @@ -600,174 +577,6 @@ def _combine_tables(self): return table - def _sort_and_reindex(self, table): - """ - Sorts and reindexes the table to meet requirements. - """ - # sort the table rows - sort_columns = ['Missing', 'P-Value', 'P-Value (adjusted)', 'Test'] - - if self._smd and self.smd_table is not None: - sort_columns = sort_columns + list(self.smd_table.columns) - - if self._sort and isinstance(self._sort, bool): - new_index = sorted(table.index.values, key=lambda x: x[0].lower()) - elif self._sort and isinstance(self._sort, str) and (self._sort in - sort_columns): - try: - new_index = table.sort_values(self._sort).index - except KeyError: - new_index = sorted(table.index.values, - key=lambda x: self._columns.index(x[0])) - warnings.warn('Sort variable not found: {}'.format(self._sort)) - elif self._sort and isinstance(self._sort, str) and (self._sort not in - sort_columns): - new_index = sorted(table.index.values, - key=lambda x: self._columns.index(x[0])) - warnings.warn('Sort must be in the following ' + - 'list: {}.'.format(self._sort)) - else: - # sort by the columns argument - new_index = sorted(table.index.values, - key=lambda x: self._columns.index(x[0])) - table = table.reindex(new_index) - - return table - - def _format_values(self, table): - """ - Formats the numerical values in the table, specifically focusing on the p value - and SMD (Standardized Mean Differences) columns. It applies rounding and - converts numbers to strings for better presentation. - """ - table = self._format_pvalues(table) - table = self._format_smd_columns(table) - return table - - def _format_pvalues(self, table): - """ - Formats the p value columns, applying rounding rules and adding - significance markers based on defined thresholds. - """ - # round pval column and convert to string - if self._pval and self._pval_adjust: - if self._pval_threshold: - asterisk_mask = table['P-Value (adjusted)'] < self._pval_threshold - - table['P-Value (adjusted)'] = table['P-Value (adjusted)'].apply( - '{:.3f}'.format).astype(str) - table.loc[table['P-Value (adjusted)'] == '0.000', - 'P-Value (adjusted)'] = '<0.001' - - if self._pval_threshold: - table.loc[asterisk_mask, 'P-Value (adjusted)'] = table['P-Value (adjusted)'][asterisk_mask].astype(str)+"*" # type: ignore - - elif self._pval: - if self._pval_threshold: - asterisk_mask = table['P-Value'] < self._pval_threshold - - table['P-Value'] = table['P-Value'].apply( - '{:.3f}'.format).astype(str) - table.loc[table['P-Value'] == '0.000', 'P-Value'] = '<0.001' - - if self._pval_threshold: - table.loc[asterisk_mask, 'P-Value'] = table['P-Value'][asterisk_mask].astype(str)+"*" # type: ignore - - return table - - def _format_smd_columns(self, table): - """ - Formats the SMD (Standardized Mean Differences) columns. Rounds the SMD values - and ensures they are presented as strings. - """ - # round smd columns and convert to string - if self._smd and self.smd_table is not None: - for c in list(self.smd_table.columns): - table[c] = table[c].apply('{:.3f}'.format).astype(str) - table.loc[table[c] == '0.000', c] = '<0.001' - - return table - - def _apply_order(self, table): - """ - Applies a predefined order to rows based on specified requirements. - May include reordering based on categorical group levels or other criteria. - """ - # if an order is specified, apply it - if self._order: - for k in self._order: - - # Skip if the variable isn't present - try: - all_var = table.loc[k].index.unique(level='value') - except KeyError: - if k not in self._groupby: # type: ignore - warnings.warn("Order variable not found: {}".format(k)) - continue - - # Remove value from order if it is not present - if [i for i in self._order[k] if i not in all_var]: - rm_var = [i for i in self._order[k] if i not in all_var] - self._order[k] = [i for i in self._order[k] - if i in all_var] - warnings.warn(("Order value not found: " - "{}: {}").format(k, rm_var)) - - new_seq = [(k, '{}'.format(v)) for v in self._order[k]] - new_seq += [(k, '{}'.format(v)) for v in all_var - if v not in self._order[k]] - - # restructure to match the original idx - new_idx_array = np.empty((len(new_seq),), dtype=object) - new_idx_array[:] = [tuple(i) for i in new_seq] - orig_idx = table.index.values.copy() - orig_idx[table.index.get_loc(k)] = new_idx_array - table = table.reindex(orig_idx) - - return table - - def _apply_limits(self, table, data): - """ - Applies limits to the number of categories shown for each categorical variable - in the DataFrame, based on specified requirements. - """ - # set the limit on the number of categorical variables - if self._limit: - levelcounts = data[self._categorical].nunique() - - for k, _ in levelcounts.items(): - # set the limit for the variable - if (isinstance(self._limit, int) - and levelcounts[k] >= self._limit): - limit = self._limit - elif isinstance(self._limit, dict) and k in self._limit: - limit = self._limit[k] - else: - continue - - if not self._order or (self._order and k not in self._order): - # re-order the variables by frequency - count = data[k].value_counts().sort_values(ascending=False) - new_idx = [(k, '{}'.format(i)) for i in count.index] - else: - # apply order - all_var = table.loc[k].index.unique(level='value') - new_idx = [(k, '{}'.format(v)) for v in self._order[k]] - new_idx += [(k, '{}'.format(v)) for v in all_var - if v not in self._order[k]] - - # restructure to match the original idx - new_idx_array = np.empty((len(new_idx),), dtype=object) - new_idx_array[:] = [tuple(i) for i in new_idx] - orig_idx = table.index.values.copy() - orig_idx[table.index.get_loc(k)] = new_idx_array - table = table.reindex(orig_idx) - - # drop the rows > the limit - table = table.drop(new_idx_array[limit:]) # type: ignore - - return table - def _insert_n_row(self, table, data): """ Inserts a row that shows 'n', the total number or count of items @@ -795,64 +604,18 @@ def _insert_n_row(self, table, data): return table - def _mask_duplicate_values(self, table, optional_columns): - """ - Masks duplicate values, ensuring that repeated values (e.g. counts of - missing values) are only displayed once. - """ - # only display data in first level row - dupe_mask = table.groupby(level=[0]).cumcount().ne(0) # type: ignore - dupe_columns = ['Missing'] - - if self._smd and self.smd_table is not None: - optional_columns = optional_columns + list(self.smd_table.columns) - for col in optional_columns: - if col in table.columns.values: - dupe_columns.append(col) - - table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('') - - return table - def _apply_alt_labels(self, table): """ Applies alternative labels to the variables if required. """ # display alternative labels if assigned - table = table.rename(index=self._create_row_labels(), level=0) - - return table - - def _reorder_columns(self, table, optional_columns): - """ - Reorder columns for consistent, predictable formatting. - """ - if self._groupby and self._order and (self._groupby in self._order): - header = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore - cols = self._order[self._groupby] + ['{}'.format(v) - for v in header - if v not in - self._order[self._groupby]] - elif self._groupby: - cols = ['{}'.format(v) for v in table.columns.levels[1].values] # type: ignore - else: - cols = ['{}'.format(v) for v in table.columns.values] - - if self._groupby and self._overall: - cols = ['Overall'] + [x for x in cols if x != 'Overall'] - - if 'Missing' in cols: - cols = ['Missing'] + [x for x in cols if x != 'Missing'] - - # move optional_columns to the end of the dataframe - for col in optional_columns: - if col in cols: - cols = [x for x in cols if x != col] + [col] - - if self._groupby: - table = table.reindex(cols, axis=1, level=1) - else: - table = table.reindex(cols, axis=1) + table = table.rename(index=create_row_labels(self._columns, + self._alt_labels, + self._label_suffix, + self._nonnormal, + self._min_max, + self._categorical + ), level=0) return table @@ -890,12 +653,13 @@ def _create_tableone(self, data): table = table.reset_index().set_index(['variable', 'value']) # type: ignore table.columns = table.columns.values.astype(str) - table = self._sort_and_reindex(table) - table = self._format_values(table) - table = self._apply_order(table) - table = self._apply_limits(table, data) + table = sort_and_reindex(table, self._smd, self.smd_table, self._sort, self._columns) + table = format_pvalues(table, self._pval, self._pval_adjust, self._pval_threshold) + table = format_smd_columns(table, self._smd, self.smd_table) + table = apply_order(table, self._order, self._groupby) + table = apply_limits(table, data, self._limit, self._categorical, self._order) table = self._insert_n_row(table, data) - table = self._mask_duplicate_values(table, optional_columns) + table = mask_duplicate_values(table, optional_columns, self._smd, self.smd_table) # remove unwanted columns if not self._isnull: @@ -909,7 +673,7 @@ def _create_tableone(self, data): table = self._add_groupby_columns(table) table = self._apply_alt_labels(table) - table = self._reorder_columns(table, optional_columns) + table = reorder_columns(table, optional_columns, self._groupby, self._order, self._overall) try: if 'Missing' in self._alt_labels or 'Overall' in self._alt_labels: # type: ignore @@ -922,49 +686,6 @@ def _create_tableone(self, data): return table - def _create_row_labels(self) -> dict: - """ - Take the original labels for rows. Rename if alternative labels are - provided. Append label suffix if label_suffix is True. - - Returns - ---------- - labels : dictionary - Dictionary, keys are original column name, values are final label. - - """ - # start with the original column names - labels = {} - for c in self._columns: - labels[c] = c - - # replace column names with alternative names if provided - if self._alt_labels: - for k in self._alt_labels.keys(): - labels[k] = self._alt_labels[k] - - # append the label suffix - if self._label_suffix: - for k in labels.keys(): - if k in self._nonnormal: - if self._min_max and k in self._min_max: - labels[k] = "{}, {}".format(labels[k], - "median [min,max]") - else: - labels[k] = "{}, {}".format(labels[k], - "median [Q1,Q3]") - elif k in self._categorical: - labels[k] = "{}, {}".format(labels[k], "n (%)") - else: - if self._min_max and k in self._min_max: - labels[k] = "{}, {}".format(labels[k], - "mean [min,max]") - else: - labels[k] = "{}, {}".format(labels[k], - "mean (SD)") - - return labels - # Allow TableOne to be called as a function. # Refactor this out at some point! From a07f7c539f2326595e1a0187f39cadd8ec06518e Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 9 Jul 2024 13:46:15 -0400 Subject: [PATCH 2/2] Refactor to avoid asterisk_mask is possibly unbound error. --- tableone/formatting.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tableone/formatting.py b/tableone/formatting.py index a453f0f..548f048 100644 --- a/tableone/formatting.py +++ b/tableone/formatting.py @@ -42,28 +42,22 @@ def format_pvalues(table, pval, pval_adjust, pval_threshold): """ # round pval column and convert to string if pval and pval_adjust: - if pval_threshold: - asterisk_mask = table['P-Value (adjusted)'] < pval_threshold - - table['P-Value (adjusted)'] = table['P-Value (adjusted)'].apply( - '{:.3f}'.format).astype(str) + table['P-Value (adjusted)'] = table['P-Value (adjusted)'].apply('{:.3f}'.format).astype(str) table.loc[table['P-Value (adjusted)'] == '0.000', 'P-Value (adjusted)'] = '<0.001' if pval_threshold: + asterisk_mask = table['P-Value (adjusted)'] < pval_threshold table.loc[asterisk_mask, 'P-Value (adjusted)'] = ( table['P-Value (adjusted)'][asterisk_mask].astype(str)+"*" # type: ignore ) elif pval: - if pval_threshold: - asterisk_mask = table['P-Value'] < pval_threshold - - table['P-Value'] = table['P-Value'].apply( - '{:.3f}'.format).astype(str) + table['P-Value'] = table['P-Value'].apply('{:.3f}'.format).astype(str) table.loc[table['P-Value'] == '0.000', 'P-Value'] = '<0.001' if pval_threshold: + asterisk_mask = table['P-Value'] < pval_threshold table.loc[asterisk_mask, 'P-Value'] = ( table['P-Value'][asterisk_mask].astype(str)+"*" # type: ignore )