Skip to content

Commit

Permalink
[SYCL][Graph] Dynamic command-groups
Browse files Browse the repository at this point in the history
Implement Dynamic Command-Group feature for intel#14896
to enable updating `ur_kernel_handle_t` objects in graph nodes
between executions.
  • Loading branch information
EwanC committed Oct 10, 2024
1 parent 7f59dea commit 3fc5f63
Show file tree
Hide file tree
Showing 16 changed files with 928 additions and 7 deletions.
50 changes: 50 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class node_impl;
class graph_impl;
class exec_graph_impl;
class dynamic_parameter_impl;
class dynamic_command_group_impl;
} // namespace detail

enum class node_type {
Expand Down Expand Up @@ -213,6 +214,27 @@ class depends_on_all_leaves : public ::sycl::detail::DataLessProperty<
} // namespace node
} // namespace property

class __SYCL_EXPORT dynamic_command_group {
public:
dynamic_command_group(
const context &SyclContext, const device &SyclDevice,
const std::vector<std::function<void(handler &)>> &CGFList);

dynamic_command_group(
const queue &SyclQueue,
const std::vector<std::function<void(handler &)>> &CGFList);

size_t get_active_cgf() const;
void set_active_cgf(size_t Index);

private:
template <class Obj>
friend const decltype(Obj::impl) &
sycl::detail::getSyclObjImpl(const Obj &SyclObject);

std::shared_ptr<detail::dynamic_command_group_impl> impl;
};

namespace detail {
// Templateless modifiable command-graph base class.
class __SYCL_EXPORT modifiable_command_graph {
Expand Down Expand Up @@ -269,6 +291,28 @@ class __SYCL_EXPORT modifiable_command_graph {
return Node;
}

/// Add a Dynamic command-group node to the graph.
/// @param DynamicCG Dynamic command-group function to create node with.
/// @param PropList Property list used to pass [0..n] predecessor nodes.
/// @return Constructed node which has been added to the graph.
node add(dynamic_command_group &DynamicCG,
const property_list &PropList = {}) {
if (PropList.has_property<property::node::depends_on>()) {
auto Deps = PropList.get_property<property::node::depends_on>();
node Node = addImpl(DynamicCG, Deps.get_dependencies());
if (PropList.has_property<property::node::depends_on_all_leaves>()) {
addGraphLeafDependencies(Node);
}
return Node;
}

node Node = addImpl(DynamicCG, {});
if (PropList.has_property<property::node::depends_on_all_leaves>()) {
addGraphLeafDependencies(Node);
}
return Node;
}

/// Add a dependency between two nodes.
/// @param Src Node which will be a dependency of \p Dest.
/// @param Dest Node which will be dependent on \p Src.
Expand Down Expand Up @@ -328,6 +372,12 @@ class __SYCL_EXPORT modifiable_command_graph {
modifiable_command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
: impl(Impl) {}

/// Template-less implementation of add() for dynamic command-group nodes.
/// @param DynCGF Dynamic Command-group function object to add.
/// @param Dep List of predecessor nodes.
/// @return Node added to the graph.
node addImpl(dynamic_command_group &DynCGF, const std::vector<node> &Dep);

/// Template-less implementation of add() for CGF nodes.
/// @param CGF Command-group function to add.
/// @param Dep List of predecessor nodes.
Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3376,6 +3376,7 @@ class __SYCL_EXPORT handler {
size_t Size, bool Block = false);
friend class ext::oneapi::experimental::detail::graph_impl;
friend class ext::oneapi::experimental::detail::dynamic_parameter_impl;
friend class ext::oneapi::experimental::detail::dynamic_command_group_impl;

bool DisableRangeRounding();

Expand Down
82 changes: 79 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,18 @@ exec_graph_impl::enqueueNodeDirect(sycl::context Ctx,
StreamID, InstanceID, CmdTraceEvent, xpti::trace_task_begin, nullptr);
#endif

std::vector<sycl::detail::CGExecKernel *> KernelAlternatives{};
if (Node->MDynCG) {
for (auto &CG : Node->MDynCG->MKernels) {
KernelAlternatives.push_back(
static_cast<sycl::detail::CGExecKernel *>(CG.get()));
}
}

ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel(
Ctx, DeviceImpl, CommandBuffer,
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
Deps, &NewSyncPoint, &NewCommand, nullptr);
KernelAlternatives, Deps, &NewSyncPoint, &NewCommand, nullptr);

MCommandMap[Node] = NewCommand;

Expand Down Expand Up @@ -1376,8 +1384,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());

// Gather arg information from Node
auto &ExecCG =
*(static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get()));
sycl::detail::CG *NodeCG = (Node->MDynCG)
? Node->MDynCG->getActiveKernel().get()
: Node->MCommandGroup.get();
auto ExecCG = *(static_cast<sycl::detail::CGExecKernel *>(NodeCG));

