diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 15318a8d7a465..83c4a30902ffd 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -53,6 +53,8 @@ TEST(FlightIntegration, AuthBasicProto) { ASSERT_OK(RunScenario("auth:basic_prot TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); } +TEST(FlightIntegration, Alignment) { ASSERT_OK(RunScenario("alignment")); } + TEST(FlightIntegration, Ordered) { ASSERT_OK(RunScenario("ordered")); } TEST(FlightIntegration, ExpirationTimeDoGet) { diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index f38076822c778..ec05db867561d 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -45,6 +45,7 @@ #include "arrow/table.h" #include "arrow/table_builder.h" #include "arrow/testing/gtest_util.h" +#include "arrow/util/align_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/string.h" #include "arrow/util/value_parsing.h" @@ -281,6 +282,144 @@ class MiddlewareScenario : public Scenario { std::shared_ptr client_middleware_; }; +/// \brief The server used for testing FlightClient data alignment. +/// +/// The server always returns the same data of various byte widths. +/// The client should return data that is aligned according to the data type +/// if FlightCallOptions.read_options.ensure_memory_alignment is true. +/// +/// This scenario is passed only when the client returns aligned data. +class AlignmentServer : public FlightServerBase { + Status GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& descriptor, + std::unique_ptr* result) override { + auto schema = BuildSchema(); + std::vector endpoints{FlightEndpoint{{"foo"}, {}, std::nullopt, ""}}; + ARROW_ASSIGN_OR_RAISE( + auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false)); + *result = std::make_unique(info); + return Status::OK(); + } + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* stream) override { + ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make( + BuildSchema(), arrow::default_memory_pool())); + if (request.ticket == "foo") { + auto int32_builder = builder->GetFieldAs(0); + ARROW_RETURN_NOT_OK(int32_builder->Append(1)); + ARROW_RETURN_NOT_OK(int32_builder->Append(2)); + ARROW_RETURN_NOT_OK(int32_builder->Append(3)); + auto int64_builder = builder->GetFieldAs(1); + ARROW_RETURN_NOT_OK(int64_builder->Append(1l)); + ARROW_RETURN_NOT_OK(int64_builder->Append(2l)); + ARROW_RETURN_NOT_OK(int64_builder->Append(3l)); + auto bool_builder = builder->GetFieldAs(2); + ARROW_RETURN_NOT_OK(bool_builder->Append(false)); + ARROW_RETURN_NOT_OK(bool_builder->Append(true)); + ARROW_RETURN_NOT_OK(bool_builder->Append(false)); + } else { + return Status::KeyError("Could not find flight: ", request.ticket); + } + ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush()); + std::vector> record_batches{record_batch}; + ARROW_ASSIGN_OR_RAISE(auto record_batch_reader, + RecordBatchReader::Make(record_batches)); + *stream = std::make_unique(record_batch_reader); + return Status::OK(); + } + + private: + std::shared_ptr BuildSchema() { + return arrow::schema({ + arrow::field("int32", arrow::int32(), false), + arrow::field("int64", arrow::int64(), false), + arrow::field("bool", arrow::boolean(), false), + }); + } +}; + +/// \brief The alignment scenario. +/// +/// This tests that the client provides aligned data if requested. +class AlignmentScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new AlignmentServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + arrow::Result> GetTable(FlightClient* client, + const FlightCallOptions& call_options) { + ARROW_ASSIGN_OR_RAISE(auto info, + client->GetFlightInfo(FlightDescriptor::Command("alignment"))); + std::vector> tables; + for (const auto& endpoint : info->endpoints()) { + if (!endpoint.locations.empty()) { + std::stringstream ss; + ss << "["; + for (const auto& location : endpoint.locations) { + if (ss.str().size() != 1) { + ss << ", "; + } + ss << location.ToString(); + } + ss << "]"; + return Status::Invalid( + "Expected to receive empty locations to use the original service: ", + ss.str()); + } + ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(call_options, endpoint.ticket)); + ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable()); + tables.push_back(table); + } + return ConcatenateTables(tables); + } + + Status RunClient(std::unique_ptr client) override { + for (bool ensure_alignment : {true, false}) { + auto call_options = FlightCallOptions(); + call_options.read_options.ensure_memory_alignment = ensure_alignment; + ARROW_ASSIGN_OR_RAISE(auto table, GetTable(client.get(), call_options)); + + // Check read data + auto expected_row_count = 3; + if (table->num_rows() != expected_row_count) { + return Status::Invalid("Read table size isn't expected\n", "Expected rows:\n", + expected_row_count, "Actual rows:\n", table->num_rows()); + } + auto expected_column_count = 3; + if (table->num_columns() != expected_column_count) { + return Status::Invalid("Read table size isn't expected\n", "Expected columns:\n", + expected_column_count, "Actual columns:\n", + table->num_columns()); + } + // Check data alignment + std::vector needs_alignment; + if (ensure_alignment) { + // with ensure_alignment=true, we require data to be aligned + if (!util::CheckAlignment(*table, arrow::util::kValueAlignment, + &needs_alignment)) { + return Status::Invalid("Read table has unaligned data"); + } + } else { + // this is not a requirement but merely an observation: + // with ensure_alignment=false, flight client returns mis-aligned data + // if this is not the case any more, feel free to remove this assertion + if (util::CheckAlignment(*table, arrow::util::kValueAlignment, + &needs_alignment)) { + return Status::Invalid( + "Read table has aligned data, which is good, but unprecedented"); + } + } + } + + return Status::OK(); + } +}; + /// \brief The server used for testing FlightInfo.ordered. /// /// If the given command is "ordered", the server sets @@ -2382,6 +2521,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "middleware") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "alignment") { + *out = std::make_shared(); + return Status::OK(); } else if (scenario_name == "ordered") { *out = std::make_shared(); return Status::OK(); diff --git a/cpp/src/arrow/ipc/options.h b/cpp/src/arrow/ipc/options.h index 48b6758212bd5..e6a690b3590d2 100644 --- a/cpp/src/arrow/ipc/options.h +++ b/cpp/src/arrow/ipc/options.h @@ -161,6 +161,11 @@ struct ARROW_EXPORT IpcReadOptions { /// RecordBatchStreamReader and StreamDecoder classes. bool ensure_native_endian = true; + /// \brief Whether to align incoming data if mis-aligned + /// + /// Received mis-aligned data is copied to aligned memory locations. + bool ensure_memory_alignment = true; + /// \brief Options to control caching behavior when pre-buffering is requested /// /// The lazy property will always be reset to true to deliver the expected behavior diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 98214c1debb86..12d682f59c2e3 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -47,6 +47,7 @@ #include "arrow/table.h" #include "arrow/type.h" #include "arrow/type_traits.h" +#include "arrow/util/align_util.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/checked_cast.h" @@ -636,8 +637,13 @@ Result> LoadRecordBatchSubset( arrow::internal::SwapEndianArrayData(filtered_column)); } } - return RecordBatch::Make(std::move(filtered_schema), metadata->length(), - std::move(filtered_columns)); + auto batch = RecordBatch::Make(std::move(filtered_schema), metadata->length(), + std::move(filtered_columns)); + if (context.options.ensure_memory_alignment) { + return util::EnsureAlignment(batch, arrow::util::kValueAlignment, + context.options.memory_pool); + } + return batch; } Result> LoadRecordBatch( diff --git a/cpp/src/arrow/util/align_util.cc b/cpp/src/arrow/util/align_util.cc index a327afa7a5cc3..ef224ebbd9255 100644 --- a/cpp/src/arrow/util/align_util.cc +++ b/cpp/src/arrow/util/align_util.cc @@ -19,6 +19,7 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" +#include "arrow/extension_type.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/type_fwd.h" @@ -28,6 +29,8 @@ namespace arrow { +using internal::checked_pointer_cast; + namespace util { bool CheckAlignment(const Buffer& buffer, int64_t alignment) { @@ -44,9 +47,13 @@ namespace { Type::type GetTypeForBuffers(const ArrayData& array) { Type::type type_id = array.type->storage_id(); if (type_id == Type::DICTIONARY) { - return ::arrow::internal::checked_pointer_cast(array.type) - ->index_type() - ->id(); + // return index type id, provided by the DictionaryType array.type or + // array.type->storage_type() if array.type is an ExtensionType + std::shared_ptr dict_type = array.type; + if (array.type->id() == Type::EXTENSION) { + dict_type = checked_pointer_cast(array.type)->storage_type(); + } + return checked_pointer_cast(dict_type)->index_type()->id(); } return type_id; } diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b2edeb0b4192f..8fe1f5ea9e10f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1855,6 +1855,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: vector[int] included_fields c_bool use_threads c_bool ensure_native_endian + c_bool ensure_memory_alignment @staticmethod CIpcReadOptions Defaults() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index e15b0ea40ed2e..60f0ef61dfcaa 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -120,6 +120,8 @@ cdef class IpcReadOptions(_Weakrefable): ---------- ensure_native_endian : bool, default True Whether to convert incoming data to platform-native endianness. + ensure_memory_alignment : bool, default True + Whether to align incoming data if mis-aligned. use_threads : bool Whether to use the global CPU thread pool to parallelize any computational tasks like decompression @@ -133,9 +135,11 @@ cdef class IpcReadOptions(_Weakrefable): # cdef block is in lib.pxd def __init__(self, *, bint ensure_native_endian=True, + bint ensure_memory_alignment=True, bint use_threads=True, list included_fields=None): self.c_options = CIpcReadOptions.Defaults() self.ensure_native_endian = ensure_native_endian + self.ensure_memory_alignment = ensure_memory_alignment self.use_threads = use_threads if included_fields is not None: self.included_fields = included_fields @@ -148,6 +152,14 @@ cdef class IpcReadOptions(_Weakrefable): def ensure_native_endian(self, bint value): self.c_options.ensure_native_endian = value + @property + def ensure_memory_alignment(self): + return self.c_options.ensure_memory_alignment + + @ensure_memory_alignment.setter + def ensure_memory_alignment(self, bint value): + self.c_options.ensure_memory_alignment = value + @property def use_threads(self): return self.c_options.use_threads diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 4be5792a92f6d..74a4f0fc514b1 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -548,11 +548,15 @@ def test_read_options(): options = pa.ipc.IpcReadOptions() assert options.use_threads is True assert options.ensure_native_endian is True + assert options.ensure_memory_alignment is True assert options.included_fields == [] options.ensure_native_endian = False assert options.ensure_native_endian is False + options.ensure_memory_alignment = False + assert options.ensure_memory_alignment is False + options.use_threads = False assert options.use_threads is False @@ -564,10 +568,11 @@ def test_read_options(): options = pa.ipc.IpcReadOptions( use_threads=False, ensure_native_endian=False, - included_fields=[1] + ensure_memory_alignment=False, included_fields=[1] ) assert options.use_threads is False assert options.ensure_native_endian is False + assert options.ensure_memory_alignment is False assert options.included_fields == [1]