Skip to content

Commit

Permalink
Fix Index.difference to match with pandas (#14053)
Browse files Browse the repository at this point in the history
This PR fixes `Index.difference` in following ways:

- [x] Fixes `name` preservation by correctly evaluating the name of two input objects, closes #14019
- [x] Fixes `is_mixed_with_object_dtype` handling that will resolve incorrect results for `CategoricalIndex`, closes #14022
- [x] Raises errors for invalid input types, the error messages are an exact match to pandas error messages for parity.
- [x] Introduce a `Range._try_reconstruct_range_index` that will try to re-construct a `RangeIndex` out of an `Int..Index` to save memory- this is on parity with pandas. closes #14013

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #14053
  • Loading branch information
galipremsagar authored Sep 8, 2023
1 parent e43809e commit 01730c4
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cudf.core.column import ColumnBase, column
from cudf.core.column_accessor import ColumnAccessor
from cudf.utils import ioutils
from cudf.utils.dtypes import is_mixed_with_object_dtype
from cudf.utils.dtypes import can_convert_to_column, is_mixed_with_object_dtype
from cudf.utils.utils import _is_same_name


Expand Down Expand Up @@ -935,13 +935,21 @@ def difference(self, other, sort=None):
>>> idx1.difference(idx2, sort=False)
Int64Index([2, 1], dtype='int64')
"""
if not can_convert_to_column(other):
raise TypeError("Input must be Index or array-like")

if sort not in {None, False}:
raise ValueError(
f"The 'sort' keyword only takes the values "
f"of None or False; {sort} was passed."
)

other = cudf.Index(other)
other = cudf.Index(other, name=getattr(other, "name", self.name))

if not len(other):
return self._get_reconciled_name_object(other)
elif self.equals(other):
return self[:0]._get_reconciled_name_object(other)

res_name = _get_result_name(self.name, other.name)

Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,28 @@ def _intersection(self, other, sort=False):

return new_index

@_cudf_nvtx_annotate
def difference(self, other, sort=None):
if isinstance(other, RangeIndex) and self.equals(other):
return self[:0]._get_reconciled_name_object(other)

return self._try_reconstruct_range_index(
super().difference(other, sort=sort)
)

def _try_reconstruct_range_index(self, index):
if isinstance(index, RangeIndex) or index.dtype.kind == "f":
return index
# Evenly spaced values can return a
# RangeIndex instead of a materialized Index.
if not index._column.has_nulls():
uniques = cupy.unique(cupy.diff(index.values))
if len(uniques) == 1 and uniques[0].get() != 0:
diff = uniques[0].get()
new_range = range(index[0], index[-1] + diff, diff)
return type(self)(new_range, name=index.name)
return index

def sort_values(
self,
return_indexer=False,
Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,10 @@ def test_index_to_series(data):
["5", "6", "2", "a", "b", "c"],
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[1.0, 5.0, 6.0, 0.0, 1.3],
["ab", "cd", "ef"],
pd.Series(["1", "2", "a", "3", None], dtype="category"),
range(0, 10),
[],
],
)
@pytest.mark.parametrize(
Expand All @@ -799,8 +803,11 @@ def test_index_to_series(data):
[10, 20, 30, 40, 50, 60],
["1", "2", "3", "4", "5", "6"],
["5", "6", "2", "a", "b", "c"],
["ab", "ef", None],
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[1.0, 5.0, 6.0, 0.0, 1.3],
range(2, 4),
pd.Series(["1", "a", "3", None], dtype="category"),
[],
],
)
Expand All @@ -818,9 +825,23 @@ def test_index_difference(data, other, sort, name_data, name_other):

expected = pd_data.difference(pd_other, sort=sort)
actual = gd_data.difference(gd_other, sort=sort)

assert_eq(expected, actual)


@pytest.mark.parametrize("other", ["a", 1, None])
def test_index_difference_invalid_inputs(other):
pdi = pd.Index([1, 2, 3])
gdi = cudf.Index([1, 2, 3])

assert_exceptions_equal(
pdi.difference,
gdi.difference,
([other], {}),
([other], {}),
)


def test_index_difference_sort_error():
pdi = pd.Index([1, 2, 3])
gdi = cudf.Index([1, 2, 3])
Expand Down
5 changes: 5 additions & 0 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ def get_min_float_dtype(col):


def is_mixed_with_object_dtype(lhs, rhs):
if cudf.api.types.is_categorical_dtype(lhs.dtype):
return is_mixed_with_object_dtype(lhs.dtype.categories, rhs)
elif cudf.api.types.is_categorical_dtype(rhs.dtype):
return is_mixed_with_object_dtype(lhs, rhs.dtype.categories)

return (lhs.dtype == "object" and rhs.dtype != "object") or (
rhs.dtype == "object" and lhs.dtype != "object"
)
Expand Down

0 comments on commit 01730c4

Please sign in to comment.