Skip to content

Commit

Permalink
Closes Bears-R-Us#2695: Add uint to extrema methods (Bears-R-Us#2719)
Browse files Browse the repository at this point in the history
This PR (closes Bears-R-Us#2695) adds uint functionality to the extremea methods

Co-authored-by: Pierce Hayes <[email protected]>
  • Loading branch information
stress-tess and Pierce Hayes authored Aug 28, 2023
1 parent 9ec222c commit 04386f0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 34 deletions.
12 changes: 6 additions & 6 deletions PROTO_tests/tests/extrema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import arkouda as ak

NO_UINT = ["int64", "float64", "bool"]
NUMERIC_TYPES = ["int64", "uint64", "float64", "bool"]


def make_np_arrays(size, dtype):
if dtype == "int64":
return np.random.randint(-(2**32), 2**32, size=size, dtype=dtype)
elif dtype == "uint64":
return ak.cast(ak.randint(-(2**32), 2**32, size=size), dtype)
elif dtype == "float64":
return np.random.uniform(-(2**32), 2**32, size=size)
elif dtype == "bool":
Expand All @@ -18,10 +20,9 @@ def make_np_arrays(size, dtype):

class TestExtrema:
@pytest.mark.parametrize("prob_size", pytest.prob_size)
@pytest.mark.parametrize("dtype", ["int64", "float64"])
@pytest.mark.parametrize("dtype", ["int64", "uint64", "float64"])
def test_extrema(self, prob_size, dtype):
# TODO add testing for uint once #2695 is completed
pda = ak.randint(-(2**32), 2**32, size=prob_size, dtype=dtype)
pda = ak.array(make_np_arrays(prob_size, dtype))
ak_sorted = ak.sort(pda)
K = prob_size // 2

Expand All @@ -33,9 +34,8 @@ def test_extrema(self, prob_size, dtype):
assert (ak.maxk(pda, K) == ak_sorted[-K:]).all()
assert (pda[ak.argmaxk(pda, K)] == ak_sorted[-K:]).all()

@pytest.mark.parametrize("dtype", NO_UINT)
@pytest.mark.parametrize("dtype", NUMERIC_TYPES)
def test_argmin_and_argmax(self, dtype):
# TODO add testing for uint once #2695 is completed
np_arr = make_np_arrays(1000, dtype)
ak_arr = ak.array(np_arr)

Expand Down
14 changes: 7 additions & 7 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def unescape(s):
mydtype = dtype(dtname)
if mydtype == bigint:
# we have to strip off quotes prior to 1.32
if value[0] == "\"":
if value[0] == '"':
return int(value[1:-1])
else:
return int(value)
Expand Down Expand Up @@ -832,13 +832,13 @@ def max(self) -> numpy_scalars:
"""
return max(self)

def argmin(self) -> np.int64:
def argmin(self) -> Union[np.int64, np.uint64]:
"""
Return the index of the first occurrence of the array min value
"""
return argmin(self)

def argmax(self) -> np.int64:
def argmax(self) -> Union[np.int64, np.uint64]:
"""
Return the index of the first occurrence of the array max value.
"""
Expand Down Expand Up @@ -2179,7 +2179,7 @@ def max(pda: pdarray) -> numpy_scalars:


@typechecked
def argmin(pda: pdarray) -> np.int64:
def argmin(pda: pdarray) -> Union[np.int64, np.uint64]:
"""
Return the index of the first occurrence of the array min value.
Expand All @@ -2190,7 +2190,7 @@ def argmin(pda: pdarray) -> np.int64:
Returns
-------
np.int64
Union[np.int64, np.uint64]
The index of the argmin calculated from the pda
Raises
Expand All @@ -2205,7 +2205,7 @@ def argmin(pda: pdarray) -> np.int64:


@typechecked
def argmax(pda: pdarray) -> np.int64:
def argmax(pda: pdarray) -> Union[np.int64, np.uint64]:
"""
Return the index of the first occurrence of the array max value.
Expand All @@ -2216,7 +2216,7 @@ def argmax(pda: pdarray) -> np.int64:
Returns
-------
np.int64
Union[np.int64, np.uint64]
The index of the argmax calculated from the pda
Raises
Expand Down
61 changes: 40 additions & 21 deletions src/KExtremeMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,41 @@ module KExtremeMsg
select(gEnt.dtype) {
when (DType.Int64) {
var e = toSymEntry(gEnt,int);
var aV;

if !returnIndices {
aV = computeExtremaValues(e.a, k);
} else {
aV = computeExtremaIndices(e.a, k);
}

var aV = if !returnIndices then computeExtremaValues(e.a, k) else computeExtremaIndices(e.a, k);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
when (DType.UInt64) {
var e = toSymEntry(gEnt,uint);
if !returnIndices {
var aV = computeExtremaValues(e.a, k);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
} else {
var aV = computeExtremaIndices(e.a, k);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
}
when (DType.Float64) {
var e = toSymEntry(gEnt,real);
if !returnIndices {
var e = toSymEntry(gEnt,real);
var aV = computeExtremaValues(e.a, k);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
} else {
var e = toSymEntry(gEnt,real);
var aV = computeExtremaIndices(e.a, k);
st.addEntry(vname, new shared SymEntry(aV));

Expand Down Expand Up @@ -109,33 +119,42 @@ module KExtremeMsg
select(gEnt.dtype) {
when (DType.Int64) {
var e = toSymEntry(gEnt,int);
var aV;
if !returnIndices {
aV = computeExtremaValues(e.a, k, false);
} else {
aV = computeExtremaIndices(e.a, k, false);
}

var aV = if !returnIndices then computeExtremaValues(e.a, k, false) else computeExtremaIndices(e.a, k, false);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
when (DType.Float64) {
when (DType.UInt64) {
var e = toSymEntry(gEnt,uint);
if !returnIndices {
var e = toSymEntry(gEnt,real);
var aV = computeExtremaValues(e.a, k, false);

st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
} else {
var e = toSymEntry(gEnt,real);
var aV = computeExtremaIndices(e.a, k, false);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
}
when (DType.Float64) {
var e = toSymEntry(gEnt,real);
if !returnIndices {
var aV = computeExtremaValues(e.a, k, false);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
keLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
} else {
var aV = computeExtremaIndices(e.a, k, false);
st.addEntry(vname, new shared SymEntry(aV));

repMsg = "created " + st.attrib(vname);
Expand Down

0 comments on commit 04386f0

Please sign in to comment.