Skip to content

Commit

Permalink
Merge pull request #295 from clEsperanto/bugfix_range
Browse files Browse the repository at this point in the history
Bugfix range
  • Loading branch information
haesleinhuepf authored Apr 2, 2023
2 parents d07ce98 + 486a000 commit ca95804
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyclesperanto_prototype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
from ._tier10 import *
from ._tier11 import *

__version__ = "0.23.6"
__version__ = "0.24.0"
__common_alias__ = "cle"
2 changes: 2 additions & 0 deletions pyclesperanto_prototype/_tier0/_array_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def __setitem__(self, index, value):

def __getitem__(self, index):
result = None
if isinstance(index, slice):
index = (index,)
if isinstance(index, list):
index = tuple(index)
if isinstance(index, (tuple, np.ndarray)) and index[0] is not None and isinstance(index[0], (tuple, list, np.ndarray)):
Expand Down
68 changes: 46 additions & 22 deletions pyclesperanto_prototype/_tier1/_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,20 @@ def range(source : Image,
-------
destination
"""

if start_x is None:
start_x = 0
if stop_x is None:
stop_x = source.shape[-1]
if step_x is None:
step_x = 1
if start_y is None:
start_y = 0
if stop_y is None:
stop_y = source.shape[-2]
if step_y is None:
step_y = 1

start_x, stop_x, step_x = correct_range(start_x, stop_x, step_x, source.shape[-1])
start_y, stop_y, step_y = correct_range(start_y, stop_y, step_y, source.shape[-2])
if len(source.shape) > 2:
if start_z is None:
start_z = 0
if stop_z is None:
stop_z = source.shape[0]
if step_z is None:
step_z = 1
start_z, stop_z, step_z = correct_range(start_z, stop_z, step_z, source.shape[-3])
else:
start_z = 0
stop_z = 1
step_z = 1

if destination is None:
if len(source.shape) > 2:
destination = create((stop_z - start_z, stop_y - start_y, stop_x - start_x), source.dtype)
destination = create((abs(stop_z - start_z), abs(stop_y - start_y), abs(stop_x - start_x)), source.dtype)
else:
destination = create((stop_y - start_y, stop_x - start_x), source.dtype)
destination = create((abs(stop_y - start_y), abs(stop_x - start_x)), source.dtype)

parameters = {
"dst":destination,
Expand All @@ -79,4 +62,45 @@ def range(source : Image,
}

execute(__file__, 'range_x.cl', 'range', destination.shape, parameters)

return destination


def correct_range(start, stop, step, size):
# set in case not set (passed None)
if step is None:
step = 1
if start is None:
if step >= 0:
start = 0
else:
start = size - 1

if stop is None:
if step >= 0:
stop = size
else:
stop = -1

# Check if ranges make sense
if start >= size:
if step >= 0:
start = size
else:
start = size - 1
if start < -size + 1:
start = -size + 1
if stop > size:
stop = size
if stop < -size:
if start > 0:
stop = 0 - 1
else:
stop = -size

if start < 0:
start = size - start
if (start > stop and step > 0) or (start < stop and step < 0):
stop = start

return start, stop, step
3 changes: 3 additions & 0 deletions pyclesperanto_prototype/_tier5/_array_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ def array_equal(source1: Image, source2: Image) -> bool:
from .._tier4 import mean_squared_error
if not np.array_equal(source1.shape, source2.shape):
return False
if np.prod(source1.shape) == 0 and np.prod(source2.shape) == 0:
# both empty arrays
return True

return mean_squared_error(source1, source2) == 0
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = pyclesperanto_prototype
version = 0.23.6
version = 0.24.0
author = Robert Haase
author_email = [email protected]
url = https://github.com/clEsperanto/pyclesperanto_prototype
Expand Down
91 changes: 91 additions & 0 deletions tests/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,94 @@ def test_types_2():
print(result)

assert np.allclose(reference, result, 0.0001)

def test_negative_step_2d():
import pyclesperanto_prototype as cle
import numpy as np

numbers = np.reshape(np.asarray([[i] for i in range(0, 20)]), (2, 10))

cle_numbers = cle.asarray(numbers)

assert cle.array_equal( numbers[::-1],
cle_numbers[::-1])
assert cle.array_equal( numbers[1::-1],
cle_numbers[1::-1])
assert cle.array_equal( numbers[:5:-1],
cle_numbers[:5:-1])
assert cle.array_equal( numbers[5:1:-1],
cle_numbers[5:1:-1])
assert cle.array_equal( numbers[100:],
cle_numbers[100:])
assert cle.array_equal( numbers[100::-1],
cle_numbers[100::-1])
assert cle.array_equal( numbers[100:-50:-1],
cle_numbers[100:-50:-1])

assert cle.array_equal( numbers[::,::-1],
cle_numbers[::,::-1])
assert cle.array_equal( numbers[::,1::-1],
cle_numbers[::,1::-1])
assert cle.array_equal( numbers[::,:5:-1],
cle_numbers[::,:5:-1])
assert cle.array_equal( numbers[::,5:1:-1],
cle_numbers[::,5:1:-1])
assert cle.array_equal( numbers[::,100:],
cle_numbers[::,100:])
assert cle.array_equal( numbers[::,100::-1],
cle_numbers[::,100::-1])
assert cle.array_equal( numbers[::,100:-50:-1],
cle_numbers[::,100:-50:-1])

def test_negative_step_3d():
import pyclesperanto_prototype as cle
import numpy as np

numbers = np.reshape(np.asarray([[i] for i in range(0,60)]),(3,4,5))

cle_numbers = cle.asarray(numbers)

assert cle.array_equal( numbers[::-1],
cle_numbers[::-1])
assert cle.array_equal( numbers[1::-1],
cle_numbers[1::-1])
assert cle.array_equal( numbers[:5:-1],
cle_numbers[:5:-1])
assert cle.array_equal( numbers[5:1:-1],
cle_numbers[5:1:-1])
assert cle.array_equal( numbers[100:],
cle_numbers[100:])
assert cle.array_equal( numbers[100::-1],
cle_numbers[100::-1])
assert cle.array_equal( numbers[100:-50:-1],
cle_numbers[100:-50:-1])

assert cle.array_equal( numbers[::, ::-1],
cle_numbers[::, ::-1])
assert cle.array_equal( numbers[::, 1::-1],
cle_numbers[::, 1::-1])
assert cle.array_equal( numbers[::, :5:-1],
cle_numbers[::, :5:-1])
assert cle.array_equal( numbers[::, 5:1:-1],
cle_numbers[::, 5:1:-1])
assert cle.array_equal( numbers[::, 100:],
cle_numbers[::, 100:])
assert cle.array_equal( numbers[::, 100::-1],
cle_numbers[::, 100::-1])
assert cle.array_equal( numbers[::, 100:-50:-1],
cle_numbers[::, 100:-50:-1])

assert cle.array_equal( numbers[::, ::, ::-1],
cle_numbers[::, ::, ::-1])
assert cle.array_equal( numbers[::, ::, 1::-1],
cle_numbers[::, ::, 1::-1])
assert cle.array_equal( numbers[::, ::, :5:-1],
cle_numbers[::, ::, :5:-1])
assert cle.array_equal( numbers[::, ::, 5:1:-1],
cle_numbers[::, ::, 5:1:-1])
assert cle.array_equal( numbers[::, ::, 100:],
cle_numbers[::, ::, 100:])
assert cle.array_equal( numbers[::, ::, 100::-1],
cle_numbers[::, ::, 100::-1])
assert cle.array_equal( numbers[::, ::, 100:-50:-1],
cle_numbers[::, ::, 100:-50:-1])

0 comments on commit ca95804

Please sign in to comment.