diff --git a/CHANGELOG.md b/CHANGELOG.md index 975e7d9280b..78e5588c6e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### New Features - Added support for `Series.str.ljust` and `Series.str.rjust`. +- Added support for `Series.str.center`. ## 1.26.0 (2024-12-05) diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index 4af8fa93142..a2f41cdc79f 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -23,7 +23,7 @@ the method in the left column. +-----------------------------+---------------------------------+----------------------------------------------------+ | ``cat`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``center`` | N | | +| ``center`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``contains`` | P | ``N`` if the `na` parameter is set to a non-bool | | | | value. | diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 7d8096bd3da..16e885c12f3 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16253,8 +16253,39 @@ def str___getitem__(self, key: Union[Scalar, slice]) -> "SnowflakeQueryCompiler" raise ValueError("slice step cannot be zero") return self.str_slice(key.start, key.stop, key.step) - def str_center(self, width: int, fillchar: str = " ") -> None: - ErrorMessage.method_not_implemented_error("center", "Series.str") + def str_center(self, width: int, fillchar: str = " ") -> "SnowflakeQueryCompiler": + if not isinstance(width, int): + raise TypeError( + f"width must be of integer type, not {type(width).__name__}" + ) + if not isinstance(fillchar, str): + raise TypeError( + f"fillchar must be of integer type, not {type(fillchar).__name__}" + ) + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") + + def output_col(column: SnowparkColumn) -> SnowparkColumn: + new_col = rpad( + lpad( + column, + greatest( + length(column), + length(column) + + (pandas_lit(width) - length(column) - pandas_lit(1)) + / pandas_lit(2), + ), + pandas_lit(fillchar), + ), + greatest(length(column), pandas_lit(width)), + pandas_lit(fillchar), + ) + return self._replace_non_str(column, new_col) + + new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( + output_col + ) + return SnowflakeQueryCompiler(new_internal_frame) def str_contains( self, @@ -16473,6 +16504,12 @@ def str_ljust(self, width: int, fillchar: str = " ") -> "SnowflakeQueryCompiler" raise TypeError( f"width must be of integer type, not {type(width).__name__}" ) + if not isinstance(fillchar, str): + raise TypeError( + f"fillchar must be of integer type, not {type(fillchar).__name__}" + ) + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") def output_col(column: SnowparkColumn) -> SnowparkColumn: new_col = rpad( @@ -16507,6 +16544,12 @@ def str_rjust(self, width: int, fillchar: str = " ") -> "SnowflakeQueryCompiler" raise TypeError( f"width must be of integer type, not {type(width).__name__}" ) + if not isinstance(fillchar, str): + raise TypeError( + f"fillchar must be of integer type, not {type(fillchar).__name__}" + ) + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") def output_col(column: SnowparkColumn) -> SnowparkColumn: new_col = lpad( diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py index b23852490d5..0fae8ce1024 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py @@ -337,7 +337,51 @@ def pad(): pass def center(): - pass + """ + Pad left and right side of strings in the Series/Index. + + Equivalent to str.center(). + + Parameters + ---------- + width : int + Minimum width of resulting string; additional characters will be filled with fillchar. + fillchar : str + Additional character for filling, default is whitespace. + + Returns + ------- + Series/Index of objects. + + Examples + -------- + For Series.str.center: + + >>> ser = pd.Series(['dog', 'bird', 'mouse']) + >>> ser.str.center(8, fillchar='.') + 0 ..dog... + 1 ..bird.. + 2 .mouse.. + dtype: object + + For Series.str.ljust: + + >>> ser = pd.Series(['dog', 'bird', 'mouse']) + >>> ser.str.ljust(8, fillchar='.') + 0 dog..... + 1 bird.... + 2 mouse... + dtype: object + + For Series.str.rjust: + + >>> ser = pd.Series(['dog', 'bird', 'mouse']) + >>> ser.str.rjust(8, fillchar='.') + 0 .....dog + 1 ....bird + 2 ...mouse + dtype: object + """ def ljust(): """ diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index f0b9c330969..6cfc25a4ecb 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -422,11 +422,11 @@ def test_str_no_params(func): ) -@pytest.mark.parametrize("func", ["ljust", "rjust"]) +@pytest.mark.parametrize("func", ["center", "ljust", "rjust"]) @pytest.mark.parametrize("width", [-1, 0, 1, 10, 100]) @pytest.mark.parametrize("fillchar", [" ", "#"]) @sql_count_checker(query_count=1) -def test_str_ljust_rjust(func, width, fillchar): +def test_str_center_ljust_rjust(func, width, fillchar): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) eval_snowpark_pandas_result( @@ -436,18 +436,20 @@ def test_str_ljust_rjust(func, width, fillchar): ) -@pytest.mark.parametrize("func", ["ljust", "rjust"]) +@pytest.mark.parametrize("func", ["center", "ljust", "rjust"]) @pytest.mark.parametrize( "width, fillchar", [ (None, " "), + ("ten", " "), (10, ""), (10, "ab"), (10, None), + (10, 10), ], ) @sql_count_checker(query_count=0) -def test_str_ljust_rjust_neg(func, width, fillchar): +def test_str_center_ljust_rjust_neg(func, width, fillchar): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) with pytest.raises(TypeError): diff --git a/tests/unit/modin/test_series_strings.py b/tests/unit/modin/test_series_strings.py index ef303f69d58..2668835a9b8 100644 --- a/tests/unit/modin/test_series_strings.py +++ b/tests/unit/modin/test_series_strings.py @@ -37,7 +37,6 @@ def test_str_cat_no_others(mock_str_register, mock_series): (lambda s: s.str.rsplit("_", n=1), "rsplit"), (lambda s: s.str.join("_"), "join"), (lambda s: s.str.pad(10), "pad"), - (lambda s: s.str.center(10), "center"), (lambda s: s.str.zfill(8), "zfill"), (lambda s: s.str.wrap(3), "wrap"), (lambda s: s.str.slice_replace(start=3, stop=5, repl="abc"), "slice_replace"), @@ -105,11 +104,6 @@ def test_str_methods_with_dataframe_return(func, func_name, mock_series): TypeError, "fillchar must be a character, not str", ), - ( - lambda s: s.str.center(8, fillchar="abc"), - TypeError, - "fillchar must be a character, not str", - ), (lambda s: s.str.wrap(-1), ValueError, r"invalid width -1 \(must be > 0\)"), ( lambda s: s.str.count(12),