Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Python plan attributes private #608

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 73 additions & 53 deletions python/cufinufft/cufinufft/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,32 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
self._plan = None

# Setup type bound methods
self.dtype = np.dtype(dtype)
self._dtype = np.dtype(dtype)

if self.dtype == np.float64:
if self._dtype == np.float64:
warnings.warn("Real dtypes are currently deprecated and will be "
"removed in version 2.3. Converting to complex128.",
DeprecationWarning)
self.dtype = np.complex128
self._dtype = np.complex128

if self.dtype == np.float32:
if self._dtype == np.float32:
warnings.warn("Real dtypes are currently deprecated and will be "
"removed in version 2.3. Converting to complex64.",
DeprecationWarning)
self.dtype = np.complex64
self._dtype = np.complex64

if self.dtype == np.complex128:
if self._dtype == np.complex128:
self._make_plan = _make_plan
self._setpts = _set_pts
self._exec_plan = _exec_plan
self._destroy_plan = _destroy_plan
self.real_dtype = np.float64
elif self.dtype == np.complex64:
self._real_dtype = np.float64
elif self._dtype == np.complex64:
self._make_plan = _make_planf
self._setpts = _set_ptsf
self._exec_plan = _exec_planf
self._destroy_plan = _destroy_planf
self.real_dtype = np.float32
self._real_dtype = np.float32
else:
raise TypeError("Expected complex64 or complex128.")

Expand All @@ -130,12 +130,12 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
if dim not in [1, 2, 3]:
raise ValueError("Only dimensions 1, 2, and 3 supported")

self.dim = dim
self.type = nufft_type
self.isign = isign
self.eps = float(eps)
self.n_modes = n_modes
self.n_trans = n_trans
self._dim = dim
self._type = nufft_type
self._isign = isign
self._eps = float(eps)
self._n_modes = n_modes
self._n_trans = n_trans
self._maxbatch = 1 # TODO: optimize this one day

# Get the default option values.
Expand All @@ -158,6 +158,26 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
# we want to keep around for life of instance.
self._references = []

@property
def type(self):
return self._type

@property
def dtype(self):
return self._dtype

@property
def dim(self):
return self._dim

@property
def n_modes(self):
return self._n_modes

@property
def n_trans(self):
return self._n_trans

@staticmethod
def _default_opts():
"""
Expand Down Expand Up @@ -186,15 +206,15 @@ def _init_plan(self):
# We extend the mode tuple to 3D as needed,
# and reorder from C/python ndarray.shape style input (nZ, nY, nX)
# to the (F) order expected by the low level library (nX, nY, nZ).
_n_modes = self.n_modes[::-1] + (1,) * (3 - self.dim)
_n_modes = self._n_modes[::-1] + (1,) * (3 - self._dim)
_n_modes = (c_int64 * 3)(*_n_modes)

ier = self._make_plan(self.type,
self.dim,
ier = self._make_plan(self._type,
self._dim,
_n_modes,
self.isign,
self.n_trans,
self.eps,
self._isign,
self._n_trans,
self._eps,
byref(self._plan),
self._opts)

Expand All @@ -221,20 +241,20 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
points (source for type 1, target for type 2).
"""

_x = _ensure_array_type(x, "x", self.real_dtype)
_y = _ensure_array_type(y, "y", self.real_dtype)
_z = _ensure_array_type(z, "z", self.real_dtype)
_x = _ensure_array_type(x, "x", self._real_dtype)
_y = _ensure_array_type(y, "y", self._real_dtype)
_z = _ensure_array_type(z, "z", self._real_dtype)

_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self.dim)
_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self._dim)

M = _compat.get_array_size(_x)

if self.type == 3:
_s = _ensure_array_type(s, "s", self.real_dtype)
_t = _ensure_array_type(t, "t", self.real_dtype)
_u = _ensure_array_type(u, "u", self.real_dtype)
if self._type == 3:
_s = _ensure_array_type(s, "s", self._real_dtype)
_t = _ensure_array_type(t, "t", self._real_dtype)
_u = _ensure_array_type(u, "u", self._real_dtype)

