Skip to content

Commit

Permalink
GH-45167: [C++] Implement Compute Equals for List Types
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Jan 21, 2025
1 parent e434536 commit 68bb513
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
49 changes: 49 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value ||
static T LogicalValue(PhysicalType value) { return value; }
};

template <typename Type>
struct GetViewType<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;

static T LogicalValue(T value) { return value; }
};

template <>
struct GetViewType<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -322,6 +329,26 @@ struct ArrayIterator<Type, enable_if_base_binary<Type>> {
}
};

template <typename Type>
struct ArrayIterator<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
using ArrayT = typename TypeTraits<Type>::ArrayType;
using offset_type = typename Type::offset_type;

const ArraySpan& arr;
int64_t position;

explicit ArrayIterator(const ArraySpan& arr) : arr(arr), position(0) {}

T operator()() {
const auto array_ptr = arr.ToArray();
const auto array = checked_cast<const ArrayT*>(array_ptr.get());

T result{array->value_slice(position++)};
return result;
}
};

template <>
struct ArrayIterator<FixedSizeBinaryType> {
const ArraySpan& arr;
Expand Down Expand Up @@ -390,6 +417,12 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
}
};

template <typename Type>
struct UnboxScalar<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
static const T& Unbox(const Scalar& val) { return checked_cast<const T&>(val); }
};

template <>
struct UnboxScalar<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -1383,6 +1416,22 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
}
}

// Generate a kernel given a templated functor for list types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateList(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::LIST:
return Generator<Type0, ListType, Args...>::Exec;
case Type::LARGE_LIST:
return Generator<Type0, LargeListType, Args...>::Exec;
default:
DCHECK(false);
return nullptr;
}
}

// END of kernel generator-dispatchers
// ----------------------------------------------------------------------
// BEGIN of DispatchBest helpers
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
}

if constexpr (std::is_same_v<Op, Equal> || std::is_same_v<Op, NotEqual>) {
for (const auto id : {Type::LIST, Type::LARGE_LIST}) {
auto exec = GenerateList<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
}
}

return func;
}

Expand Down
93 changes: 93 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,99 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
}
}

template <typename ArrowType>
class TestCompareList : public ::testing::Test {};
TYPED_TEST_SUITE(TestCompareList, ListArrowTypes);

TYPED_TEST(TestCompareList, ArrayScalar) {
const auto int_value_typ = std::make_shared<Int32Type>();
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
const auto bin_value_typ = std::make_shared<StringType>();
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));

const std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
const auto lhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
const auto lhs_bin = ArrayFromJSON(
bin_ty, R"([["a", "b", "c"], ["foo", "bar", "baz"], ["hello"], null])");
const auto rhs_int = ScalarFromJSON(int_ty, R"([1, 2, 3])");
const auto rhs_bin = ScalarFromJSON(bin_ty, R"(["a", "b", "c"])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ScalarArray) {
const auto int_value_typ = std::make_shared<Int32Type>();
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
const auto bin_value_typ = std::make_shared<StringType>();
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));

const std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
const auto lhs_int = ScalarFromJSON(int_ty, R"([1, 2, 3])");
const auto lhs_bin = ScalarFromJSON(bin_ty, R"(["a", "b", "c"])");
const auto rhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
const auto rhs_bin = ArrayFromJSON(
bin_ty, R"([["a", "b", "c"], ["foo", "bar"], ["baz", "hello", "world"], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ArrayArray) {
const auto int_value_typ = std::make_shared<Int32Type>();
const auto int_ty = std::make_shared<TypeParam>(std::move(int_value_typ));
const auto bin_value_typ = std::make_shared<StringType>();
const auto bin_ty = std::make_shared<TypeParam>(std::move(bin_value_typ));

const std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
const auto lhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5, 6], [7], null])");
const auto lhs_bin = ArrayFromJSON(
bin_ty, R"([["a", "b", "c"], ["foo", "bar", "baz"], ["hello"], null])");
const auto rhs_int = ArrayFromJSON(int_ty, R"([[1, 2, 3], [4, 5], [6, 7, 8], null])");
const auto rhs_bin = ArrayFromJSON(
bin_ty, R"([["a", "b", "c"], ["foo", "bar"], ["baz", "hello", "world"], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([])"),
ArrayFromJSON(int_ty, R"([])"), ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([null])"),
ArrayFromJSON(int_ty, R"([null])"),
ArrayFromJSON(boolean(), "[null]"));

CheckScalarBinary(function, lhs_int, rhs_int, ArrayFromJSON(boolean(), expected));

CheckScalarBinary(function, ArrayFromJSON(bin_ty, R"([])"),
ArrayFromJSON(int_ty, R"([])"), ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(int_ty, R"([null])"),
ArrayFromJSON(bin_ty, R"([null])"),
ArrayFromJSON(boolean(), "[null]"));

CheckScalarBinary(function, lhs_bin, rhs_bin, ArrayFromJSON(boolean(), expected));
}
}

// Helper to organize tests for fixed size binary comparisons
struct CompareCase {
std::shared_ptr<DataType> lhs_type;
Expand Down

0 comments on commit 68bb513

Please sign in to comment.