Skip to content

Commit

Permalink
Resolving binding references when applying deferred command buffers. (i…
Browse files Browse the repository at this point in the history
…ree-org#17840)

This uses the binding table passed in during application to translate
any indirect buffer references into direct ones before passing them on
to the replay target. This allows the replay target to assume all
incoming buffer refs are direct.
  • Loading branch information
benvanik authored Jul 9, 2024
1 parent b5a12ee commit 8513e5f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 12 deletions.
38 changes: 38 additions & 0 deletions runtime/src/iree/hal/command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,44 @@ iree_hal_buffer_binding_table_empty(void) {
return table;
}

// Returns an unretained buffer specified in |buffer_ref| or from
// |binding_table| with the slot specified if indirect. If the caller needs to
// preserve the buffer for longer than the (known) lifetime of the binding table
// then it must be retained or added to a resource set.
static inline iree_status_t iree_hal_buffer_binding_table_resolve_ref(
iree_hal_buffer_binding_table_t binding_table,
iree_hal_buffer_ref_t buffer_ref, iree_hal_buffer_ref_t* out_resolved_ref) {
if (buffer_ref.buffer) {
// Direct buffer reference.
*out_resolved_ref = buffer_ref;
return iree_ok_status();
} else if (binding_table.count == 0) {
// NULL buffer reference.
memset(out_resolved_ref, 0, sizeof(*out_resolved_ref));
return iree_ok_status();
} else if (IREE_UNLIKELY(buffer_ref.buffer_slot >= binding_table.count)) {
// Out of bounds slot (validation should have caught). May be worth removing
// this case as this is a hot path.
// NOTE: this asserts that all incoming buffers must not be NULL. That may
// not be true.
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"buffer binding %u out of range of binding table "
"with capacity %" PRIhsz,
buffer_ref.buffer_slot, binding_table.count);
} else {
// Indirect buffer reference - need to combine the final range based on
// the binding table range and the range of the reference.
const iree_hal_buffer_binding_t* binding =
&binding_table.bindings[buffer_ref.buffer_slot];
out_resolved_ref->ordinal = buffer_ref.ordinal;
out_resolved_ref->buffer_slot = 0;
out_resolved_ref->buffer = binding->buffer;
return iree_hal_buffer_calculate_range(
binding->offset, binding->length, buffer_ref.offset, buffer_ref.length,
&out_resolved_ref->offset, &out_resolved_ref->length);
}
}

//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 41 additions & 12 deletions runtime/src/iree/hal/utils/deferred_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,11 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_discard_buffer(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_discard_buffer_t* cmd) {
iree_hal_buffer_ref_t buffer_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->buffer_ref, &buffer_ref));
return iree_hal_command_buffer_discard_buffer(target_command_buffer,
cmd->buffer_ref);
buffer_ref);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -507,9 +510,12 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_fill_buffer(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_fill_buffer_t* cmd) {
return iree_hal_command_buffer_fill_buffer(
target_command_buffer, cmd->target_ref, (void**)&cmd->pattern,
cmd->pattern_length);
iree_hal_buffer_ref_t target_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->target_ref, &target_ref));
return iree_hal_command_buffer_fill_buffer(target_command_buffer, target_ref,
(void**)&cmd->pattern,
cmd->pattern_length);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -547,8 +553,11 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_update_buffer(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_update_buffer_t* cmd) {
iree_hal_buffer_ref_t target_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->target_ref, &target_ref));
return iree_hal_command_buffer_update_buffer(
target_command_buffer, cmd->source_buffer, 0, cmd->target_ref);
target_command_buffer, cmd->source_buffer, 0, target_ref);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -591,8 +600,14 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_copy_buffer(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_copy_buffer_t* cmd) {
return iree_hal_command_buffer_copy_buffer(target_command_buffer,
cmd->source_ref, cmd->target_ref);
iree_hal_buffer_ref_t source_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->source_ref, &source_ref));
iree_hal_buffer_ref_t target_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->target_ref, &target_ref));
return iree_hal_command_buffer_copy_buffer(target_command_buffer, source_ref,
target_ref);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -639,9 +654,15 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_collective(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_collective_t* cmd) {
iree_hal_buffer_ref_t send_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->send_ref, &send_ref));
iree_hal_buffer_ref_t recv_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->recv_ref, &recv_ref));
return iree_hal_command_buffer_collective(target_command_buffer, cmd->channel,
cmd->op, cmd->param, cmd->send_ref,
cmd->recv_ref, cmd->element_count);
cmd->op, cmd->param, send_ref,
recv_ref, cmd->element_count);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -728,9 +749,15 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_push_descriptor_set(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_push_descriptor_set_t* cmd) {
iree_hal_buffer_ref_t* binding_refs = (iree_hal_buffer_ref_t*)iree_alloca(
cmd->binding_count * sizeof(iree_hal_buffer_ref_t));
for (iree_host_size_t i = 0; i < cmd->binding_count; ++i) {
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->bindings[i], &binding_refs[i]));
}
return iree_hal_command_buffer_push_descriptor_set(
target_command_buffer, cmd->pipeline_layout, cmd->set, cmd->binding_count,
cmd->bindings);
binding_refs);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -814,9 +841,11 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch_indirect(
iree_hal_command_buffer_t* target_command_buffer,
iree_hal_buffer_binding_table_t binding_table,
const iree_hal_cmd_dispatch_indirect_t* cmd) {
iree_hal_buffer_ref_t workgroups_ref;
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->workgroups_ref, &workgroups_ref));
return iree_hal_command_buffer_dispatch_indirect(
target_command_buffer, cmd->executable, cmd->entry_point,
cmd->workgroups_ref);
target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref);
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8513e5f

Please sign in to comment.