Skip to content

Commit

Permalink
Closes Bears-R-Us#3079 and Bears-R-Us#3080: Sum and Plus Equal of Boo…
Browse files Browse the repository at this point in the history
…lean Arrays (Bears-R-Us#3154)

* adding + Bool-Bool VV, SV, VS and += Bool-Bool VV and SV

* remove casting

* fix vs and sv cases

* sv opeq

---------

Co-authored-by: jaketrookman <[email protected]>
  • Loading branch information
jaketrookman and jaketrookman authored May 7, 2024
1 parent e5d4813 commit bc64163
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
37 changes: 37 additions & 0 deletions PROTO_tests/tests/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,43 @@ def test_invert(self):
np_invert = ~np.arange(10, dtype=np.uint)
assert ak_invert.to_list() == np_invert.tolist()

def test_bool_bool_addition_binop(self):
np_x = np.array([True, True, False, False])
np_y = np.array([True, False, True, False])
ak_x = ak.array(np_x)
ak_y = ak.array(np_y)
# Vector-Vector Case
assert (np_x+np_y).tolist() == (ak_x+ak_y).to_list()
# Scalar-Vector Case
assert (np_x[0]+np_y).tolist() == (ak_x[0]+ak_y).to_list()
assert (np_x[-1]+np_y).tolist() == (ak_x[-1]+ak_y).to_list()
# Vector-Scalar Case
assert (np_x+np_y[0]).tolist() == (ak_x+ak_y[0]).to_list()
assert (np_x+np_y[-1]).tolist() == (ak_x+ak_y[-1]).to_list()

def test_bool_bool_addition_opeq(self):
np_x = np.array([True, True, False, False])
np_y = np.array([True, False, True, False])
ak_x = ak.array(np_x)
ak_y = ak.array(np_y)
np_x += np_y
ak_x += ak_y
# Vector-Vector Case
assert np_x.tolist() == ak_x.to_list()
# Scalar-Vector Case
# True
np_true = np_x[0]
ak_true = ak_x[0]
np_true += np_y
ak_true += ak_y
assert np_x.tolist() == ak_x.to_list()
# False
np_false = np_x[-1]
ak_false = ak_x[-1]
np_false += np_y
ak_false += ak_y
assert np_x.tolist() == ak_x.to_list()

def test_uint_bool_binops(self):
# Test fix for issue #1932
# Adding support to binopvv to correctly handle uint and bool types
Expand Down
9 changes: 9 additions & 0 deletions src/BinOp.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ module BinOp
when ">=" {
e.a = l.a:int >= r.a:int;
}
when "+" {
e.a = l.a | r.a;
}
otherwise {
var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -520,6 +523,9 @@ module BinOp
when ">=" {
e.a = l.a:int >= val:int;
}
when "+" {
e.a = l.a | val;
}
otherwise {
var errorMsg = notImplementedError(pn,l.dtype,op,dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -911,6 +917,9 @@ module BinOp
when ">=" {
e.a = val:int >= r.a:int;
}
when "+" {
e.a = val | r.a;
}
otherwise {
var errorMsg = notImplementedError(pn,dtype,op,r.dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down
15 changes: 14 additions & 1 deletion src/OperatorMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ module OperatorMsg
when (DType.Bool, DType.Bool) {
var l = toSymEntry(left,bool, nd);
var r = toSymEntry(right,bool, nd);
if (op == "<<") || (op == ">>" ) {
if (op == "<<") || (op == ">>" ) {
var e = st.addEntry(rname, l.tupShape, int);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
Expand Down Expand Up @@ -1346,6 +1346,7 @@ module OperatorMsg
when "|=" {l.a |= r.a;}
when "&=" {l.a &= r.a;}
when "^=" {l.a ^= r.a;}
when "+=" {l.a |= r.a;}
otherwise {
var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -1838,6 +1839,18 @@ module OperatorMsg
}
}
}
when (DType.Bool, DType.Bool) {
var l = toSymEntry(left, bool, nd);
var val = value.getBoolValue();
select op {
when "+=" {l.a |= val;}
otherwise {
var errorMsg = notImplementedError(pn,left.dtype,op,dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}
when (DType.UInt64, DType.BigInt) {
var errorMsg = notImplementedError(pn,left.dtype,op,dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down
37 changes: 37 additions & 0 deletions tests/operator_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,43 @@ def test_invert(self):
np_uint_inv = ~np.arange(10, dtype=np.uint)
self.assertListEqual(np_uint_inv.tolist(), inverted.to_list())

def test_bool_bool_addition_binop(self):
np_x = np.array([True, True, False, False])
np_y = np.array([True, False, True, False])
ak_x = ak.array(np_x)
ak_y = ak.array(np_y)
# Vector-Vector Case
self.assertListEqual((np_x+np_y).tolist(), (ak_x+ak_y).to_list())
# Scalar-Vector Case
self.assertListEqual((np_x[0]+np_y).tolist(), (ak_x[0]+ak_y).to_list())
self.assertListEqual((np_x[-1]+np_y).tolist(), (ak_x[-1]+ak_y).to_list())
# Vector-Scalar Case
self.assertListEqual((np_x+np_y[0]).tolist(), (ak_x+ak_y[0]).to_list())
self.assertListEqual((np_x+np_y[-1]).tolist(), (ak_x+ak_y[-1]).to_list())

def test_bool_bool_addition_opeq(self):
np_x = np.array([True, True, False, False])
np_y = np.array([True, False, True, False])
ak_x = ak.array(np_x)
ak_y = ak.array(np_y)
np_x += np_y
ak_x += ak_y
# Vector-Vector Case
self.assertListEqual(np_x.tolist(), ak_x.to_list())
# Scalar-Vector Case
# True
np_true = np_x[0]
ak_true = ak_x[0]
np_true += np_y
ak_true += ak_y
self.assertListEqual(np_x.tolist(), ak_x.to_list())
# False
np_false = np_x[-1]
ak_false = ak_x[-1]
np_false += np_y
ak_false += ak_y
self.assertListEqual(np_x.tolist(), ak_x.to_list())

def test_uint_bool_binops(self):
# Test fix for issue #1932
# Adding support to binopvv to correctly handle uint and bool types
Expand Down

0 comments on commit bc64163

Please sign in to comment.