// Copy args because we may modify them
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
// Copy NDR desc since we need to modify it
Expand Down Expand Up @@ -1560,6 +1571,27 @@ modifiable_command_graph::modifiable_command_graph(
: impl(std::make_shared<detail::graph_impl>(
SyclQueue.get_context(), SyclQueue.get_device(), PropList)) {}

node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");

std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

auto DynCGFImpl = sycl::detail::getSyclObjImpl(DynCGF);
const std::function<void(handler &)> &CGF = DynCGFImpl->getActiveCGF();

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);

// Track the dynamic command-group used inside the node object
NodeImpl->MDynCG = DynCGFImpl;

return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
Expand Down Expand Up @@ -1760,6 +1792,31 @@ void dynamic_parameter_base::updateAccessor(
impl->updateAccessor(Acc);
}

dynamic_command_group_impl::dynamic_command_group_impl(
const context &Context, const device &Device,
const std::vector<std::function<void(handler &)>> &CGFList)
: MContext(Context), MDevice(Device), MActiveCGF(0), MCGFList(CGFList) {

// Create a placeholder graph object so we can use it to construct a handler
// object to process the CGFs.
auto TmpGraph = std::make_shared<graph_impl>(MContext, MDevice);
for (const auto &CGF : MCGFList) {
// Handler defined inside the loop so it doesn't appear to the runtime
// as a single command-group with multiple commands inside.
sycl::handler Handler{TmpGraph};
CGF(Handler);

if (Handler.getType() != sycl::detail::CGType::Kernel) {
throw sycl::exception(
make_error_code(errc::invalid),
"The only type of command-groups that can be used in "
"dynamic command-groups is kernels.");
}

Handler.finalize();
MKernels.push_back(std::move(Handler.impl->MGraphNodeCG));
}
}
} // namespace detail

node_type node::get_type() const { return impl->MNodeType; }
Expand Down Expand Up @@ -1798,6 +1855,25 @@ template <> __SYCL_EXPORT void node::update_range<2>(range<2> Range) {
template <> __SYCL_EXPORT void node::update_range<3>(range<3> Range) {
impl->updateRange(Range);
}

dynamic_command_group::dynamic_command_group(
const context &SyclContext, const device &SyclDevice,
const std::vector<std::function<void(handler &)>> &CGFList)
: impl(std::make_shared<detail::dynamic_command_group_impl>(
SyclContext, SyclDevice, CGFList)) {}

dynamic_command_group::dynamic_command_group(
const queue &SyclQueue,
const std::vector<std::function<void(handler &)>> &CGFList)
: impl(std::make_shared<detail::dynamic_command_group_impl>(
SyclQueue.get_context(), SyclQueue.get_device(), CGFList)) {}

size_t dynamic_command_group::get_active_cgf() const {
return impl->getActiveIndex();
}
void dynamic_command_group::set_active_cgf(size_t Index) {
return impl->setActiveIndex(Index);
}
} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
40 changes: 39 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
/// Stores the executable graph impl associated with this node if it is a
/// subgraph node.
std::shared_ptr<exec_graph_impl> MSubGraphImpl;
/// For Dynamic command-group nodes, stores the dynamic command-group object.
std::shared_ptr<dynamic_command_group_impl> MDynCG;

/// Used for tracking visited status during cycle checks.
bool MVisited = false;
Expand Down Expand Up @@ -160,7 +162,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
: enable_shared_from_this(Other), MSuccessors(Other.MSuccessors),
MPredecessors(Other.MPredecessors), MCGType(Other.MCGType),
MNodeType(Other.MNodeType), MCommandGroup(Other.getCGCopy()),
MSubGraphImpl(Other.MSubGraphImpl) {}
MSubGraphImpl(Other.MSubGraphImpl), MDynCG(Other.MDynCG) {}

/// Copy-assignment operator. This will perform a deep-copy of the
/// command group object associated with this node.
Expand All @@ -172,6 +174,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
MNodeType = Other.MNodeType;
MCommandGroup = Other.getCGCopy();
MSubGraphImpl = Other.MSubGraphImpl;
MDynCG = Other.MDynCG;
}
return *this;
}
Expand Down Expand Up @@ -1579,6 +1582,41 @@ class dynamic_parameter_impl {
std::vector<std::byte> MValueStorage;
};

class dynamic_command_group_impl {
public:
dynamic_command_group_impl(
const sycl::context &Context, const sycl::device &Device,
const std::vector<std::function<void(handler &)>> &CGFList);

size_t getActiveIndex() const { return MActiveCGF; }
void setActiveIndex(size_t Index) {
if (Index >= MCGFList.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Index is out of range.");
}

MActiveCGF = Index;
}

const std::function<void(handler &)> &getActiveCGF() const {
return MCGFList[MActiveCGF];
}

const std::unique_ptr<sycl::detail::CG> &getActiveKernel() const {
return MKernels[MActiveCGF];
}

sycl::context MContext; // TODO - verify
sycl::device MDevice; // TODO - verify
size_t MActiveCGF; // TODO Thread safe?

// List of CGFs. Initialized on creation of dynamic command-group object by
// copying by value the list of std::functions passed by the user.
const std::vector<std::function<void(handler &)>> MCGFList;

/// List of kernel command-groups for dynamic command-group nodes
std::vector<std::unique_ptr<sycl::detail::CG>> MKernels;
};
} // namespace detail
} // namespace experimental
} // namespace oneapi
Expand Down
22 changes: 19 additions & 3 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,7 @@ ur_result_t enqueueImpCommandBufferKernel(
context Ctx, DeviceImplPtr DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
const CGExecKernel &CommandGroup,
const std::vector<sycl::detail::CGExecKernel *> &AlternativeKernels,
std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints,
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
ur_exp_command_buffer_command_handle_t *OutCommand,
Expand Down Expand Up @@ -2520,6 +2521,18 @@ ur_result_t enqueueImpCommandBufferKernel(
ContextImpl, DeviceImpl, CommandGroup.MKernelName);
}

