Skip to content

Commit

Permalink
docs: update py getting started to use cupy
Browse files Browse the repository at this point in the history
Simpler.
  • Loading branch information
janden committed Dec 26, 2023
1 parent 1a615a3 commit 580e160
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions docs/python_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ Quick-start examples
As mentioned in the :ref:`Python GPU installation instructions <install-python-gpu>`, the easiest way to install the Python interface for cuFINUFFT is to run ``pip install cufinufft``.

Assuming cuFINUFFT has been installed, we will now consider how to calculate a 1D type 1 transform.
To manage the GPU and transfer to and from host and device, we will use the ``pycuda`` library.
To manage the GPU and transfer to and from host and device, we will use the ``cupy`` library.
Consequently, we start with a few import statements.

.. code-block:: python
import numpy as np
import pycuda.autoinit
from pycuda.gpuarray import to_gpu
import cupy as cp
import cufinufft
Expand All @@ -30,26 +27,26 @@ We then proceed to setting up a few parameters.
N = 200000
# generate positions for the nonuniform points and the coefficients
x = 2 * np.pi * np.random.uniform(size=M)
c = (np.random.standard_normal(size=M)
+ 1J * np.random.standard_normal(size=M))
x_gpu = 2 * cp.pi * cp.random.uniform(size=M)
c_gpu = (cp.random.standard_normal(size=M)
+ 1J * cp.random.standard_normal(size=M))
Now that the data is prepared, we need to set up a cuFINUFFT plan that can be executed on that data.

.. code-block:: python
# create plan
plan = cufinufft.Plan(1, (N,), dtype=np.float64)
plan = cufinufft.Plan(1, (N,), dtype="complex128")
# set the nonuniform points
plan.setpts(to_gpu(x))
plan.setpts(x_gpu)
With everything set up, we are now ready to execute the plan.

.. code-block:: python
# execute the plan
f_gpu = plan.execute(to_gpu(c))
f_gpu = plan.execute(c_gpu)
# move results off the GPU
f = f_gpu.get()
Expand Down

0 comments on commit 580e160

Please sign in to comment.