Skip to content

Commit

Permalink
[SYCL][Graph] Dynamic command-groups
Browse files Browse the repository at this point in the history
Implement Dynamic Command-Group feature for intel#14896
to enable updating `ur_kernel_handle_t` objects in graph nodes
between executions as well as parameters and nd-range of node.
  • Loading branch information
EwanC committed Oct 11, 2024
1 parent 7f59dea commit c27ec13
Show file tree
Hide file tree
Showing 18 changed files with 1,171 additions and 12 deletions.
50 changes: 50 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class node_impl;
class graph_impl;
class exec_graph_impl;
class dynamic_parameter_impl;
class dynamic_command_group_impl;
} // namespace detail

enum class node_type {
Expand Down Expand Up @@ -213,6 +214,27 @@ class depends_on_all_leaves : public ::sycl::detail::DataLessProperty<
} // namespace node
} // namespace property

class __SYCL_EXPORT dynamic_command_group {
public:
dynamic_command_group(
const context &SyclContext, const device &SyclDevice,
const std::vector<std::function<void(handler &)>> &CGFList);

dynamic_command_group(
const queue &SyclQueue,
const std::vector<std::function<void(handler &)>> &CGFList);

size_t get_active_cgf() const;
void set_active_cgf(size_t Index);

private:
template <class Obj>
friend const decltype(Obj::impl) &
sycl::detail::getSyclObjImpl(const Obj &SyclObject);

std::shared_ptr<detail::dynamic_command_group_impl> impl;
};

namespace detail {
// Templateless modifiable command-graph base class.
class __SYCL_EXPORT modifiable_command_graph {
Expand Down Expand Up @@ -269,6 +291,28 @@ class __SYCL_EXPORT modifiable_command_graph {
return Node;
}

/// Add a Dynamic command-group node to the graph.
/// @param DynamicCG Dynamic command-group function to create node with.
/// @param PropList Property list used to pass [0..n] predecessor nodes.
/// @return Constructed node which has been added to the graph.
node add(dynamic_command_group &DynamicCG,
const property_list &PropList = {}) {
if (PropList.has_property<property::node::depends_on>()) {
auto Deps = PropList.get_property<property::node::depends_on>();
node Node = addImpl(DynamicCG, Deps.get_dependencies());
if (PropList.has_property<property::node::depends_on_all_leaves>()) {
addGraphLeafDependencies(Node);
}
return Node;
}

node Node = addImpl(DynamicCG, {});
if (PropList.has_property<property::node::depends_on_all_leaves>()) {
addGraphLeafDependencies(Node);
}
return Node;
}

/// Add a dependency between two nodes.
/// @param Src Node which will be a dependency of \p Dest.
/// @param Dest Node which will be dependent on \p Src.
Expand Down Expand Up @@ -328,6 +372,12 @@ class __SYCL_EXPORT modifiable_command_graph {
modifiable_command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
: impl(Impl) {}

/// Template-less implementation of add() for dynamic command-group nodes.
/// @param DynCGF Dynamic Command-group function object to add.
/// @param Dep List of predecessor nodes.
/// @return Node added to the graph.
node addImpl(dynamic_command_group &DynCGF, const std::vector<node> &Dep);

/// Template-less implementation of add() for CGF nodes.
/// @param CGF Command-group function to add.
/// @param Dep List of predecessor nodes.
Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3376,6 +3376,7 @@ class __SYCL_EXPORT handler {
size_t Size, bool Block = false);
friend class ext::oneapi::experimental::detail::graph_impl;
friend class ext::oneapi::experimental::detail::dynamic_parameter_impl;
friend class ext::oneapi::experimental::detail::dynamic_command_group_impl;

bool DisableRangeRounding();

Expand Down
106 changes: 103 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,18 @@ exec_graph_impl::enqueueNodeDirect(sycl::context Ctx,
StreamID, InstanceID, CmdTraceEvent, xpti::trace_task_begin, nullptr);
#endif

std::vector<sycl::detail::CGExecKernel *> KernelAlternatives{};
if (Node->MDynCG) {
for (auto &CG : Node->MDynCG->MKernels) {
KernelAlternatives.push_back(
static_cast<sycl::detail::CGExecKernel *>(CG.get()));
}
}

ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel(
Ctx, DeviceImpl, CommandBuffer,
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
Deps, &NewSyncPoint, &NewCommand, nullptr);
KernelAlternatives, Deps, &NewSyncPoint, &NewCommand, nullptr);

MCommandMap[Node] = NewCommand;

Expand Down Expand Up @@ -1376,8 +1384,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());

// Gather arg information from Node
auto &ExecCG =
*(static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get()));
sycl::detail::CG *NodeCG = (Node->MDynCG)
? Node->MDynCG->getActiveKernel().get()
: Node->MCommandGroup.get();
auto ExecCG = *(static_cast<sycl::detail::CGExecKernel *>(NodeCG));

// Copy args because we may modify them
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
// Copy NDR desc since we need to modify it
Expand Down Expand Up @@ -1560,6 +1571,39 @@ modifiable_command_graph::modifiable_command_graph(
: impl(std::make_shared<detail::graph_impl>(
SyclQueue.get_context(), SyclQueue.get_device(), PropList)) {}

node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
auto DynCGFImpl = sycl::detail::getSyclObjImpl(DynCGF);

if (DynCGFImpl->MContext != impl->getContext()) {
throw sycl::exception(
make_error_code(sycl::errc::invalid),
"Context of dynamic command-group does not match graph.");
}

if (DynCGFImpl->MDevice != impl->getDevice()) {
throw sycl::exception(
make_error_code(sycl::errc::invalid),
"Device of dynamic command-group does not match graph.");
}

std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

const std::function<void(handler &)> &CGF = DynCGFImpl->getActiveCGF();

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);

