Skip to content

Commit

Permalink
Array-API slice Assignment (Bears-R-Us#3166)
Browse files Browse the repository at this point in the history
* implement __setitem__ for multidimensional array assignment

Signed-off-by: Jeremiah Corrado <[email protected]>

* fix flake8 ltl's

Signed-off-by: Jeremiah Corrado <[email protected]>

---------

Signed-off-by: Jeremiah Corrado <[email protected]>
  • Loading branch information
jeremiah-corrado authored May 20, 2024
1 parent b435ed4 commit 828150f
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 19 deletions.
16 changes: 14 additions & 2 deletions arkouda/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,21 @@ def __setitem__(
/,
) -> None:
if isinstance(key, Array):
self._array[key._array] = value
if isinstance(value, Array):
if value.size == 1 or value.shape == ():
self._array[key._array] = value._array[0]
else:
self._array[key._array] = value._array
else:
self._array[key._array] = value
else:
self._array[key] = value
if isinstance(value, Array):
if value.size == 1 or value.shape == ():
self._array[key] = value._array[0]
else:
self._array[key] = value._array
else:
self._array[key] = value

def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
if isinstance(other, (int, float)):
Expand Down
77 changes: 68 additions & 9 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,12 +950,12 @@ def __setitem__(self, key, value):
logger.debug(f"start: {start} stop: {stop} stride: {stride}")
if isinstance(value, pdarray):
generic_msg(
cmd="[slice]=pdarray",
cmd="[slice]=pdarray-1D",
args={
"array": self,
"start": start,
"stop": stop,
"stride": stride,
"starts": start,
"stops": stop,
"strides": stride,
"value": value,
},
)
Expand All @@ -974,14 +974,16 @@ def __setitem__(self, key, value):
else:
raise TypeError(f"Unhandled key type: {key} ({type(key)})")
else:
if isinstance(key, tuple) and not isinstance(value, pdarray):
allScalar = True
if isinstance(key, tuple):
# TODO: add support for an Ellipsis in the key tuple
# (inserts ':' for any unspecified dimensions)
all_scalar_keys = True
starts = []
stops = []
strides = []
for dim, k in enumerate(key):
if isinstance(k, slice):
allScalar = False
all_scalar_keys = False
(start, stop, stride) = k.indices(self.shape[dim])
starts.append(start)
stops.append(stop)
Expand All @@ -998,10 +1000,67 @@ def __setitem__(self, key, value):
else:
# treat this as a single element slice
starts.append(k)
stops.append(k + 1)
stops.append(k+1)
strides.append(1)

if allScalar:
if isinstance(value, pdarray):
if len(starts) == self.ndim:
slice_shape = tuple([stops[i] - starts[i] for i in range(self.ndim)])

# check that the slice is within the bounds of the array
for i in range(self.ndim):
if slice_shape[i] > self.shape[i]:
raise ValueError(
f"slice indices ({key}) out of bounds for array of " +
f"shape {self.shape}"
)

if value.ndim == len(slice_shape):
# check that the slice shape matches the value shape
for i in range(self.ndim):
if slice_shape[i] != value.shape[i]:
raise ValueError(
f"slice shape ({slice_shape}) must match shape of value " +
f"array ({value.shape})"
)
value_ = value
elif value.ndim < len(slice_shape):
# check that the value shape is compatible with the slice shape
iv = 0
for i in range(self.ndim):
if slice_shape[i] == 1:
continue
elif slice_shape[i] == value.shape[iv]:
iv += 1
else:
raise ValueError(
f"slice shape ({slice_shape}) must be compatible with shape " +
f"of value array ({value.shape})"
)

# reshape to add singleton dimensions as needed
value_ = _reshape(value, slice_shape)
else:
raise ValueError(
f"value array must not have more dimensions ({value.ndim}) than the" +
f"slice ({len(slice_shape)})"
)
else:
raise ValueError(
f"slice rank ({len(starts)}) must match array rank ({self.ndim})"
)

generic_msg(
cmd=f"[slice]=pdarray-{self.ndim}D",
args={
"array": self,
"starts": tuple(starts),
"stops": tuple(stops),
"strides": tuple(strides),
"value": value_,
},
)
elif all_scalar_keys:
# use simpler indexing if we got a tuple of only scalars
generic_msg(
cmd=f"[int]=val-{self.ndim}D",
Expand Down
111 changes: 105 additions & 6 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -1418,14 +1418,114 @@ module IndexingMsg
imLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}


@arkouda.registerND(cmd_prefix="[slice]=pdarray-")
proc setSliceIndexToPdarrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
// take simplified path for 1D case
if nd == 1 then return setSliceIndexToPdarrayMsg1D(cmd, msgArgs, st);

param pn = Reflection.getRoutineName();
const starts = msgArgs.get("starts").getTuple(nd),
stops = msgArgs.get("stops").getTuple(nd),
strides = msgArgs.get("strides").getTuple(nd),
name = msgArgs.getValueOf("array"),
yname = msgArgs.getValueOf("value");

var sliceRanges: nd * stridableRange;
for param dim in 0..<nd do
sliceRanges[dim] = convertSlice(starts[dim], stops[dim], strides[dim]);
const sliceDom = {(...sliceRanges)};

imLogger.debug(getModuleName(),pn,getLineNumber(),
"%s into: '%s' over domain: '%?' from: %s''"
.doFormat(cmd, name, sliceDom, yname));

var gX: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st),
gY: borrowed GenSymEntry = getGenericTypedArrayEntry(yname, st);

