Skip to content

Commit

Permalink
Although semantically there is no dtoh copy in the compiler backend, …
Browse files Browse the repository at this point in the history
…a copy of some kind is still needed. Plus a test.
  • Loading branch information
isazi committed Sep 19, 2024
1 parent 4587539 commit 4c77414
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
29 changes: 19 additions & 10 deletions kernel_tuner/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,24 +356,33 @@ def memset(self, allocation, value, size):
C.memset(allocation.ctypes, value, size)

def memcpy_dtoh(self, dest, src):
"""There is no memcpy_dtoh for the compiler backend."""
pass
"""This method implements the semantic of a device to host copy for the Compiler backend.
There is no actual copy from device to host happening, but host to host.
:param dest: A numpy or cupy array to store the data
:type dest: np.ndarray or cupy.ndarray
:param src: An Argument for some memory allocation
:type src: Argument
"""
# there is no real copy from device to host, but host to host
if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy):
# Implicit conversion to a NumPy array is not allowed.
value = src.numpy.get()
else:
value = src.numpy
xp = get_array_module(dest)
dest[:] = xp.asarray(value)

def memcpy_htod(self, dest, src):
"""There is no memcpy_htod for the compiler backend."""
"""There is no memcpy_htod implemented for the compiler backend."""
pass

def refresh_memory(self, arguments, should_sync):
"""Copy the preserved content of the output memory to used arrays."""
for i, arg in enumerate(arguments):
if should_sync[i]:
if isinstance(arg, np.ndarray) and is_cupy_array(self.allocations[i].numpy):
# Implicit conversion to a NumPy array is not allowed.
value = self.allocations[i].numpy.get()
else:
value = self.allocations[i].numpy
xp = get_array_module(arg)
arg[:] = xp.asarray(value)
self.memcpy_dtoh(arg, self.allocations[i])

def cleanup_lib(self):
"""Unload the previously loaded shared library"""
Expand Down
11 changes: 11 additions & 0 deletions test/test_compiler_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,14 @@ def test_refresh_memory():
assert np.all(arguments[0] == [0, 0, 0])
cfunc.refresh_memory(arguments, [True])
assert np.all(arguments[0] == [1, 2, 3])


def test_memcpy_dtoh():
arg1 = np.array([0, 5, 0, 7]).astype(np.int32)
arguments = [arg1]
cfunc = CompilerFunctions()
ready_arguments = cfunc.ready_argument_list(arguments)
expected = np.array([0, 0, 0, 0]).astype(np.float32)
assert np.all(ready_arguments.numpy != expected)
cfunc.memcpy_dtoh(expected, ready_arguments)
assert np.all(ready_arguments.numpy == expected)

0 comments on commit 4c77414

Please sign in to comment.