// Track the dynamic command-group used inside the node object
NodeImpl->MDynCG = DynCGFImpl;

return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
Expand Down Expand Up @@ -1760,6 +1804,31 @@ void dynamic_parameter_base::updateAccessor(
impl->updateAccessor(Acc);
}

dynamic_command_group_impl::dynamic_command_group_impl(
const context &Context, const device &Device,
const std::vector<std::function<void(handler &)>> &CGFList)
: MContext(Context), MDevice(Device), MActiveCGF(0), MCGFList(CGFList) {

// Create a placeholder graph object so we can use it to construct a handler
// object to process the CGFs.
auto TmpGraph = std::make_shared<graph_impl>(MContext, MDevice);
for (const auto &CGF : MCGFList) {
// Handler defined inside the loop so it doesn't appear to the runtime
// as a single command-group with multiple commands inside.
sycl::handler Handler{TmpGraph};
CGF(Handler);

if (Handler.getType() != sycl::detail::CGType::Kernel) {
throw sycl::exception(
make_error_code(errc::invalid),
"The only type of command-groups that can be used in "
"dynamic command-groups is kernels.");
}

Handler.finalize();
MKernels.push_back(std::move(Handler.impl->MGraphNodeCG));
}
}
} // namespace detail

node_type node::get_type() const { return impl->MNodeType; }
Expand Down Expand Up @@ -1798,6 +1867,37 @@ template <> __SYCL_EXPORT void node::update_range<2>(range<2> Range) {
template <> __SYCL_EXPORT void node::update_range<3>(range<3> Range) {
impl->updateRange(Range);
}

dynamic_command_group::dynamic_command_group(
const context &SyclContext, const device &SyclDevice,
const std::vector<std::function<void(handler &)>> &CGFList)
: impl(std::make_shared<detail::dynamic_command_group_impl>(
SyclContext, SyclDevice, CGFList)) {
if (CGFList.empty()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Dynamic command-group cannot be created with an "
"empty CGF list.");
}
}

dynamic_command_group::dynamic_command_group(
const queue &SyclQueue,
const std::vector<std::function<void(handler &)>> &CGFList)
: impl(std::make_shared<detail::dynamic_command_group_impl>(
SyclQueue.get_context(), SyclQueue.get_device(), CGFList)) {
if (CGFList.empty()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Dynamic command-group cannot be created with an "
"empty CGF list.");
}
}

size_t dynamic_command_group::get_active_cgf() const {
return impl->getActiveIndex();
}
void dynamic_command_group::set_active_cgf(size_t Index) {
return impl->setActiveIndex(Index);
}
} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
40 changes: 39 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
/// Stores the executable graph impl associated with this node if it is a
/// subgraph node.
std::shared_ptr<exec_graph_impl> MSubGraphImpl;
/// For Dynamic command-group nodes, stores the dynamic command-group object.
std::shared_ptr<dynamic_command_group_impl> MDynCG;

/// Used for tracking visited status during cycle checks.
bool MVisited = false;
Expand Down Expand Up @@ -160,7 +162,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
: enable_shared_from_this(Other), MSuccessors(Other.MSuccessors),
MPredecessors(Other.MPredecessors), MCGType(Other.MCGType),
MNodeType(Other.MNodeType), MCommandGroup(Other.getCGCopy()),
MSubGraphImpl(Other.MSubGraphImpl) {}
MSubGraphImpl(Other.MSubGraphImpl), MDynCG(Other.MDynCG) {}

/// Copy-assignment operator. This will perform a deep-copy of the
/// command group object associated with this node.
Expand All @@ -172,6 +174,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
MNodeType = Other.MNodeType;
MCommandGroup = Other.getCGCopy();
MSubGraphImpl = Other.MSubGraphImpl;
MDynCG = Other.MDynCG;
}
return *this;
}
Expand Down Expand Up @@ -1579,6 +1582,41 @@ class dynamic_parameter_impl {
std::vector<std::byte> MValueStorage;
};

class dynamic_command_group_impl {
public:
dynamic_command_group_impl(
const sycl::context &Context, const sycl::device &Device,
const std::vector<std::function<void(handler &)>> &CGFList);

size_t getActiveIndex() const { return MActiveCGF; }
void setActiveIndex(size_t Index) {
if (Index >= MCGFList.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Index is out of range.");
}

MActiveCGF = Index;
}

const std::function<void(handler &)> &getActiveCGF() const {
return MCGFList[MActiveCGF];
}

const std::unique_ptr<sycl::detail::CG> &getActiveKernel() const {
return MKernels[MActiveCGF];
}

sycl::context MContext;
sycl::device MDevice;
size_t MActiveCGF; // TODO Thread safe?

// List of CGFs. Initialized on creation of dynamic command-group object by
// copying by value the list of std::functions passed by the user.
const std::vector<std::function<void(handler &)>> MCGFList;

/// List of kernel command-groups for dynamic command-group nodes
std::vector<std::unique_ptr<sycl::detail::CG>> MKernels;
};
} // namespace detail
} // namespace experimental
} // namespace oneapi
Expand Down
Loading

0 comments on commit c27ec13

Please sign in to comment.