proc sliceAssignHelper(type xt, type yt, param adjustMaxSize=false): MsgTuple throws {
// note 'value'/'y' needs to be expanded to match 'array'/'x's rank before
// calling this command
var ex = toSymEntry(gX,xt,nd);
const ey = toSymEntry(gY,yt,nd);

// ensure the slice assignment is valid
for dim in 0..<nd {
if ey.tupShape[dim] != sliceDom.dim[dim].size {
const errMsg = "shape of slice does not match array in dimension %i".doFormat(dim) +
" (%i != %i)".doFormat(ey.tupShape[dim], sliceDom.dim[dim].size);
imLogger.error(getModuleName(),pn,getLineNumber(),errMsg);
return new MsgTuple(errMsg, MsgType.ERROR);
}
if sliceDom.dim[dim].low < 0 || sliceDom.dim[dim].high > ex.tupShape[dim] {
const errMsg = "slice indices out of bounds in dimension %i".doFormat(dim) +
" (%i..%i not in 0..<%i)".doFormat(sliceDom.dim[dim].low,
sliceDom.dim[dim].high, ex.tupShape[dim]);
imLogger.error(getModuleName(),pn,getLineNumber(),errMsg);
return new MsgTuple(errMsg, MsgType.ERROR);
}
}

// adjust y's max size for bigint arrays
if adjustMaxSize {
var ya = ey.a:bigint;
if ex.max_bits != -1 {
var max_size = 1:bigint;
max_size <<= ex.max_bits;
max_size -= 1;
forall y in ya with (const local_max_size = max_size) {
y &= local_max_size;
}
}

ex.a[sliceDom] = ya;
} else {
// otherwise, just assign the values
ex.a[sliceDom] = ey.a:xt;
}

const repMsg = "%s success".doFormat(pn);
imLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select (gX.dtype, gY.dtype) {
when (DType.Int64, DType.Int64) do return sliceAssignHelper(int, int);
when (DType.Int64, DType.UInt64) do return sliceAssignHelper(int, uint);
when (DType.Int64, DType.Float64) do return sliceAssignHelper(int, real);
when (DType.Int64, DType.Bool) do return sliceAssignHelper(int, bool);
when (DType.UInt64, DType.Int64) do return sliceAssignHelper(uint, int);
when (DType.UInt64, DType.UInt64) do return sliceAssignHelper(uint, uint);
when (DType.UInt64, DType.Float64) do return sliceAssignHelper(uint, real);
when (DType.UInt64, DType.Bool) do return sliceAssignHelper(uint, bool);
when (DType.Float64, DType.Int64) do return sliceAssignHelper(real, int);
when (DType.Float64, DType.UInt64) do return sliceAssignHelper(real, uint);
when (DType.Float64, DType.Float64) do return sliceAssignHelper(real, real);
when (DType.Float64, DType.Bool) do return sliceAssignHelper(real, bool);
when (DType.Bool, DType.Int64) do return sliceAssignHelper(bool, int);
when (DType.Bool, DType.UInt64) do return sliceAssignHelper(bool, uint);
when (DType.Bool, DType.Float64) do return sliceAssignHelper(bool, real);
when (DType.Bool, DType.Bool) do return sliceAssignHelper(bool, bool);
when (DType.BigInt, DType.BigInt) do return sliceAssignHelper(bigint, bigint, true);
when (DType.BigInt, DType.Int64) do return sliceAssignHelper(bigint, int, true);
when (DType.BigInt, DType.UInt64) do return sliceAssignHelper(bigint, uint, true);
when (DType.BigInt, DType.Bool) do return sliceAssignHelper(bigint, bool, true);
otherwise {
const errorMsg = notImplementedError(pn,
"("+dtype2str(gX.dtype)+","+dtype2str(gY.dtype)+")");
imLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}

/* setSliceIndexToPdarray "a[slice] = pdarray" response to __setitem__(slice, pdarray) */
proc setSliceIndexToPdarrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
proc setSliceIndexToPdarrayMsg1D(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
param pn = Reflection.getRoutineName();
var repMsg: string; // response message
const start = msgArgs.get("start").getIntValue();
const stop = msgArgs.get("stop").getIntValue();
const stride = msgArgs.get("stride").getIntValue();
const start = msgArgs.get("starts").getIntValue();
const stop = msgArgs.get("stops").getIntValue();
const stride = msgArgs.get("strides").getIntValue();
var slice: stridableRange;

const name = msgArgs.getValueOf("array");
Expand Down Expand Up @@ -1691,5 +1791,4 @@ module IndexingMsg
registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName());
registerFunction("[pdarray]=val", setPdarrayIndexToValueMsg, getModuleName());
registerFunction("[pdarray]=pdarray", setPdarrayIndexToPdarrayMsg, getModuleName());
registerFunction("[slice]=pdarray", setSliceIndexToPdarrayMsg, getModuleName());
}
21 changes: 19 additions & 2 deletions tests/array_api/indexing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

from base_test import ArkoudaTest
from context import arkouda as ak
import arkouda.array_api as Array
Expand All @@ -16,6 +14,25 @@ def randArr(shape):


class IndexingTests(ArkoudaTest):
def test_rank_changing_assignment(self):
a = randArr((5, 6, 7))
b = randArr((5, 6))
c = randArr((6, 7))
d = randArr((6,))
e = randArr((5, 6, 7))

a[:, :, 0] = b
self.assertEqual((a[:, :, 0]).tolist(), b.tolist())

a[1, :, :] = c
self.assertEqual((a[1, :, :]).tolist(), c.tolist())

a[2, :, 3] = d
self.assertEqual((a[2, :, 3]).tolist(), d.tolist())

a[:, :, :] = e
self.assertEqual(a.tolist(), e.tolist())

def test_pdarray_index(self):
a = randArr((5, 6, 7))
anp = np.asarray(a.tolist())
Expand Down

0 comments on commit 828150f

Please sign in to comment.