// TODO - refactor this naive impl
std::vector<ur_kernel_handle_t> AltK;
for (const auto &CGKernel : AlternativeKernels) {
ur_kernel_handle_t URK;
std::tie(URK, std::ignore, std::ignore, std::ignore) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, CGKernel->MKernelName);
if (URK != UrKernel) {
AltK.push_back(URK);
}
}

auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &Ctx,
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
size_t NextTrueIndex) {
Expand Down Expand Up @@ -2561,7 +2574,8 @@ ur_result_t enqueueImpCommandBufferKernel(
ur_result_t Res =
Adapter->call_nocheck<UrApiKind::urCommandBufferAppendKernelLaunchExp>(
CommandBuffer, UrKernel, NDRDesc.Dims, &NDRDesc.GlobalOffset[0],
&NDRDesc.GlobalSize[0], LocalSize, 0, nullptr, SyncPoints.size(),
&NDRDesc.GlobalSize[0], LocalSize, AltK.size(),
AltK.size() ? AltK.data() : nullptr, SyncPoints.size(),
SyncPoints.size() ? SyncPoints.data() : nullptr, OutSyncPoint,
OutCommand);

Expand Down Expand Up @@ -2777,10 +2791,12 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
return AllocaCmd->getMemAllocation();
};

// TODO
std::vector<sycl::detail::CGExecKernel *> AlternativeKernels{};
auto result = enqueueImpCommandBufferKernel(
MQueue->get_context(), MQueue->getDeviceImplPtr(), MCommandBuffer,
*ExecKernel, MSyncPointDeps, &OutSyncPoint, &OutCommand,
getMemAllocationFunc);
*ExecKernel, AlternativeKernels, MSyncPointDeps, &OutSyncPoint,
&OutCommand, getMemAllocationFunc);
MEvent->setSyncPoint(OutSyncPoint);
MEvent->setCommandBufferCommand(OutCommand);
return result;
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ ur_result_t enqueueImpCommandBufferKernel(
context Ctx, DeviceImplPtr DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
const CGExecKernel &CommandGroup,
const std::vector<sycl::detail::CGExecKernel *> &AlternativeKernels,
std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints,
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
ur_exp_command_buffer_command_handle_t *OutCommand,
Expand Down
58 changes: 58 additions & 0 deletions sycl/test-e2e/Graph/Update/dyn_cgf_accessor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
// Extra run to check for immediate-command-list in Level Zero
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
//

// XFAIL: level_zero

// Tests using dynamic command-group objects with buffer accessors

#include "../graph_common.hpp"

int main() {
queue Queue{};
const size_t N = 1024;
std::vector<int> HostData(N, 0);
buffer Buf{HostData};
Buf.set_write_back(false);
auto Acc = Buf.get_access();

exp_ext::command_graph Graph{
Queue.get_context(),
Queue.get_device(),
{exp_ext::property::graph::assume_buffer_outlives_graph{}}};

int PatternA = 42;
auto CGFA = [&](handler &CGH) {
CGH.parallel_for(N, [=](item<1> Item) { Acc[Item.get_id()] = PatternA; });
};

int PatternB = 0xA;
auto CGFB = [&](handler &CGH) {
CGH.parallel_for(N, [=](item<1> Item) { Acc[Item.get_id()] = PatternB; });
};

auto DynamicCG = exp_ext::dynamic_command_group(Queue, {CGFA, CGFB});
auto DynamicCGNode = Graph.add(DynamicCG);
auto ExecGraph = Graph.finalize(exp_ext::property::graph::updatable{});

Queue.ext_oneapi_graph(ExecGraph).wait();
Queue.copy(Acc, HostData.data()).wait();
for (size_t i = 0; i < N; i++) {
assert(HostData[i] == PatternA);
}

DynamicCG.set_active_cgf(1);
ExecGraph.update(DynamicCGNode);
Queue.ext_oneapi_graph(ExecGraph).wait();

Queue.copy(Acc, HostData.data()).wait();
for (size_t i = 0; i < N; i++) {
assert(HostData[i] == PatternB);
}

return 0;
}
Loading

0 comments on commit 3fc5f63

Please sign in to comment.