Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-32276: [C++][FlightRPC] Align RecordBatch buffers given to IPC #44279

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ TEST(FlightIntegration, AuthBasicProto) { ASSERT_OK(RunScenario("auth:basic_prot

TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); }

TEST(FlightIntegration, Alignment) { ASSERT_OK(RunScenario("alignment")); }

TEST(FlightIntegration, Ordered) { ASSERT_OK(RunScenario("ordered")); }

TEST(FlightIntegration, ExpirationTimeDoGet) {
Expand Down
142 changes: 142 additions & 0 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "arrow/table.h"
#include "arrow/table_builder.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/align_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
Expand Down Expand Up @@ -281,6 +282,144 @@ class MiddlewareScenario : public Scenario {
std::shared_ptr<TestClientMiddlewareFactory> client_middleware_;
};

/// \brief The server used for testing FlightClient data alignment.
///
/// The server always returns the same data of various byte widths.
/// The client should return data that is aligned according to the data type
/// if FlightCallOptions.read_options.ensure_memory_alignment is true.
///
/// This scenario is passed only when the client returns aligned data.
class AlignmentServer : public FlightServerBase {
Status GetFlightInfo(const ServerCallContext& context,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* result) override {
auto schema = BuildSchema();
std::vector<FlightEndpoint> endpoints{FlightEndpoint{{"foo"}, {}, std::nullopt, ""}};
ARROW_ASSIGN_OR_RAISE(
auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false));
*result = std::make_unique<FlightInfo>(info);
return Status::OK();
}

Status DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr<FlightDataStream>* stream) override {
ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make(
BuildSchema(), arrow::default_memory_pool()));
if (request.ticket == "foo") {
auto int32_builder = builder->GetFieldAs<Int32Builder>(0);
ARROW_RETURN_NOT_OK(int32_builder->Append(1));
ARROW_RETURN_NOT_OK(int32_builder->Append(2));
ARROW_RETURN_NOT_OK(int32_builder->Append(3));
auto int64_builder = builder->GetFieldAs<Int64Builder>(1);
ARROW_RETURN_NOT_OK(int64_builder->Append(1l));
ARROW_RETURN_NOT_OK(int64_builder->Append(2l));
ARROW_RETURN_NOT_OK(int64_builder->Append(3l));
auto bool_builder = builder->GetFieldAs<BooleanBuilder>(2);
ARROW_RETURN_NOT_OK(bool_builder->Append(false));
ARROW_RETURN_NOT_OK(bool_builder->Append(true));
ARROW_RETURN_NOT_OK(bool_builder->Append(false));
} else {
return Status::KeyError("Could not find flight: ", request.ticket);
}
ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush());
std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
RecordBatchReader::Make(record_batches));
*stream = std::make_unique<RecordBatchStream>(record_batch_reader);
return Status::OK();
}

private:
std::shared_ptr<Schema> BuildSchema() {
return arrow::schema({
arrow::field("int32", arrow::int32(), false),
arrow::field("int64", arrow::int64(), false),
arrow::field("bool", arrow::boolean(), false),
});
}
};

/// \brief The alignment scenario.
///
/// This tests that the client provides aligned data if requested.
class AlignmentScenario : public Scenario {
Status MakeServer(std::unique_ptr<FlightServerBase>* server,
FlightServerOptions* options) override {
server->reset(new AlignmentServer());
return Status::OK();
}

Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }

arrow::Result<std::shared_ptr<Table>> GetTable(FlightClient* client,
const FlightCallOptions& call_options) {
ARROW_ASSIGN_OR_RAISE(auto info,
client->GetFlightInfo(FlightDescriptor::Command("alignment")));
std::vector<std::shared_ptr<arrow::Table>> tables;
for (const auto& endpoint : info->endpoints()) {
if (!endpoint.locations.empty()) {
std::stringstream ss;
ss << "[";
for (const auto& location : endpoint.locations) {
if (ss.str().size() != 1) {
ss << ", ";
}
ss << location.ToString();
}
ss << "]";
return Status::Invalid(
"Expected to receive empty locations to use the original service: ",
ss.str());
}
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(call_options, endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
}
return ConcatenateTables(tables);
}

Status RunClient(std::unique_ptr<FlightClient> client) override {
for (bool ensure_alignment : {true, false}) {
auto call_options = FlightCallOptions();
call_options.read_options.ensure_memory_alignment = ensure_alignment;
ARROW_ASSIGN_OR_RAISE(auto table, GetTable(client.get(), call_options));

// Check read data
auto expected_row_count = 3;
if (table->num_rows() != expected_row_count) {
return Status::Invalid("Read table size isn't expected\n", "Expected rows:\n",
expected_row_count, "Actual rows:\n", table->num_rows());
}
auto expected_column_count = 3;
if (table->num_columns() != expected_column_count) {
return Status::Invalid("Read table size isn't expected\n", "Expected columns:\n",
expected_column_count, "Actual columns:\n",
table->num_columns());
}
// Check data alignment
std::vector<bool> needs_alignment;
if (ensure_alignment) {
// with ensure_alignment=true, we require data to be aligned
if (!util::CheckAlignment(*table, arrow::util::kValueAlignment,
&needs_alignment)) {
return Status::Invalid("Read table has unaligned data");
}
} else {
// this is not a requirement but merely an observation:
// with ensure_alignment=false, flight client returns mis-aligned data
// if this is not the case any more, feel free to remove this assertion
if (util::CheckAlignment(*table, arrow::util::kValueAlignment,
&needs_alignment)) {
return Status::Invalid(
"Read table has aligned data, which is good, but unprecedented");
}
}
}

return Status::OK();
}
};

