diff --git a/PROTO_tests/tests/operator_test.py b/PROTO_tests/tests/operator_test.py index cfc1bbd5c3..4694ca3dbc 100644 --- a/PROTO_tests/tests/operator_test.py +++ b/PROTO_tests/tests/operator_test.py @@ -233,6 +233,43 @@ def test_invert(self): np_invert = ~np.arange(10, dtype=np.uint) assert ak_invert.to_list() == np_invert.tolist() + def test_bool_bool_addition_binop(self): + np_x = np.array([True, True, False, False]) + np_y = np.array([True, False, True, False]) + ak_x = ak.array(np_x) + ak_y = ak.array(np_y) + # Vector-Vector Case + assert (np_x+np_y).tolist() == (ak_x+ak_y).to_list() + # Scalar-Vector Case + assert (np_x[0]+np_y).tolist() == (ak_x[0]+ak_y).to_list() + assert (np_x[-1]+np_y).tolist() == (ak_x[-1]+ak_y).to_list() + # Vector-Scalar Case + assert (np_x+np_y[0]).tolist() == (ak_x+ak_y[0]).to_list() + assert (np_x+np_y[-1]).tolist() == (ak_x+ak_y[-1]).to_list() + + def test_bool_bool_addition_opeq(self): + np_x = np.array([True, True, False, False]) + np_y = np.array([True, False, True, False]) + ak_x = ak.array(np_x) + ak_y = ak.array(np_y) + np_x += np_y + ak_x += ak_y + # Vector-Vector Case + assert np_x.tolist() == ak_x.to_list() + # Scalar-Vector Case + # True + np_true = np_x[0] + ak_true = ak_x[0] + np_true += np_y + ak_true += ak_y + assert np_x.tolist() == ak_x.to_list() + # False + np_false = np_x[-1] + ak_false = ak_x[-1] + np_false += np_y + ak_false += ak_y + assert np_x.tolist() == ak_x.to_list() + def test_uint_bool_binops(self): # Test fix for issue #1932 # Adding support to binopvv to correctly handle uint and bool types diff --git a/src/BinOp.chpl b/src/BinOp.chpl index b01cf38eb7..da671450f9 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -109,6 +109,9 @@ module BinOp when ">=" { e.a = l.a:int >= r.a:int; } + when "+" { + e.a = l.a | r.a; + } otherwise { var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); @@ -520,6 +523,9 @@ module BinOp when ">=" { e.a = l.a:int >= val:int; } + when "+" { + e.a = l.a | val; + } otherwise { var errorMsg = notImplementedError(pn,l.dtype,op,dtype); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); @@ -911,6 +917,9 @@ module BinOp when ">=" { e.a = val:int >= r.a:int; } + when "+" { + e.a = val | r.a; + } otherwise { var errorMsg = notImplementedError(pn,dtype,op,r.dtype); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 86e34aeae2..444d2ae563 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -148,7 +148,7 @@ module OperatorMsg when (DType.Bool, DType.Bool) { var l = toSymEntry(left,bool, nd); var r = toSymEntry(right,bool, nd); - if (op == "<<") || (op == ">>" ) { + if (op == "<<") || (op == ">>" ) { var e = st.addEntry(rname, l.tupShape, int); return doBinOpvv(l, r, e, op, rname, pn, st); } @@ -1346,6 +1346,7 @@ module OperatorMsg when "|=" {l.a |= r.a;} when "&=" {l.a &= r.a;} when "^=" {l.a ^= r.a;} + when "+=" {l.a |= r.a;} otherwise { var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); @@ -1838,6 +1839,18 @@ module OperatorMsg } } } + when (DType.Bool, DType.Bool) { + var l = toSymEntry(left, bool, nd); + var val = value.getBoolValue(); + select op { + when "+=" {l.a |= val;} + otherwise { + var errorMsg = notImplementedError(pn,left.dtype,op,dtype); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + } when (DType.UInt64, DType.BigInt) { var errorMsg = notImplementedError(pn,left.dtype,op,dtype); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); diff --git a/tests/operator_tests.py b/tests/operator_tests.py index e415f4700c..f0c2c04bcf 100755 --- a/tests/operator_tests.py +++ b/tests/operator_tests.py @@ -305,6 +305,43 @@ def test_invert(self): np_uint_inv = ~np.arange(10, dtype=np.uint) self.assertListEqual(np_uint_inv.tolist(), inverted.to_list()) + def test_bool_bool_addition_binop(self): + np_x = np.array([True, True, False, False]) + np_y = np.array([True, False, True, False]) + ak_x = ak.array(np_x) + ak_y = ak.array(np_y) + # Vector-Vector Case + self.assertListEqual((np_x+np_y).tolist(), (ak_x+ak_y).to_list()) + # Scalar-Vector Case + self.assertListEqual((np_x[0]+np_y).tolist(), (ak_x[0]+ak_y).to_list()) + self.assertListEqual((np_x[-1]+np_y).tolist(), (ak_x[-1]+ak_y).to_list()) + # Vector-Scalar Case + self.assertListEqual((np_x+np_y[0]).tolist(), (ak_x+ak_y[0]).to_list()) + self.assertListEqual((np_x+np_y[-1]).tolist(), (ak_x+ak_y[-1]).to_list()) + + def test_bool_bool_addition_opeq(self): + np_x = np.array([True, True, False, False]) + np_y = np.array([True, False, True, False]) + ak_x = ak.array(np_x) + ak_y = ak.array(np_y) + np_x += np_y + ak_x += ak_y + # Vector-Vector Case + self.assertListEqual(np_x.tolist(), ak_x.to_list()) + # Scalar-Vector Case + # True + np_true = np_x[0] + ak_true = ak_x[0] + np_true += np_y + ak_true += ak_y + self.assertListEqual(np_x.tolist(), ak_x.to_list()) + # False + np_false = np_x[-1] + ak_false = ak_x[-1] + np_false += np_y + ak_false += ak_y + self.assertListEqual(np_x.tolist(), ak_x.to_list()) + def test_uint_bool_binops(self): # Test fix for issue #1932 # Adding support to binopvv to correctly handle uint and bool types