From 88ccc48103ec17dfb6f35b934280f4bbfcabd21f Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Tue, 19 Sep 2023 07:55:22 -0700 Subject: [PATCH] Add helper method --- python/cudf/cudf/core/column_accessor.py | 7 ++++++ python/cudf/cudf/core/dataframe.py | 31 +++++++++--------------- python/cudf/cudf/core/indexed_frame.py | 6 ++--- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/python/cudf/cudf/core/column_accessor.py b/python/cudf/cudf/core/column_accessor.py index cb79a30422e..758f0115a6f 100644 --- a/python/cudf/cudf/core/column_accessor.py +++ b/python/cudf/cudf/core/column_accessor.py @@ -82,6 +82,13 @@ def _to_flat_dict(d): return {k: v for k, v in _to_flat_dict_inner(d)} +def _get_level_names(obj, default=None): + """ + Helper function to return names for `level_names` + """ + return tuple(obj.names) if isinstance(obj, pd.Index) else default + + class ColumnAccessor(abc.MutableMapping): """ Parameters diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index fc9721066a8..d9c0fdba93f 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -68,7 +68,7 @@ column_empty, concat_columns, ) -from cudf.core.column_accessor import ColumnAccessor +from cudf.core.column_accessor import ColumnAccessor, _get_level_names from cudf.core.copy_types import BooleanMask from cudf.core.groupby.groupby import DataFrameGroupBy, groupby_doc_template from cudf.core.index import BaseIndex, RangeIndex, _index_from_data, as_index @@ -666,9 +666,7 @@ def __init__( ) for k in columns }, - level_names=tuple(columns.names) - if isinstance(columns, pd.Index) - else None, + level_names=_get_level_names(columns), ) elif isinstance(data, ColumnAccessor): raise TypeError( @@ -715,10 +713,8 @@ def __init__( self._data = new_df._data self._index = new_df._index - self._data._level_names = ( - tuple(columns.names) - if isinstance(columns, pd.Index) - else self._data._level_names + self._data._level_names = _get_level_names( + columns, self._data._level_names ) elif len(data) > 0 and isinstance(data[0], Series): self._init_from_series_list( @@ -842,10 +838,8 @@ def _init_from_series_list(self, data, columns, index): self._data[col_name] = column.column_empty( row_count=len(self), dtype=None, masked=True ) - self._data._level_names = ( - tuple(columns.names) - if isinstance(columns, pd.Index) - else self._data._level_names + self._data._level_names = _get_level_names( + columns, self._data._level_names ) self._data = self._data.select_by_label(columns) @@ -970,10 +964,8 @@ def _init_from_dict_like( data[col_name], nan_as_null=nan_as_null, ) - self._data._level_names = ( - tuple(columns.names) - if isinstance(columns, pd.Index) - else self._data._level_names + self._data._level_names = _get_level_names( + columns, self._data._level_names ) @classmethod @@ -5390,8 +5382,8 @@ def from_records(cls, data, index=None, columns=None, nan_as_null=False): df = df.set_index(index) else: df._index = as_index(index) - if isinstance(columns, pd.Index): - df._data._level_names = tuple(columns.names) + + df._data._level_names = _get_level_names(columns) return df @classmethod @@ -5448,8 +5440,7 @@ def _from_arrays(cls, data, index=None, columns=None, nan_as_null=False): df._data[names[0]] = column.as_column( data, nan_as_null=nan_as_null ) - if isinstance(columns, pd.Index): - df._data._level_names = tuple(columns.names) + df._data._level_names = _get_level_names(columns) if index is None: df._index = RangeIndex(start=0, stop=len(data)) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 1796baa6147..b3ecfd5eb15 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -52,7 +52,7 @@ from cudf.core._base_index import BaseIndex from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ColumnBase, as_column, full -from cudf.core.column_accessor import ColumnAccessor +from cudf.core.column_accessor import ColumnAccessor, _get_level_names from cudf.core.copy_types import BooleanMask, GatherMap from cudf.core.dtypes import ListDtype from cudf.core.frame import Frame @@ -2661,9 +2661,7 @@ def _reindex( data=cudf.core.column_accessor.ColumnAccessor( cols, multiindex=self._data.multiindex, - level_names=tuple(column_names.names) - if isinstance(column_names, pd.Index) - else None, + level_names=_get_level_names(column_names), ), index=index, )