Skip to content

Commit

Permalink
Regards #690: We've extracted the free() calls out of the async sub…
Browse files Browse the repository at this point in the history
…-namespace + some comment tweaks and redundancy removals
  • Loading branch information
eyalroz committed Oct 28, 2024
1 parent 58cc11d commit b8dcd9b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
59 changes: 38 additions & 21 deletions src/cuda/api/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,41 +229,58 @@ inline region_t allocate(
return allocate_in_current_context(size_in_bytes, stream_handle);
}

} // namespace detail_

/// Free a region of device-side memory (regardless of how it was allocated)
inline void free(void* ptr)
#if CUDA_VERSION >= 11020
inline void free(
context::handle_t context_handle,
void* allocated_region_start,
optional<stream::handle_t> stream_handle = {})
#else
inline void free(
context::handle_t context_handle,
void* allocated_region_start)
#endif
{
auto result = cuMemFree(address(ptr));
#if CUDA_VERSION >= 11020
if (stream_handle) {
auto status = cuMemFreeAsync(device::address(allocated_region_start), *stream_handle);
throw_if_error_lazy(status,
"Failed scheduling an asynchronous freeing of the global memory region starting at "
+ cuda::detail_::ptr_as_hex(allocated_region_start) + " on "
+ stream::detail_::identify(*stream_handle, context_handle));
return;
}
#endif
auto result = cuMemFree(address(allocated_region_start));
#ifdef CAW_THROW_ON_FREE_IN_DESTROYED_CONTEXT
if (result == status::success) { return; }
#else
if (result == status::success or result == status::context_is_destroyed) { return; }
#endif
throw runtime_error(result, "Freeing device memory at " + cuda::detail_::ptr_as_hex(ptr));
throw runtime_error(result, "Freeing device memory at " + cuda::detail_::ptr_as_hex(allocated_region_start));
}

/// @copydoc free(void*)
inline void free(region_t region) { free(region.start()); }
} // namespace detail_

/// Free a region of device-side memory (regardless of how it was allocated)
#if CUDA_VERSION >= 11020
namespace async {

namespace detail_ {
inline void free(void* region_start, optional_ref<const stream_t> stream = {});
#else
inline void free(void* ptr);
#endif

inline void free(
context::handle_t context_handle,
stream::handle_t stream_handle,
void* allocated_region_start)
/// @copydoc free(void*, optional_ref<const stream_t>)
#if CUDA_VERSION >= 11020
inline void free(region_t region, optional_ref<const stream_t> stream = {})
#else
inline void free(region_t region)
#endif
{
auto status = cuMemFreeAsync(device::address(allocated_region_start), stream_handle);
throw_if_error_lazy(status,
"Failed scheduling an asynchronous freeing of the global memory region starting at "
+ cuda::detail_::ptr_as_hex(allocated_region_start) + " on "
+ stream::detail_::identify(stream_handle, context_handle) );
free(region.start(), stream);
}

} // namespace detail_
#if CUDA_VERSION >= 11020

namespace async {

/**
* Schedule a de-allocation of device-side memory on a CUDA stream.
Expand Down
19 changes: 15 additions & 4 deletions src/cuda/api/multi_wrapper_impls/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,24 @@ inline region_t allocate(size_t size_in_bytes, optional_ref<const stream_t> stre
detail_::allocate_in_current_context(size_in_bytes);
}

namespace async {
#endif // CUDA_VERSION >= 11020

inline void free(const stream_t& stream, void* region_start)
#if CUDA_VERSION >= 11020
inline void free(void* region_start, optional_ref<const stream_t> stream)
#else
inline void free(void* ptr)
#endif // CUDA_VERSION >= 11020
{
return detail_::free(stream.context().handle(), stream.handle(), region_start);
auto cch = context::current::detail_::get_handle();
#if CUDA_VERSION >= 11020
if (stream) {
detail_::free(cch, region_start, stream->handle());
}
#endif
detail_::free(cch,region_start);
}
#endif // CUDA_VERSION >= 11020

namespace async {

template <typename T>
inline void typed_set(T* start, const T& value, size_t num_elements, const stream_t& stream)
Expand Down
6 changes: 3 additions & 3 deletions src/cuda/api/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,14 +601,14 @@ class stream_t {
///@{
void free(void* region_start) const
{
memory::device::async::free(associated_stream, region_start);
memory::device::free(region_start, associated_stream);
}

void free(memory::region_t region) const
{
memory::device::async::free(associated_stream, region);
memory::device::free(region, associated_stream);
}
#endif
#endif // CUDA_VERSION >= 11020

/**
* Sets the attachment of a region of managed memory (i.e. in the address space visible
Expand Down

0 comments on commit b8dcd9b

Please sign in to comment.