_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self.dim)
_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self._dim)

N = _compat.get_array_size(_s)
else:
Expand All @@ -254,22 +274,22 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
# We will also store references to these arrays.
# This keeps python from prematurely cleaning them up.
self._references.append(_x)
if self.dim >= 2:
if self._dim >= 2:
fpts_axes.insert(0, _compat.get_array_ptr(_y))
self._references.append(_y)
if self.dim >= 3:
if self._dim >= 3:
fpts_axes.insert(0, _compat.get_array_ptr(_z))
self._references.append(_z)

# Do the same for type 3
if self.type == 3:
if self._type == 3:
fpts_axes_t3 = [_compat.get_array_ptr(_s), None, None]
self._references.append(_s)
if self.dim >= 2:
if self._dim >= 2:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_t))
self._references.append(_t)

if self.dim >= 3:
if self._dim >= 3:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_u))
self._references.append(_u)
else:
Expand All @@ -280,8 +300,8 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
M, *fpts_axes[:3],
N, *fpts_axes_t3[:3])

self.nj = M
self.nk = N
self._nj = M
self._nk = N

if ier != 0:
raise RuntimeError('Error setting non-uniform points.')
Expand Down Expand Up @@ -309,37 +329,37 @@ def execute(self, data, out=None):
The output array of the transform(s).
"""

_data = _ensure_array_type(data, "data", self.dtype)
_out = _ensure_array_type(out, "out", self.dtype, output=True)
_data = _ensure_array_type(data, "data", self._dtype)
_out = _ensure_array_type(out, "out", self._dtype, output=True)

if self.type == 1:
req_data_shape = (self.n_trans, self.nj)
req_out_shape = self.n_modes
elif self.type == 2:
req_data_shape = (self.n_trans, *self.n_modes)
req_out_shape = (self.nj,)
elif self.type == 3:
req_data_shape = (self.n_trans, self.nj)
req_out_shape = (self.nk,)
if self._type == 1:
req_data_shape = (self._n_trans, self._nj)
req_out_shape = self._n_modes
elif self._type == 2:
req_data_shape = (self._n_trans, *self._n_modes)
req_out_shape = (self._nj,)
elif self._type == 3:
req_data_shape = (self._n_trans, self._nj)
req_out_shape = (self._nk,)

_data, data_shape = _ensure_array_shape(_data, "data", req_data_shape,
allow_reshape=True)
if self.type == 1:
if self._type == 1:
batch_shape = data_shape[:-1]
else:
batch_shape = data_shape[:-self.dim]
batch_shape = data_shape[:-self._dim]

req_out_shape = batch_shape + req_out_shape

if out is None:
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self.dtype)
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self._dtype)
else:
_out = _ensure_array_shape(_out, "out", req_out_shape)

if self.type in [1, 3]:
if self._type in [1, 3]:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_data),
_compat.get_array_ptr(_out))
elif self.type == 2:
elif self._type == 2:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_out),
_compat.get_array_ptr(_data))

Expand Down
30 changes: 30 additions & 0 deletions python/cufinufft/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,33 @@ def test_opts(to_gpu, to_cpu, shape=(8, 8, 8), M=32, tol=1e-3):
fk = to_cpu(fk_gpu)

utils.verify_type1(k, c, fk, tol)


def test_cufinufft_plan_properties():
nufft_type = 2
n_modes = (8, 8)
n_trans = 2
dtype = np.complex64

plan = Plan(nufft_type, n_modes, n_trans, dtype=dtype)

assert plan.type == nufft_type
assert tuple(plan.n_modes) == n_modes
assert plan.dim == len(n_modes)
assert plan.n_trans == n_trans
assert plan.dtype == dtype

with pytest.raises(AttributeError):
plan.type = 1

with pytest.raises(AttributeError):
plan.n_modes = (4, 4)

with pytest.raises(AttributeError):
plan.dim = 1

with pytest.raises(AttributeError):
plan.n_trans = 1

with pytest.raises(AttributeError):
plan.dtype = np.float64
Loading
Loading