From 01730c46a4f403fd5cf9245512c941176eef2428 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Fri, 8 Sep 2023 13:36:48 -0500 Subject: [PATCH] Fix `Index.difference` to match with pandas (#14053) This PR fixes `Index.difference` in following ways: - [x] Fixes `name` preservation by correctly evaluating the name of two input objects, closes #14019 - [x] Fixes `is_mixed_with_object_dtype` handling that will resolve incorrect results for `CategoricalIndex`, closes #14022 - [x] Raises errors for invalid input types, the error messages are an exact match to pandas error messages for parity. - [x] Introduce a `Range._try_reconstruct_range_index` that will try to re-construct a `RangeIndex` out of an `Int..Index` to save memory- this is on parity with pandas. closes #14013 Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/14053 --- python/cudf/cudf/core/_base_index.py | 12 ++++++++++-- python/cudf/cudf/core/index.py | 22 ++++++++++++++++++++++ python/cudf/cudf/tests/test_index.py | 21 +++++++++++++++++++++ python/cudf/cudf/utils/dtypes.py | 5 +++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index 829ca33d8a5..8091f3f7dd2 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -30,7 +30,7 @@ from cudf.core.column import ColumnBase, column from cudf.core.column_accessor import ColumnAccessor from cudf.utils import ioutils -from cudf.utils.dtypes import is_mixed_with_object_dtype +from cudf.utils.dtypes import can_convert_to_column, is_mixed_with_object_dtype from cudf.utils.utils import _is_same_name @@ -935,13 +935,21 @@ def difference(self, other, sort=None): >>> idx1.difference(idx2, sort=False) Int64Index([2, 1], dtype='int64') """ + if not can_convert_to_column(other): + raise TypeError("Input must be Index or array-like") + if sort not in {None, False}: raise ValueError( f"The 'sort' keyword only takes the values " f"of None or False; {sort} was passed." ) - other = cudf.Index(other) + other = cudf.Index(other, name=getattr(other, "name", self.name)) + + if not len(other): + return self._get_reconciled_name_object(other) + elif self.equals(other): + return self[:0]._get_reconciled_name_object(other) res_name = _get_result_name(self.name, other.name) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index c7e25cdc430..4bb5428838f 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -724,6 +724,28 @@ def _intersection(self, other, sort=False): return new_index + @_cudf_nvtx_annotate + def difference(self, other, sort=None): + if isinstance(other, RangeIndex) and self.equals(other): + return self[:0]._get_reconciled_name_object(other) + + return self._try_reconstruct_range_index( + super().difference(other, sort=sort) + ) + + def _try_reconstruct_range_index(self, index): + if isinstance(index, RangeIndex) or index.dtype.kind == "f": + return index + # Evenly spaced values can return a + # RangeIndex instead of a materialized Index. + if not index._column.has_nulls(): + uniques = cupy.unique(cupy.diff(index.values)) + if len(uniques) == 1 and uniques[0].get() != 0: + diff = uniques[0].get() + new_range = range(index[0], index[-1] + diff, diff) + return type(self)(new_range, name=index.name) + return index + def sort_values( self, return_indexer=False, diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 506edd5b3f3..58dbc48e31e 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -789,6 +789,10 @@ def test_index_to_series(data): ["5", "6", "2", "a", "b", "c"], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 5.0, 6.0, 0.0, 1.3], + ["ab", "cd", "ef"], + pd.Series(["1", "2", "a", "3", None], dtype="category"), + range(0, 10), + [], ], ) @pytest.mark.parametrize( @@ -799,8 +803,11 @@ def test_index_to_series(data): [10, 20, 30, 40, 50, 60], ["1", "2", "3", "4", "5", "6"], ["5", "6", "2", "a", "b", "c"], + ["ab", "ef", None], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 5.0, 6.0, 0.0, 1.3], + range(2, 4), + pd.Series(["1", "a", "3", None], dtype="category"), [], ], ) @@ -818,9 +825,23 @@ def test_index_difference(data, other, sort, name_data, name_other): expected = pd_data.difference(pd_other, sort=sort) actual = gd_data.difference(gd_other, sort=sort) + assert_eq(expected, actual) +@pytest.mark.parametrize("other", ["a", 1, None]) +def test_index_difference_invalid_inputs(other): + pdi = pd.Index([1, 2, 3]) + gdi = cudf.Index([1, 2, 3]) + + assert_exceptions_equal( + pdi.difference, + gdi.difference, + ([other], {}), + ([other], {}), + ) + + def test_index_difference_sort_error(): pdi = pd.Index([1, 2, 3]) gdi = cudf.Index([1, 2, 3]) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index ea96a0859ce..e50457b8e7b 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -426,6 +426,11 @@ def get_min_float_dtype(col): def is_mixed_with_object_dtype(lhs, rhs): + if cudf.api.types.is_categorical_dtype(lhs.dtype): + return is_mixed_with_object_dtype(lhs.dtype.categories, rhs) + elif cudf.api.types.is_categorical_dtype(rhs.dtype): + return is_mixed_with_object_dtype(lhs, rhs.dtype.categories) + return (lhs.dtype == "object" and rhs.dtype != "object") or ( rhs.dtype == "object" and lhs.dtype != "object" )