Skip to content

Commit

Permalink
GH-45167: [C++] Implement Compute Equals for List Types
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Jan 15, 2025
1 parent f4e4ed3 commit 573867e
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
55 changes: 55 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value ||
static T LogicalValue(PhysicalType value) { return value; }
};

template <typename Type>
struct GetViewType<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;

static T LogicalValue(T value) { return value; }
};

template <>
struct GetViewType<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -322,6 +329,32 @@ struct ArrayIterator<Type, enable_if_base_binary<Type>> {
}
};

template <typename Type>
struct ArrayIterator<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
using ArrayT = typename TypeTraits<Type>::ArrayType;

const ArraySpan& arr;
const char* data;
const int32_t width;
int64_t position;

explicit ArrayIterator(const ArraySpan& arr)
: arr(arr),
data(reinterpret_cast<const char*>(arr.buffers[1].data)),
width(arr.type->byte_width()),
position(arr.offset) {}

T operator()() {
// TODO: how cann we avoid the ToArray call
const auto array_ptr = arr.ToArray();
const auto array = checked_cast<const ArrayT*>(array_ptr.get());
auto result = T{array->value_slice(position)};
position++;
return result;
}
};

template <>
struct ArrayIterator<FixedSizeBinaryType> {
const ArraySpan& arr;
Expand Down Expand Up @@ -390,6 +423,12 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
}
};

template <typename Type>
struct UnboxScalar<Type, enable_if_list_type<Type>> {
using T = typename TypeTraits<Type>::ScalarType;
static const T& Unbox(const Scalar& val) { return checked_cast<const T&>(val); }
};

template <>
struct UnboxScalar<Decimal32Type> {
using T = Decimal32;
Expand Down Expand Up @@ -1382,6 +1421,22 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
}
}

// Generate a kernel given a templated functor for list types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateList(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::LIST:
return Generator<Type0, ListType, Args...>::Exec;
case Type::LARGE_LIST:
return Generator<Type0, LargeListType, Args...>::Exec;
default:
DCHECK(false);
return nullptr;
}
}

// END of kernel generator-dispatchers
// ----------------------------------------------------------------------
// BEGIN of DispatchBest helpers
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
}

if constexpr (std::is_same_v<Op, Equal> || std::is_same_v<Op, NotEqual>) {
for (const auto id : {Type::LIST, Type::LARGE_LIST}) {
auto exec = GenerateList<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
}
}

return func;
}

Expand Down
66 changes: 66 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,72 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
}
}

template <typename ArrowType>
class TestCompareList : public ::testing::Test {};
TYPED_TEST_SUITE(TestCompareList, ListArrowTypes);

TYPED_TEST(TestCompareList, ArrayScalar) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
auto rhs = ScalarFromJSON(ty, R"([1, 2, 3])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ScalarArray) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ScalarFromJSON(ty, R"([1, 2, 3])");
auto rhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [42], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareList, ArrayArray) {
auto value_typ = std::make_shared<Int32Type>();
auto ty = std::make_shared<TypeParam>(std::move(value_typ));

std::vector<std::pair<std::string, std::string>> cases = {
{"equal", "[1, 0, 0, null]"},
{"not_equal", "[0, 1, 1, null]"},
};
auto lhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5, 6], [7], null])");
auto rhs = ArrayFromJSON(ty, R"([[1, 2, 3], [4, 5], [6, 7, 8], null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]"));

CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
}
}

// Helper to organize tests for fixed size binary comparisons
struct CompareCase {
std::shared_ptr<DataType> lhs_type;
Expand Down

0 comments on commit 573867e

Please sign in to comment.