From 828150f817cdcc094b26f7245ebfcc27f02050f8 Mon Sep 17 00:00:00 2001 From: jeremiah-corrado <62707311+jeremiah-corrado@users.noreply.github.com> Date: Mon, 20 May 2024 11:07:19 -0600 Subject: [PATCH] Array-API slice Assignment (#3166) * implement __setitem__ for multidimensional array assignment Signed-off-by: Jeremiah Corrado * fix flake8 ltl's Signed-off-by: Jeremiah Corrado --------- Signed-off-by: Jeremiah Corrado --- arkouda/array_api/_array_object.py | 16 ++++- arkouda/pdarrayclass.py | 77 +++++++++++++++++--- src/IndexingMsg.chpl | 111 +++++++++++++++++++++++++++-- tests/array_api/indexing.py | 21 +++++- 4 files changed, 206 insertions(+), 19 deletions(-) diff --git a/arkouda/array_api/_array_object.py b/arkouda/array_api/_array_object.py index 93d368bbb2..55f0d832d4 100644 --- a/arkouda/array_api/_array_object.py +++ b/arkouda/array_api/_array_object.py @@ -541,9 +541,21 @@ def __setitem__( /, ) -> None: if isinstance(key, Array): - self._array[key._array] = value + if isinstance(value, Array): + if value.size == 1 or value.shape == (): + self._array[key._array] = value._array[0] + else: + self._array[key._array] = value._array + else: + self._array[key._array] = value else: - self._array[key] = value + if isinstance(value, Array): + if value.size == 1 or value.shape == (): + self._array[key] = value._array[0] + else: + self._array[key] = value._array + else: + self._array[key] = value def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: if isinstance(other, (int, float)): diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 8afa556e0f..f8b0a6d047 100755 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -950,12 +950,12 @@ def __setitem__(self, key, value): logger.debug(f"start: {start} stop: {stop} stride: {stride}") if isinstance(value, pdarray): generic_msg( - cmd="[slice]=pdarray", + cmd="[slice]=pdarray-1D", args={ "array": self, - "start": start, - "stop": stop, - "stride": stride, + "starts": start, + "stops": stop, + "strides": stride, "value": value, }, ) @@ -974,14 +974,16 @@ def __setitem__(self, key, value): else: raise TypeError(f"Unhandled key type: {key} ({type(key)})") else: - if isinstance(key, tuple) and not isinstance(value, pdarray): - allScalar = True + if isinstance(key, tuple): + # TODO: add support for an Ellipsis in the key tuple + # (inserts ':' for any unspecified dimensions) + all_scalar_keys = True starts = [] stops = [] strides = [] for dim, k in enumerate(key): if isinstance(k, slice): - allScalar = False + all_scalar_keys = False (start, stop, stride) = k.indices(self.shape[dim]) starts.append(start) stops.append(stop) @@ -998,10 +1000,67 @@ def __setitem__(self, key, value): else: # treat this as a single element slice starts.append(k) - stops.append(k + 1) + stops.append(k+1) strides.append(1) - if allScalar: + if isinstance(value, pdarray): + if len(starts) == self.ndim: + slice_shape = tuple([stops[i] - starts[i] for i in range(self.ndim)]) + + # check that the slice is within the bounds of the array + for i in range(self.ndim): + if slice_shape[i] > self.shape[i]: + raise ValueError( + f"slice indices ({key}) out of bounds for array of " + + f"shape {self.shape}" + ) + + if value.ndim == len(slice_shape): + # check that the slice shape matches the value shape + for i in range(self.ndim): + if slice_shape[i] != value.shape[i]: + raise ValueError( + f"slice shape ({slice_shape}) must match shape of value " + + f"array ({value.shape})" + ) + value_ = value + elif value.ndim < len(slice_shape): + # check that the value shape is compatible with the slice shape + iv = 0 + for i in range(self.ndim): + if slice_shape[i] == 1: + continue + elif slice_shape[i] == value.shape[iv]: + iv += 1 + else: + raise ValueError( + f"slice shape ({slice_shape}) must be compatible with shape " + + f"of value array ({value.shape})" + ) + + # reshape to add singleton dimensions as needed + value_ = _reshape(value, slice_shape) + else: + raise ValueError( + f"value array must not have more dimensions ({value.ndim}) than the" + + f"slice ({len(slice_shape)})" + ) + else: + raise ValueError( + f"slice rank ({len(starts)}) must match array rank ({self.ndim})" + ) + + generic_msg( + cmd=f"[slice]=pdarray-{self.ndim}D", + args={ + "array": self, + "starts": tuple(starts), + "stops": tuple(stops), + "strides": tuple(strides), + "value": value_, + }, + ) + elif all_scalar_keys: # use simpler indexing if we got a tuple of only scalars generic_msg( cmd=f"[int]=val-{self.ndim}D", diff --git a/src/IndexingMsg.chpl b/src/IndexingMsg.chpl index 6f80ec42de..b0c15eb62b 100644 --- a/src/IndexingMsg.chpl +++ b/src/IndexingMsg.chpl @@ -1418,14 +1418,114 @@ module IndexingMsg imLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); } - + + @arkouda.registerND(cmd_prefix="[slice]=pdarray-") + proc setSliceIndexToPdarrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + // take simplified path for 1D case + if nd == 1 then return setSliceIndexToPdarrayMsg1D(cmd, msgArgs, st); + + param pn = Reflection.getRoutineName(); + const starts = msgArgs.get("starts").getTuple(nd), + stops = msgArgs.get("stops").getTuple(nd), + strides = msgArgs.get("strides").getTuple(nd), + name = msgArgs.getValueOf("array"), + yname = msgArgs.getValueOf("value"); + + var sliceRanges: nd * stridableRange; + for param dim in 0.. ex.tupShape[dim] { + const errMsg = "slice indices out of bounds in dimension %i".doFormat(dim) + + " (%i..%i not in 0..<%i)".doFormat(sliceDom.dim[dim].low, + sliceDom.dim[dim].high, ex.tupShape[dim]); + imLogger.error(getModuleName(),pn,getLineNumber(),errMsg); + return new MsgTuple(errMsg, MsgType.ERROR); + } + } + + // adjust y's max size for bigint arrays + if adjustMaxSize { + var ya = ey.a:bigint; + if ex.max_bits != -1 { + var max_size = 1:bigint; + max_size <<= ex.max_bits; + max_size -= 1; + forall y in ya with (const local_max_size = max_size) { + y &= local_max_size; + } + } + + ex.a[sliceDom] = ya; + } else { + // otherwise, just assign the values + ex.a[sliceDom] = ey.a:xt; + } + + const repMsg = "%s success".doFormat(pn); + imLogger.debug(getModuleName(),pn,getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + + select (gX.dtype, gY.dtype) { + when (DType.Int64, DType.Int64) do return sliceAssignHelper(int, int); + when (DType.Int64, DType.UInt64) do return sliceAssignHelper(int, uint); + when (DType.Int64, DType.Float64) do return sliceAssignHelper(int, real); + when (DType.Int64, DType.Bool) do return sliceAssignHelper(int, bool); + when (DType.UInt64, DType.Int64) do return sliceAssignHelper(uint, int); + when (DType.UInt64, DType.UInt64) do return sliceAssignHelper(uint, uint); + when (DType.UInt64, DType.Float64) do return sliceAssignHelper(uint, real); + when (DType.UInt64, DType.Bool) do return sliceAssignHelper(uint, bool); + when (DType.Float64, DType.Int64) do return sliceAssignHelper(real, int); + when (DType.Float64, DType.UInt64) do return sliceAssignHelper(real, uint); + when (DType.Float64, DType.Float64) do return sliceAssignHelper(real, real); + when (DType.Float64, DType.Bool) do return sliceAssignHelper(real, bool); + when (DType.Bool, DType.Int64) do return sliceAssignHelper(bool, int); + when (DType.Bool, DType.UInt64) do return sliceAssignHelper(bool, uint); + when (DType.Bool, DType.Float64) do return sliceAssignHelper(bool, real); + when (DType.Bool, DType.Bool) do return sliceAssignHelper(bool, bool); + when (DType.BigInt, DType.BigInt) do return sliceAssignHelper(bigint, bigint, true); + when (DType.BigInt, DType.Int64) do return sliceAssignHelper(bigint, int, true); + when (DType.BigInt, DType.UInt64) do return sliceAssignHelper(bigint, uint, true); + when (DType.BigInt, DType.Bool) do return sliceAssignHelper(bigint, bool, true); + otherwise { + const errorMsg = notImplementedError(pn, + "("+dtype2str(gX.dtype)+","+dtype2str(gY.dtype)+")"); + imLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + } + /* setSliceIndexToPdarray "a[slice] = pdarray" response to __setitem__(slice, pdarray) */ - proc setSliceIndexToPdarrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { + proc setSliceIndexToPdarrayMsg1D(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { param pn = Reflection.getRoutineName(); var repMsg: string; // response message - const start = msgArgs.get("start").getIntValue(); - const stop = msgArgs.get("stop").getIntValue(); - const stride = msgArgs.get("stride").getIntValue(); + const start = msgArgs.get("starts").getIntValue(); + const stop = msgArgs.get("stops").getIntValue(); + const stride = msgArgs.get("strides").getIntValue(); var slice: stridableRange; const name = msgArgs.getValueOf("array"); @@ -1691,5 +1791,4 @@ module IndexingMsg registerFunction("[pdarray]", pdarrayIndexMsg, getModuleName()); registerFunction("[pdarray]=val", setPdarrayIndexToValueMsg, getModuleName()); registerFunction("[pdarray]=pdarray", setPdarrayIndexToPdarrayMsg, getModuleName()); - registerFunction("[slice]=pdarray", setSliceIndexToPdarrayMsg, getModuleName()); } diff --git a/tests/array_api/indexing.py b/tests/array_api/indexing.py index c67d5d0ee2..3e6c48bf4b 100644 --- a/tests/array_api/indexing.py +++ b/tests/array_api/indexing.py @@ -1,5 +1,3 @@ -import unittest - from base_test import ArkoudaTest from context import arkouda as ak import arkouda.array_api as Array @@ -16,6 +14,25 @@ def randArr(shape): class IndexingTests(ArkoudaTest): + def test_rank_changing_assignment(self): + a = randArr((5, 6, 7)) + b = randArr((5, 6)) + c = randArr((6, 7)) + d = randArr((6,)) + e = randArr((5, 6, 7)) + + a[:, :, 0] = b + self.assertEqual((a[:, :, 0]).tolist(), b.tolist()) + + a[1, :, :] = c + self.assertEqual((a[1, :, :]).tolist(), c.tolist()) + + a[2, :, 3] = d + self.assertEqual((a[2, :, 3]).tolist(), d.tolist()) + + a[:, :, :] = e + self.assertEqual(a.tolist(), e.tolist()) + def test_pdarray_index(self): a = randArr((5, 6, 7)) anp = np.asarray(a.tolist())