Skip to content

Commit

Permalink
Add helper method
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Sep 19, 2023
1 parent ffce0f6 commit 88ccc48
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
7 changes: 7 additions & 0 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def _to_flat_dict(d):
return {k: v for k, v in _to_flat_dict_inner(d)}


def _get_level_names(obj, default=None):
"""
Helper function to return names for `level_names`
"""
return tuple(obj.names) if isinstance(obj, pd.Index) else default


class ColumnAccessor(abc.MutableMapping):
"""
Parameters
Expand Down
31 changes: 11 additions & 20 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
column_empty,
concat_columns,
)
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.column_accessor import ColumnAccessor, _get_level_names
from cudf.core.copy_types import BooleanMask
from cudf.core.groupby.groupby import DataFrameGroupBy, groupby_doc_template
from cudf.core.index import BaseIndex, RangeIndex, _index_from_data, as_index
Expand Down Expand Up @@ -666,9 +666,7 @@ def __init__(
)
for k in columns
},
level_names=tuple(columns.names)
if isinstance(columns, pd.Index)
else None,
level_names=_get_level_names(columns),
)
elif isinstance(data, ColumnAccessor):
raise TypeError(
Expand Down Expand Up @@ -715,10 +713,8 @@ def __init__(

self._data = new_df._data
self._index = new_df._index
self._data._level_names = (
tuple(columns.names)
if isinstance(columns, pd.Index)
else self._data._level_names
self._data._level_names = _get_level_names(
columns, self._data._level_names
)
elif len(data) > 0 and isinstance(data[0], Series):
self._init_from_series_list(
Expand Down Expand Up @@ -842,10 +838,8 @@ def _init_from_series_list(self, data, columns, index):
self._data[col_name] = column.column_empty(
row_count=len(self), dtype=None, masked=True
)
self._data._level_names = (
tuple(columns.names)
if isinstance(columns, pd.Index)
else self._data._level_names
self._data._level_names = _get_level_names(
columns, self._data._level_names
)
self._data = self._data.select_by_label(columns)

Expand Down Expand Up @@ -970,10 +964,8 @@ def _init_from_dict_like(
data[col_name],
nan_as_null=nan_as_null,
)
self._data._level_names = (
tuple(columns.names)
if isinstance(columns, pd.Index)
else self._data._level_names
self._data._level_names = _get_level_names(
columns, self._data._level_names
)

@classmethod
Expand Down Expand Up @@ -5390,8 +5382,8 @@ def from_records(cls, data, index=None, columns=None, nan_as_null=False):
df = df.set_index(index)
else:
df._index = as_index(index)
if isinstance(columns, pd.Index):
df._data._level_names = tuple(columns.names)

df._data._level_names = _get_level_names(columns)
return df

@classmethod
Expand Down Expand Up @@ -5448,8 +5440,7 @@ def _from_arrays(cls, data, index=None, columns=None, nan_as_null=False):
df._data[names[0]] = column.as_column(
data, nan_as_null=nan_as_null
)
if isinstance(columns, pd.Index):
df._data._level_names = tuple(columns.names)
df._data._level_names = _get_level_names(columns)

if index is None:
df._index = RangeIndex(start=0, stop=len(data))
Expand Down
6 changes: 2 additions & 4 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from cudf.core._base_index import BaseIndex
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column import ColumnBase, as_column, full
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.column_accessor import ColumnAccessor, _get_level_names
from cudf.core.copy_types import BooleanMask, GatherMap
from cudf.core.dtypes import ListDtype
from cudf.core.frame import Frame
Expand Down Expand Up @@ -2661,9 +2661,7 @@ def _reindex(
data=cudf.core.column_accessor.ColumnAccessor(
cols,
multiindex=self._data.multiindex,
level_names=tuple(column_names.names)
if isinstance(column_names, pd.Index)
else None,
level_names=_get_level_names(column_names),
),
index=index,
)
Expand Down

0 comments on commit 88ccc48

Please sign in to comment.