Skip to content

Commit

Permalink
Implements Index.putmask
Browse files Browse the repository at this point in the history
  • Loading branch information
beobest2 committed Jun 2, 2020
1 parent ae57c2a commit 5be5ba2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
49 changes: 49 additions & 0 deletions databricks/koalas/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pyspark import sql as spark
from pyspark.sql import functions as F, Window
from pyspark.sql.types import BooleanType, NumericType, StringType, TimestampType
from pyspark.sql.functions import udf

from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.config import get_option, option_context
Expand Down Expand Up @@ -1539,6 +1540,54 @@ def argmin(self):

return sdf.orderBy(self.spark.column.asc(), F.col(sequence_col).asc()).first()[0]

def putmask(self, mask, value):
"""
Return a new Index of the values set with the mask.
.. note:: this API can be pretty expensive since it is based on
a global sequence internally.
Returns
-------
Index
Example
-------
>>> kidx = ks.Index(['a', 'b', 'c', 'd', 'e'])
>>> kidx
Index(['a', 'b', 'c', 'd', 'e'], dtype='object')
>>> kidx.putmask([True if x < 2 else False for x in range(5)], "Koalas").sort_values()
Index(['Koalas', 'Koalas', 'c', 'd', 'e'], dtype='object')
"""
origin_col = self._internal.index_spark_column_names[0]
sdf = self._internal.spark_frame.select(self.spark.column)

sequence_col = verify_temp_column_name(sdf, "__distributed_sequence_column__")
sdf = InternalFrame.attach_distributed_sequence_column(sdf, column_name=sequence_col)

masking_col = verify_temp_column_name(sdf, "__masking_column__")
masking_udf = udf(lambda x: mask[x], BooleanType())

sdf = sdf.withColumn(masking_col, masking_udf(sequence_col))
# spark_frame here looks like below
# +-------------------------------+-----------------+------------------+
# |__distributed_sequence_column__|__index_level_0__|__masking_column__|
# +-------------------------------+-----------------+------------------+
# | 0| a| true|
# | 3| d| false|
# | 1| b| true|
# | 2| c| false|
# | 4| e| false|
# +-------------------------------+-----------------+------------------+

cond = F.when(sdf[masking_col], value).otherwise(sdf[origin_col])
sdf = sdf.select(cond.alias(origin_col))

internal = InternalFrame(spark_frame=sdf, index_map=self._internal.index_map)

return ks.DataFrame(internal).index

def set_names(self, names, level=None, inplace=False):
"""
Set Index or MultiIndex name.
Expand Down
1 change: 0 additions & 1 deletion databricks/koalas/missing/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class MissingPandasLikeIndex(object):
is_type_compatible = _unsupported_function("is_type_compatible")
join = _unsupported_function("join")
map = _unsupported_function("map")
putmask = _unsupported_function("putmask")
ravel = _unsupported_function("ravel")
reindex = _unsupported_function("reindex")
searchsorted = _unsupported_function("searchsorted")
Expand Down
11 changes: 11 additions & 0 deletions databricks/koalas/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ def test_dropna(self):
self.assert_eq(kidx.dropna(), pidx.dropna())
self.assert_eq((kidx + 1).dropna(), (pidx + 1).dropna())

def test_putmask(self):
pidx = pd.Index(["a", "b", "c", "d", "e"])
kidx = ks.from_pandas(pidx)

mask = [True if x < 2 else False for x in range(5)]
value = "Koalas"

self.assert_eq(
kidx.putmask(mask, value).sort_values(), pidx.putmask(mask, value).sort_values()
)

def test_index_symmetric_difference(self):
pidx1 = pd.Index([1, 2, 3, 4])
pidx2 = pd.Index([2, 3, 4, 5])
Expand Down

0 comments on commit 5be5ba2

Please sign in to comment.