Skip to content

Commit

Permalink
Fix bug on pow and rpow (#2047)
Browse files Browse the repository at this point in the history
Fixed incompatible behavior for `pow` and `rpow`.

In pandas:

```python
>>> pd.Series([1, 2, 3]) ** np.nan
0    1.0
1    NaN
2    NaN
dtype: float64

>>> 1 ** pd.Series([np.nan, 2, 3])
0    1.0
1    1.0
2    1.0
dtype: float64
```

In Koalas:

```python
>>> ks.Series([1, 2, 3]) ** np.nan
0    NaN  # doesn't match
1    NaN
2    NaN
dtype: float64

>>> 1 ** ks.Series([np.nan, 2, 3])
0    NaN  # doesn't match
1    1.0
2    1.0
dtype: float64
```
  • Loading branch information
itholic authored Feb 12, 2021
1 parent e6a9628 commit 87f5b18
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
14 changes: 12 additions & 2 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,18 @@ def rmod(left, right):

return column_op(rmod)(self, other)

__pow__ = column_op(Column.__pow__)
__rpow__ = column_op(Column.__rpow__)
def __pow__(self, other) -> Union["Series", "Index"]:
def pow_func(left, right):
return F.when(left == 1, left).otherwise(Column.__pow__(left, right))

return column_op(pow_func)(self, other)

def __rpow__(self, other) -> Union["Series", "Index"]:
def rpow_func(left, right):
return F.when(F.lit(right == 1), right).otherwise(Column.__rpow__(left, right))

return column_op(rpow_func)(self, other)

__abs__ = column_op(F.abs)

# comparison operators
Expand Down
23 changes: 23 additions & 0 deletions databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,16 @@ def test_align(self):
self.assertRaises(ValueError, lambda: kdf1.align(kdf3, axis=None))
self.assertRaises(ValueError, lambda: kdf1.align(kdf3, axis=1))

def test_pow_and_rpow(self):
pser = pd.Series([1, 2, np.nan])
kser = ks.from_pandas(pser)
pser_other = pd.Series([np.nan, 2, 3])
kser_other = ks.from_pandas(pser_other)

self.assert_eq(pser.pow(pser_other), kser.pow(kser_other).sort_index())
self.assert_eq(pser ** pser_other, (kser ** kser_other).sort_index())
self.assert_eq(pser.rpow(pser_other), kser.rpow(kser_other).sort_index())


class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils):
@classmethod
Expand Down Expand Up @@ -1671,3 +1681,16 @@ def test_align(self):

with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
kdf1.align(kdf2, axis=0)

def test_pow_and_rpow(self):
pser = pd.Series([1, 2, np.nan])
kser = ks.from_pandas(pser)
pser_other = pd.Series([np.nan, 2, 3])
kser_other = ks.from_pandas(pser_other)

with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
kser.pow(kser_other)
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
kser ** kser_other
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
kser.rpow(kser_other)
9 changes: 9 additions & 0 deletions databricks/koalas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,3 +2673,12 @@ def test_align(self):
self.assert_eq(kdf_r, pdf_r)

self.assertRaises(ValueError, lambda: kdf.a.align(kdf.b, axis=1))

def test_pow_and_rpow(self):
pser = pd.Series([1, 2, np.nan])
kser = ks.from_pandas(pser)

self.assert_eq(pser.pow(np.nan), kser.pow(np.nan))
self.assert_eq(pser ** np.nan, kser ** np.nan)
self.assert_eq(pser.rpow(np.nan), kser.rpow(np.nan))
self.assert_eq(1 ** pser, 1 ** kser)

0 comments on commit 87f5b18

Please sign in to comment.