/// \brief The server used for testing FlightInfo.ordered.
///
/// If the given command is "ordered", the server sets
Expand Down Expand Up @@ -2382,6 +2521,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>*
} else if (scenario_name == "middleware") {
*out = std::make_shared<MiddlewareScenario>();
return Status::OK();
} else if (scenario_name == "alignment") {
*out = std::make_shared<AlignmentScenario>();
return Status::OK();
} else if (scenario_name == "ordered") {
*out = std::make_shared<OrderedScenario>();
return Status::OK();
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/ipc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ struct ARROW_EXPORT IpcReadOptions {
/// RecordBatchStreamReader and StreamDecoder classes.
bool ensure_native_endian = true;

/// \brief Whether to align incoming data if mis-aligned
///
/// Received mis-aligned data is copied to aligned memory locations.
bool ensure_memory_alignment = true;

/// \brief Options to control caching behavior when pre-buffering is requested
///
/// The lazy property will always be reset to true to deliver the expected behavior
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "arrow/table.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/align_util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
Expand Down Expand Up @@ -636,8 +637,13 @@ Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
arrow::internal::SwapEndianArrayData(filtered_column));
}
}
return RecordBatch::Make(std::move(filtered_schema), metadata->length(),
std::move(filtered_columns));
auto batch = RecordBatch::Make(std::move(filtered_schema), metadata->length(),
std::move(filtered_columns));
if (context.options.ensure_memory_alignment) {
return util::EnsureAlignment(batch, arrow::util::kValueAlignment,
context.options.memory_pool);
}
return batch;
}

Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
Expand Down
13 changes: 10 additions & 3 deletions cpp/src/arrow/util/align_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "arrow/array.h"
#include "arrow/chunked_array.h"
#include "arrow/extension_type.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
#include "arrow/type_fwd.h"
Expand All @@ -28,6 +29,8 @@

namespace arrow {

using internal::checked_pointer_cast;

namespace util {

bool CheckAlignment(const Buffer& buffer, int64_t alignment) {
Expand All @@ -44,9 +47,13 @@ namespace {
Type::type GetTypeForBuffers(const ArrayData& array) {
Type::type type_id = array.type->storage_id();
if (type_id == Type::DICTIONARY) {
return ::arrow::internal::checked_pointer_cast<DictionaryType>(array.type)
->index_type()
->id();
// return index type id, provided by the DictionaryType array.type or
// array.type->storage_type() if array.type is an ExtensionType
std::shared_ptr<DataType> dict_type = array.type;
if (array.type->id() == Type::EXTENSION) {
dict_type = checked_pointer_cast<ExtensionType>(array.type)->storage_type();
}
return checked_pointer_cast<DictionaryType>(dict_type)->index_type()->id();
}
return type_id;
}
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
vector[int] included_fields
c_bool use_threads
c_bool ensure_native_endian
c_bool ensure_memory_alignment

@staticmethod
CIpcReadOptions Defaults()
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ cdef class IpcReadOptions(_Weakrefable):
----------
ensure_native_endian : bool, default True
Whether to convert incoming data to platform-native endianness.
ensure_memory_alignment : bool, default True
Whether to align incoming data if mis-aligned.
use_threads : bool
Whether to use the global CPU thread pool to parallelize any
computational tasks like decompression
Expand All @@ -133,9 +135,11 @@ cdef class IpcReadOptions(_Weakrefable):
# cdef block is in lib.pxd

def __init__(self, *, bint ensure_native_endian=True,
bint ensure_memory_alignment=True,
bint use_threads=True, list included_fields=None):
self.c_options = CIpcReadOptions.Defaults()
self.ensure_native_endian = ensure_native_endian
self.ensure_memory_alignment = ensure_memory_alignment
self.use_threads = use_threads
if included_fields is not None:
self.included_fields = included_fields
Expand All @@ -148,6 +152,14 @@ cdef class IpcReadOptions(_Weakrefable):
def ensure_native_endian(self, bint value):
self.c_options.ensure_native_endian = value

@property
def ensure_memory_alignment(self):
return self.c_options.ensure_memory_alignment

@ensure_memory_alignment.setter
def ensure_memory_alignment(self, bint value):
self.c_options.ensure_memory_alignment = value

@property
def use_threads(self):
return self.c_options.use_threads
Expand Down
7 changes: 6 additions & 1 deletion python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,15 @@ def test_read_options():
options = pa.ipc.IpcReadOptions()
assert options.use_threads is True
assert options.ensure_native_endian is True
assert options.ensure_memory_alignment is True
assert options.included_fields == []

options.ensure_native_endian = False
assert options.ensure_native_endian is False

options.ensure_memory_alignment = False
assert options.ensure_memory_alignment is False

options.use_threads = False
assert options.use_threads is False

Expand All @@ -564,10 +568,11 @@ def test_read_options():

options = pa.ipc.IpcReadOptions(
use_threads=False, ensure_native_endian=False,
included_fields=[1]
ensure_memory_alignment=False, included_fields=[1]
)
assert options.use_threads is False
assert options.ensure_native_endian is False
assert options.ensure_memory_alignment is False
assert options.included_fields == [1]


Expand Down
Loading