diff --git a/cpp/src/arrow/acero/task_util.cc b/cpp/src/arrow/acero/task_util.cc index 85378eaeeb27c..8905208a9a568 100644 --- a/cpp/src/arrow/acero/task_util.cc +++ b/cpp/src/arrow/acero/task_util.cc @@ -212,11 +212,16 @@ std::vector> TaskSchedulerImpl::PickTasks(int num_tasks, Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id, bool* task_group_finished) { + Status status; if (!aborted_) { - RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id)); + status = task_groups_[group_id].task_impl_(thread_id, task_id); } *task_group_finished = PostExecuteTask(thread_id, group_id); - return Status::OK(); + if (*task_group_finished) { + bool all_task_groups_finished = false; + RETURN_NOT_OK(OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); + } + return status; } bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) { @@ -373,11 +378,17 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) bool task_group_finished = false; RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished)); - - if (task_group_finished) { - bool all_task_groups_finished = false; - return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished); - } + // if (!status.ok()) { + // if (PostExecuteTask(thread_id, group_id)) { + // bool all_task_groups_finished = false; + // RETURN_NOT_OK( + // OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); + // } + // return status; + // } else if (task_group_finished) { + // bool all_task_groups_finished = false; + // return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished); + // } return Status::OK(); })); diff --git a/cpp/src/arrow/acero/task_util_test.cc b/cpp/src/arrow/acero/task_util_test.cc index d5196ad4e0a03..e5d48582a3ca1 100644 --- a/cpp/src/arrow/acero/task_util_test.cc +++ b/cpp/src/arrow/acero/task_util_test.cc @@ -231,5 +231,78 @@ TEST(TaskScheduler, StressTwo) { } } +TEST(TaskScheduler, AbortContOnTaskErrorSerial) { + constexpr int kNumTasks = 16; + + auto scheduler = TaskScheduler::Make(); + auto task = [&](std::size_t, int64_t task_id) { + if (task_id == kNumTasks / 2) { + return Status::Invalid("Task failed"); + } + return Status::OK(); + }; + + int task_group = + scheduler->RegisterTaskGroup(task, [](std::size_t) { return Status::OK(); }); + scheduler->RegisterEnd(); + + ASSERT_OK(scheduler->StartScheduling( + 0, [](TaskScheduler::TaskGroupContinuationImpl) { return Status::OK(); }, 1, true)); + ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Task failed", + scheduler->StartTaskGroup(0, task_group, kNumTasks)); + + bool abort_cont_called = false; + auto abort_cont = [&]() { abort_cont_called = true; }; + scheduler->Abort(abort_cont); + ASSERT_TRUE(abort_cont_called); +} + +TEST(TaskScheduler, AbortContOnTaskErrorParallel) { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif + constexpr int kNumThreads = 16; + + ThreadIndexer thread_indexer; + int num_threads = std::min(static_cast(thread_indexer.Capacity()), kNumThreads); + ASSERT_OK_AND_ASSIGN(std::shared_ptr thread_pool, + MakePrimedThreadPool(num_threads)); + TaskScheduler::ScheduleImpl schedule = + [&](TaskScheduler::TaskGroupContinuationImpl task) { + return thread_pool->Spawn([&, task] { + std::size_t thread_id = thread_indexer(); + auto status = task(thread_id); + ASSERT_TRUE(status.ok() || status.IsInvalid() || status.IsCancelled()); + }); + }; + + int num_tasks = num_threads * 2; + auto scheduler = TaskScheduler::Make(); + auto task = [&](std::size_t, int64_t task_id) { + if (task_id % 2 == 0) { + return Status::Invalid("Task failed"); + } + return Status::OK(); + }; + + int task_group = + scheduler->RegisterTaskGroup(task, [](std::size_t) { return Status::OK(); }); + scheduler->RegisterEnd(); + + ASSERT_OK(scheduler->StartScheduling(0, schedule, num_tasks, false)); + ASSERT_OK(scheduler->StartTaskGroup(0, task_group, num_tasks)); + + thread_pool->WaitForIdle(); + + bool abort_cont_called = false; + auto abort_cont = [&]() { + ASSERT_FALSE(abort_cont_called); + abort_cont_called = true; + }; + scheduler->Abort(abort_cont); + + ASSERT_TRUE(abort_cont_called); +} + } // namespace acero } // namespace arrow