diff --git a/pjrt-plugin/iree/integrations/pjrt/common/BUILD b/pjrt-plugin/iree/integrations/pjrt/common/BUILD index 340c8133..21c8ef2d 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/BUILD +++ b/pjrt-plugin/iree/integrations/pjrt/common/BUILD @@ -26,6 +26,7 @@ iree_pjrt_cc_library( deps = [ ":compiler", ":debugging", + "@iree_core//runtime/src/iree/base:tracing", "@iree_core//runtime/src/iree/hal", "@iree_core//runtime/src/iree/modules/hal", "@iree_core//runtime/src/iree/vm", diff --git a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc index be45e9af..6201f2f0 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc +++ b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc @@ -9,8 +9,11 @@ #include #include +#include "iree/base/tracing.h" #include "iree/hal/api.h" +using iree::vm::retain_ref; + namespace iree::pjrt { // Chopped down utilities from various TPU support libraries. Basically all for @@ -298,7 +301,7 @@ const std::string& ErrorInstance::message() const { if (cached_message_.empty()) { std::string buffer; iree_host_size_t actual_len; - buffer.resize(128); + buffer.resize(1024); // TODO: Actually reallocate to full size on trunc. if (!iree_status_format(status_, buffer.size(), buffer.data(), &actual_len)) { buffer.resize(actual_len); @@ -349,14 +352,25 @@ iree_status_t BufferInstance::GetXlaShape(xla::Shape** out_shape) { return iree_ok_status(); } +BufferInstance::BufferInstance( + DeviceInstance& device, iree::vm::ref buffer_view) + : device_(device), buffer_view_(std::move(buffer_view)) { + IREE_CHECK_OK(device.CreateFence(&ready_fence_)); + IREE_CHECK_OK(device.CreateFence(&done_fence_)); +} + void BufferInstance::BindApi(PJRT_Api* api) { api->PJRT_Buffer_Destroy = +[](PJRT_Buffer_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_Destroy"); + iree_status_t status = + BufferInstance::Unwrap(args->buffer)->AsyncDeallocate(); delete BufferInstance::Unwrap(args->buffer); - return nullptr; + return MakeError(status); }; api->PJRT_Buffer_OnDeviceTrimmedShape = +[](PJRT_Buffer_OnDeviceTrimmedShape_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceTrimmedShape"); auto impl = [&]() -> iree_status_t { // TODO: This function is terrible and not exposed properly to C. // It is slated to be deleted... @@ -382,6 +396,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_ToHostBuffer = +[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_ToHostBuffer"); BufferInstance* buffer = BufferInstance::Unwrap(args->src); if (!args->dst) { // Size query. @@ -395,10 +410,12 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_OnDeviceSizeInBytes = +[](PJRT_Buffer_OnDeviceSizeInBytes_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceSizeInBytes"); return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_OnDeviceSizeInBytes")); }; api->PJRT_Buffer_Delete = +[](PJRT_Buffer_Delete_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_Delete"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Delete")); }; @@ -409,6 +426,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_CopyToDevice = +[](PJRT_Buffer_CopyToDevice_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_CopyToDevice"); return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_CopyToDevice")); }; @@ -423,6 +441,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_ReadyEvent = +[](PJRT_Buffer_ReadyEvent_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_ReadyEvent"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_ReadyEvent")); }; @@ -438,19 +457,83 @@ iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) { return iree_ok_status(); } +iree_status_t BufferInstance::AsyncDeallocate() { + IREE_TRACE_SCOPE(); + return iree_hal_device_queue_dealloca( + device().device(), IREE_HAL_QUEUE_AFFINITY_ANY, + /*wait_semaphore_list=*/iree_hal_fence_semaphore_list(done_fence()), + /*signal_semaphore_list=*/iree_hal_semaphore_list_empty(), + iree_hal_buffer_view_buffer(buffer_view_.get())); + return iree_ok_status(); +} + iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size, - EventInstance** done_event) { - // TODO: Do an async transfer on a transfer queue like a grown up. - iree_hal_device_t* hal_device; - IREE_RETURN_IF_ERROR(device_.GetHalDevice(&hal_device)); - IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h( - hal_device, iree_hal_buffer_view_buffer(buffer_view()), 0, dst, dst_size, - IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout())); - - *done_event = new EventInstance(); + EventInstance** out_done_event) { + // Set up an event for external triggering. While a little wonky, we + // trigger it in the host buffer release callback, which happens once the + // transfer is done. I don't love this option but it seems to match what + // I'm looking for. + EventInstance* capture_done_event = + new EventInstance(EventInstance::Type::EXTERNAL); + *out_done_event = capture_done_event; + + // Import the destination (host) buffer as an iree_hal_buffer_t so that we + // can issue copy commands. + iree::vm::ref dst_buffer; + iree_hal_buffer_params_t dst_buffer_params = { + IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET}; + iree_hal_external_buffer_t dst_external_buffer; + memset(&dst_external_buffer, 0, sizeof(dst_external_buffer)); + dst_external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION; + dst_external_buffer.flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE; + dst_external_buffer.size = dst_size; + dst_external_buffer.handle.host_allocation.ptr = dst; + auto release_callback = +[](void* user_data, iree_hal_buffer_t* buffer) { + IREE_TRACE_SCOPE0("PJRT_CopyToHost_ReleaseCallback"); + auto* local_done_event = static_cast(user_data); + local_done_event->ExternalSignalReady(iree_ok_status()); + }; + IREE_RETURN_IF_ERROR(iree_hal_allocator_import_buffer( + device_.device_allocator(), dst_buffer_params, &dst_external_buffer, + /*release_callback=*/{release_callback, capture_done_event}, + &dst_buffer)); + + // Create the transfer command buffer. + iree::vm::ref transfer_cb; + iree_hal_transfer_command_t transfer_command; + memset(&transfer_command, 0, sizeof(transfer_command)); + transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY; + transfer_command.copy.source_buffer = + iree_hal_buffer_view_buffer(buffer_view()); + transfer_command.copy.source_offset = 0; + transfer_command.copy.target_buffer = dst_buffer.get(); + transfer_command.copy.target_offset = 0; + transfer_command.copy.length = dst_size; + IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer( + device_.device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_QUEUE_AFFINITY_ANY, + /*transfer_count=*/1, &transfer_command, &transfer_cb)); + dst_buffer.reset(); + + IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY, + /*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()), + /*signal_semaphore_list=*/iree_hal_semaphore_list_empty(), + /*command_buffer_count=*/1, &transfer_cb)); + return iree_ok_status(); } +iree_status_t BufferInstance::AdvanceReadyFence(iree_hal_semaphore_t* semaphore, + uint64_t timepoint) { + return iree_hal_fence_insert(ready_fence_.get(), semaphore, timepoint); +} + +iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore, + uint64_t timepoint) { + return iree_hal_fence_insert(done_fence_.get(), semaphore, timepoint); +} + //===----------------------------------------------------------------------===// // DeviceInstance //===----------------------------------------------------------------------===// @@ -502,12 +585,25 @@ void DeviceInstance::BindApi(PJRT_Api* api) { }; } +iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) { + return iree_hal_fence_create(/*capacity=*/2, client_.host_allocator(), + out_fence); +} + iree_status_t DeviceInstance::OpenDevice() { if (device_) return iree_ok_status(); - return iree_hal_driver_create_device_by_id( + IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id( driver_, /*device_id=*/info_->device_id, /*param_count=*/0, /*params=*/nullptr, client_.host_allocator(), - &device_); + &device_)); + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device(), 0ull, &main_timeline_)); + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device(), 0ull, &transfer_timeline_)); + // IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(transfer_timeline_.get(), 0, + // client_.host_allocator(), + // &transfer_now_fence_)); + return iree_ok_status(); } iree_status_t DeviceInstance::HostBufferToDevice( @@ -565,26 +661,113 @@ iree_status_t DeviceInstance::HostBufferToDevice( byte_length *= dims[i]; } - // TODO: Don't do synchronous h2d transfer. Instead issue a command against - // the transfer queue like a grown-up. Also pay attention to zero copy flags - // and such. Plenty to make efficient here. + iree::vm::ref buffer; + // There are multiple ways to implement zero-copy/staged transfers and each + // implementation will have different performance cliffs associated with + // directly operating on imported host buffers. In many actual + // host/device situations, such unified memory is a productivity (not a + // performance) feature and best avoided. As such, we always need to be + // able to decide to do a staged transfer and implement that here. Using + // an imported buffer on the device is left as an optimization for + // implementations on which we believe it will be beneficial. + bool require_snapshot_now = host_buffer_semantics == + PJRT_HostBufferSemantics_kImmutableOnlyDuringCall; + bool caller_data_done = false; + iree::vm::ref host_staging_buffer; + IREE_RETURN_IF_ERROR(AcquireHostStagingBuffer( + iree_make_const_byte_span(data, byte_length), require_snapshot_now, + &caller_data_done, &host_staging_buffer)); + if (!caller_data_done) { + return iree_make_status( + IREE_STATUS_UNIMPLEMENTED, + "Deferred snapshot of host data not yet implemented"); + } + + // Allocate on stream. We serialize across 3 timepoints: + // 0. Last transfer complete + // 1. Allocation + // 2. This transfer complete + // There are various ways to be smarter about this but without more + // information from the caller, this is ok. If we wanted to favor smaller + // allocation scopes, it may be desirable to join with the main execution + // timeline, but that would obviously serialize more. + uint64_t wait_transfer_start = last_transfer_timepoint_; + uint64_t signal_alloca_complete = ++last_transfer_timepoint_; + uint64_t signal_copy_complete = ++last_transfer_timepoint_; iree_hal_buffer_params_t params; memset(¶ms, 0, sizeof(params)); params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; - params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; - - iree_hal_buffer_view_t* buffer_view = nullptr; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer( - iree_hal_device_allocator(device_.get()), num_dims, &shape[0], - element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, params, - iree_make_const_byte_span(data, byte_length), &buffer_view)); + params.usage = + IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca( + device(), IREE_HAL_QUEUE_AFFINITY_ANY, + /*wait_semaphore_list=*/ + {1, &transfer_timeline_, &wait_transfer_start}, + /*signal_semaphore_list=*/ + {1, &transfer_timeline_, &signal_alloca_complete}, + IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, byte_length, &buffer)); + + // Queue up the transfer command. + iree::vm::ref transfer_cb; + iree_hal_transfer_command_t transfer_command; + memset(&transfer_command, 0, sizeof(transfer_command)); + transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY; + transfer_command.copy.source_buffer = host_staging_buffer.get(), + transfer_command.copy.source_offset = 0; + transfer_command.copy.target_buffer = buffer.get(); + transfer_command.copy.target_offset = 0; + transfer_command.copy.length = byte_length; + IREE_RETURN_IF_ERROR(iree_hal_create_transfer_command_buffer( + device(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_QUEUE_AFFINITY_ANY, + /*transfer_count=*/1, &transfer_command, &transfer_cb)); + IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + device(), IREE_HAL_QUEUE_AFFINITY_ANY, + /*wait_semaphore_list=*/ + {1, &transfer_timeline_, &signal_alloca_complete}, + /*signal_semaphore_list=*/ + {1, &transfer_timeline_, &signal_copy_complete}, + /*command_buffer_count=*/1, &transfer_cb)); + + // Wrap in a buffer view and return. + iree::vm::ref result_buffer_view; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( + buffer.get(), num_dims, &shape[0], element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, client_.host_allocator(), + &result_buffer_view)); + + *out_buffer = new BufferInstance(*this, std::move(result_buffer_view)); + (*out_buffer) + ->AdvanceReadyFence(transfer_timeline_.get(), signal_copy_complete); + (*out_buffer) + ->AdvanceDoneFence(transfer_timeline_.get(), signal_copy_complete); + + // We snapshotted the caller data when acquiring the host staging buffer, + // so we won't be touching it again. + *out_done_with_host_buffer_event = + new EventInstance(EventInstance::Type::SIGNALLED); - // Since we synchronously copied, return an already signalled event. - *out_done_with_host_buffer_event = new EventInstance(); - - // Construct and return a BufferInstance. - *out_buffer = new BufferInstance(*this, buffer_view); + return iree_ok_status(); +} +iree_status_t DeviceInstance::AcquireHostStagingBuffer( + iree_const_byte_span_t initial_contents, bool snapshot_initial_contents_now, + bool* initial_contents_snapshotted, iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE(); + // There are multiple ways to do this that have different cost/benefits. + // Here we do the simplest thing and snapshot into a new host allocation. + // This could be replaced with either some form of staging ring buffer + // or importing from a raw pointer (on implementations where the cost of + // unified addressing is zero). + iree_hal_buffer_params_t params; + memset(¶ms, 0, sizeof(params)); + params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + device_allocator(), params, initial_contents.data_length, + initial_contents, out_buffer)); + // We did a synchronous snapshot (memcpy). + *initial_contents_snapshotted = true; return iree_ok_status(); } @@ -621,6 +804,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { // PJRT_Client_Create is polymorphic api->PJRT_Client_Destroy = +[](PJRT_Client_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_Destroy"); delete ClientInstance::Unwrap(args->client); return nullptr; }; @@ -674,6 +858,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Client_Compile = +[](PJRT_Client_Compile_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_Compile"); // TODO: It is not great that we only get a client here vs a list of // devices to consider (or something). The issue is that systems often // have unrelated devices that will not actually be scheduled and those @@ -698,6 +883,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Client_BufferFromHostBuffer = +[](PJRT_Client_BufferFromHostBuffer_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_BufferFromHostBuffer"); auto status = DeviceInstance::Unwrap(args->device) ->HostBufferToDevice( @@ -799,6 +985,9 @@ PJRT_Error* ClientInstance::Compile(PJRT_Program* program, if (!job->SetFlag("--iree-input-type=mhlo")) { return MakeCompilerError(); } + if (!job->SetFlag("--iree-execution-model=async-external")) { + return MakeCompilerError(); + } if (!SetDefaultCompilerFlags(job.get())) { return MakeCompilerError(); } @@ -854,39 +1043,119 @@ iree_status_t ClientInstance::PopulateVMModules( return iree_ok_status(); } +std::tuple ClientInstance::AdvanceTimeline() { + uint64_t current = execution_timeline_; + uint64_t next = current + 1; + execution_timeline_ = next; + return std::make_tuple(current, next); +} + //===----------------------------------------------------------------------===// // EventInstance //===----------------------------------------------------------------------===// +EventInstance::EventInstance(Type type) : type_(type) { + switch (type) { + case Type::SIGNALLED: + is_ready_ = true; + break; + case Type::EXTERNAL: + is_ready_ = false; + break; + } +} + +EventInstance::~EventInstance() { iree_status_ignore(status_); } + void EventInstance::BindApi(PJRT_Api* api) { api->PJRT_Event_Destroy = +[](PJRT_Event_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Destroy"); delete EventInstance::Unwrap(args->event); return nullptr; }; api->PJRT_Event_IsReady = +[](PJRT_Event_IsReady_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_IsReady"); args->is_ready = EventInstance::Unwrap(args->event)->is_ready(); return nullptr; }; api->PJRT_Event_Error = +[](PJRT_Event_Error_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Error"); return (PJRT_Error*)EventInstance::Unwrap(args->event)->error(); }; api->PJRT_Event_Await = +[](PJRT_Event_Await_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Await"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Event_Await")); }; api->PJRT_Event_OnReady = +[](PJRT_Event_OnReady_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_OnReady"); return MakeError(EventInstance::Unwrap(args->event) ->OnReady(args->callback, args->user_arg)); }; } +ErrorInstance* EventInstance::error() { + std::lock_guard guard(lock_); + if (!iree_status_is_ok(status_)) return new ErrorInstance(status_); + return nullptr; +} +bool EventInstance::is_ready() { + std::lock_guard guard(lock_); + return is_ready_; +} + iree_status_t EventInstance::OnReady(PJRT_Event_OnReadyCallback callback, void* user_arg) { - // TODO: Detect if not ready. - callback((PJRT_Error*)error_, user_arg); + iree_status_t local_status; + { + std::lock_guard guard(lock_); + if (!is_ready_) { + pending_callbacks_.push_back({callback, user_arg}); + return iree_ok_status(); + } + local_status = status_; + } + + // Already signalled. Callback out of lock scope. + // Note that the callback may destroy the event - so must only operate on + // locals. + callback( + iree_status_is_ok(local_status) + ? nullptr + : (PJRT_Error*)new ErrorInstance(iree_status_clone(local_status)), + user_arg); return iree_ok_status(); } +void EventInstance::ExternalSignalReady(iree_status_t status) { + IREE_TRACE_SCOPE(); + assert(type_ == Type::EXTERNAL && "expected EXTERNAL Event type"); + iree_status_t local_status; + std::vector> local_callbacks; + { + std::lock_guard guard(lock_); + if (is_ready_) { + return; + } + local_callbacks.swap(pending_callbacks_); + is_ready_ = true; + status_ = status; + local_status = status_; + } + + // Trigger callbacks outside of the lock. + // Note that the callback may destroy the event - so must only operate on + // locals. + for (auto& cb : local_callbacks) { + IREE_TRACE_SCOPE0("PJRT_User_Callback_Invoke"); + cb.first( + iree_status_is_ok(local_status) + ? nullptr + : (PJRT_Error*)new ErrorInstance(iree_status_clone(local_status)), + cb.second); + } +} + //===----------------------------------------------------------------------===// // ExecutableInstance //===----------------------------------------------------------------------===// @@ -894,6 +1163,7 @@ iree_status_t EventInstance::OnReady(PJRT_Event_OnReadyCallback callback, void ExecutableInstance::BindApi(PJRT_Api* api) { api->PJRT_Executable_Destroy = +[](PJRT_Executable_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_Destroy"); delete ExecutableInstance::Unwrap(args->executable); return nullptr; }; @@ -930,11 +1200,13 @@ void ExecutableInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Executable_Execute = +[](PJRT_Executable_Execute_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_Execute"); return MakeError( ExecutableInstance::Unwrap(args->executable)->BatchExecute(args)); }; api->PJRT_Executable_NumOutputs = +[](PJRT_Executable_NumOutputs_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_NumOutputs"); auto* exec = ExecutableInstance::Unwrap(args->executable); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -967,6 +1239,7 @@ void ExecutableInstance::BindApi(PJRT_Api* api) { } iree_status_t ExecutableInstance::LoadAll() { + IREE_TRACE_SCOPE(); if (!loaded_executables_.empty()) return iree_ok_status(); std::vector new_list; @@ -1052,6 +1325,11 @@ iree_status_t ExecutableInstance::BatchExecute( // Make sure loaded. IREE_RETURN_IF_ERROR(LoadAll()); + // Timeline setup. There are two timelines that we synchronize to: + // the main execution timeline, which preserves as-called ordering to + // execution, and the transfer timeline of each device. + auto [wait_timepoint, signal_timepoint] = client_.AdvanceTimeline(); + // Initialize invocations. auto allocator = client_.host_allocator(); auto& loaded_execs = loaded_executables_; @@ -1059,12 +1337,34 @@ iree_status_t ExecutableInstance::BatchExecute( LoadedExecutable* dev_exe; iree::vm::ref inputs; iree::vm::ref outputs; + iree::vm::ref wait_fence; + iree::vm::ref signal_fence; }; std::vector invs; invs.resize(args->num_devices); for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) { auto& inv = invs[dev_index]; inv.dev_exe = &loaded_execs[dev_index]; + + // Wait fence initial value. + // We allocate it to be able to hold two semaphores (main timeline and + // transfer timeline) and initialize it with the global invocation order + // of the main timeline. As we process inputs, we will also insert their + // transfer ready semaphore value so that execution can only begin once + // all dependencies are ready. This at most represents two unique + // semaphores. + IREE_RETURN_IF_ERROR( + inv.dev_exe->device_instance->CreateFence(&inv.wait_fence)); + IREE_RETURN_IF_ERROR(iree_hal_fence_insert( + inv.wait_fence.get(), inv.dev_exe->device_instance->main_timeline(), + wait_timepoint)); + + // Signal fence. This signals the next tick on the main execution + // timeline. + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at( + inv.dev_exe->device_instance->main_timeline(), signal_timepoint, + client_.host_allocator(), &inv.signal_fence)); + IREE_RETURN_IF_ERROR(iree_vm_list_create( /*element_type=*/nullptr, args->num_args, allocator, &inv.inputs)); IREE_RETURN_IF_ERROR(iree_vm_list_create( @@ -1078,7 +1378,20 @@ iree_status_t ExecutableInstance::BatchExecute( iree_hal_buffer_view_retain_ref(buffer->buffer_view()); IREE_RETURN_IF_ERROR( iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref)); + + // Extend the execute wait to include the input's ready signal. + IREE_RETURN_IF_ERROR( + iree_hal_fence_extend(inv.wait_fence.get(), buffer->ready_fence())); + + // And extend the buffer's done fence to close over this execution. + buffer->AdvanceDoneFence(inv.dev_exe->device_instance->main_timeline(), + signal_timepoint); } + + // Add (wait, signal) fences as required by the async-external execution + // model. + iree_vm_list_push_ref_retain(inv.inputs.get(), inv.wait_fence); + iree_vm_list_push_ref_retain(inv.inputs.get(), inv.signal_fence); } // Issue invocations. @@ -1101,18 +1414,25 @@ iree_status_t ExecutableInstance::BatchExecute( for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) { auto& inv = invs[dev_index]; for (size_t i = 0; i < inv.dev_exe->result_count; ++i) { - iree_hal_buffer_view_t* ret_buffer_view = - (iree_hal_buffer_view_t*)iree_vm_list_get_ref_deref( - inv.outputs.get(), i, iree_hal_buffer_view_get_descriptor()); + iree::vm::ref ret_buffer_view = + retain_ref((iree_hal_buffer_view_t*)iree_vm_list_get_ref_deref( + inv.outputs.get(), i, iree_hal_buffer_view_get_descriptor())); // This should not be possible so just hard-assert. IREE_ASSERT_ARGUMENT(ret_buffer_view); - iree_hal_buffer_view_retain(ret_buffer_view); - args->output_lists[dev_index][i] = - *(new BufferInstance(*inv.dev_exe->device_instance, ret_buffer_view)); + auto result_buffer = std::make_unique( + *inv.dev_exe->device_instance, std::move(ret_buffer_view)); + IREE_RETURN_IF_ERROR(result_buffer->AdvanceReadyFence( + inv.dev_exe->device_instance->main_timeline(), signal_timepoint)); + IREE_RETURN_IF_ERROR(result_buffer->AdvanceDoneFence( + inv.dev_exe->device_instance->main_timeline(), signal_timepoint)); + args->output_lists[dev_index][i] = *(result_buffer.release()); } if (args->device_complete_events) { - args->device_complete_events[dev_index] = *(new EventInstance()); + // TODO: Plumb through signal fence. This doesn't seem to be used in + // the simple cases I've seen so far. + args->device_complete_events[dev_index] = + *(new EventInstance(EventInstance::Type::SIGNALLED)); } } diff --git a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h index 102f1280..b765512f 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h +++ b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h @@ -7,6 +7,7 @@ #ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_ #define IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_ +#include #include #include #include @@ -68,8 +69,8 @@ inline PJRT_Error* MakeError(iree_status_t status) { class BufferInstance { public: - BufferInstance(DeviceInstance& device, iree_hal_buffer_view_t* buffer_view) - : device_(device), buffer_view_(buffer_view) {} + BufferInstance(DeviceInstance& device, + iree::vm::ref buffer_view); ~BufferInstance(); operator PJRT_Buffer*() { return reinterpret_cast(this); } static BufferInstance* Unwrap(PJRT_Buffer* buffer) { @@ -79,6 +80,7 @@ class BufferInstance { iree_hal_buffer_view_t* buffer_view() { return buffer_view_.get(); } DeviceInstance& device() { return device_; } + iree_status_t AsyncDeallocate(); bool is_deleted() { return false; } bool is_on_cpu() { // TODO: Plumb through an indication if running on CPU and then implement @@ -92,12 +94,29 @@ class BufferInstance { iree_status_t CopyToHost(void* dst, iree_host_size_t dst_size, EventInstance** done_event); + // Advance the ready and done fences. + iree_status_t AdvanceReadyFence(iree_hal_semaphore_t* semaphore, + uint64_t timepoint); + iree_status_t AdvanceDoneFence(iree_hal_semaphore_t* semaphore, + uint64_t timepoint); + + iree_hal_fence_t* ready_fence() { return ready_fence_.get(); } + iree_hal_fence_t* done_fence() { return done_fence_.get(); } + private: DeviceInstance& device_; - iree::vm::ref buffer_view_; // Owned. + iree::vm::ref buffer_view_; // Various things require XLA's idea of shapes, layouts, etc. // We keep one around for such cases. std::optional cached_shape_; + + // Fences. + // ready_fence_: Signalled when the buffer is ready to be consumed. Consumers + // should wait on this fence. + // done_fence_: Signalled when all scheduled work on the buffer is done. + // Consumers should advance this fence when using it. + iree::vm::ref ready_fence_; + iree::vm::ref done_fence_; }; //===----------------------------------------------------------------------===// @@ -137,14 +156,38 @@ class DeviceInstance { EventInstance** out_done_with_host_buffer_event, BufferInstance** out_buffer); + // TODO(laurenzo): Eagerly set up device to allow simple access. iree_status_t GetHalDevice(iree_hal_device_t** out_device); + // Only valid once device opened. + iree_hal_semaphore_t* main_timeline() { return main_timeline_.get(); } + + iree_hal_device_t* device() { return device_.get(); } + iree_hal_allocator_t* device_allocator() { + return iree_hal_device_allocator(device_.get()); + } + // Creates a fence sized to the maximum number of semaphores in use by the + // device. + iree_status_t CreateFence(iree_hal_fence_t** out_fence); + private: iree_status_t OpenDevice(); + iree_status_t AcquireHostStagingBuffer( + iree_const_byte_span_t initial_contents, + bool snapshot_initial_contents_now, bool* initial_contents_snapshotted, + iree_hal_buffer_t** out_buffer); + int client_id_; ClientInstance& client_; iree_hal_driver_t* driver_; // Owned by client. iree::vm::ref device_; + iree::vm::ref main_timeline_; + iree::vm::ref transfer_timeline_; + // A fence that is initialized to the start of the transfer timeline, + // effectively being signalled immediately. + iree::vm::ref transfer_now_fence_; + // The timepoint of the last transfer. + uint64_t last_transfer_timepoint_ = 0; iree_hal_device_info_t* info_; }; @@ -154,8 +197,20 @@ class DeviceInstance { class EventInstance { public: + enum class Type { + // An event that is always signalled. + SIGNALLED, + + // An EXTERNAL event will have an outside caller invoke + // ExternalSignalReady() when it is ready. + // Any further OnReady callback will happen within the context of that + // call (which can be on any thread). + EXTERNAL, + }; + // Default construction is always signalled. - EventInstance() = default; + EventInstance(Type type); + ~EventInstance(); operator PJRT_Event*() { return reinterpret_cast(this); } static void BindApi(PJRT_Api* api); static EventInstance* Unwrap(PJRT_Event* exe) { @@ -163,12 +218,18 @@ class EventInstance { } iree_status_t OnReady(PJRT_Event_OnReadyCallback callback, void* user_arg); - ErrorInstance* error() { return error_; } - bool is_ready() { return is_ready_; } + ErrorInstance* error(); + bool is_ready(); + + // For EXTERNAL events: Signals that the event is ready. + void ExternalSignalReady(iree_status_t status); private: - ErrorInstance* error_ = nullptr; - bool is_ready_ = true; + std::mutex lock_; + Type type_; + iree_status_t status_ = iree_ok_status(); + bool is_ready_; + std::vector> pending_callbacks_; }; //===----------------------------------------------------------------------===// @@ -286,6 +347,9 @@ struct ClientInstance { // Returns false on failure (and sets error information on the compiler_job). virtual bool SetDefaultCompilerFlags(CompilerJob* compiler_job) = 0; + // Advances the timeline, returning (current, next) time point values. + std::tuple AdvanceTimeline(); + protected: iree_allocator_t host_allocator_; std::string cached_platform_name_; @@ -307,6 +371,16 @@ struct ClientInstance { // VM. iree::vm::ref vm_instance_; + + // Synchronization. + // We keep one global execution timeline across all devices. The management + // of this is currently somewhat primitive: we increment it by one for each + // invocation. Batch invocations (i.e. across multiple devices), only + // increment by one. In the future, additional parallelism could be plumbed + // up to the framework to allow different kinds of timeline management. + // Waiting on the current value of |execution_timeline_| will drain all + // scheduled work to date. + uint64_t execution_timeline_ = 0ull; }; //===----------------------------------------------------------------------===// diff --git a/pjrt-plugin/test/test_simple.py b/pjrt-plugin/test/test_simple.py index dac2aadf..268872aa 100644 --- a/pjrt-plugin/test/test_simple.py +++ b/pjrt-plugin/test/test_simple.py @@ -3,8 +3,16 @@ from jax._src.lib import xla_client +# Do once and print. a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]) b = a for i in range(100): b = jax.numpy.asarray([i]) * a + b print(b) + +# Do once and print. +a = jax.numpy.asarray([10, 20, 30, 40, 50, 60, 70, 80, 90]) +b = a +for i in range(100): + b = jax.numpy.asarray([i]) * a + b +print(b)