diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index c4dd09095fd..4e5e1867cb4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -5689,6 +5689,57 @@ def _format_datatype(name, dtype, nullable=None, depth=0): return f"root\n{schema_tmp_str}" + def pipe( + self, func: Callable[..., "DataFrame"], *args: str, **kwargs: Dict[str, str] + ) -> "DataFrame": + """ + The pipe method facilitates chainable operations by allowing the output of one function to be passed as the + input to another function in a sequence. It is particularly useful for creating a clean, readable workflow + when applying multiple transformations or operations on data. The method supports chaining by returning + the result of each operation, making the sequence easy to understand and maintain. + + Args: + func: A function that takes a DataFrame as its first argument and returns a transformed DataFrame. + args: Positional arguments to be passed to the function. + kwargs: Keyword arguments to be passed to the function. + + Example:: + >>> from snowflake.snowpark import Session + >>> from snowflake.snowpark.functions import lit + >>> data = [ + ... {"id": 1, "value": 50}, + ... {"id": 2, "value": 30}, + ... {"id": 3, "value": 70}, + ... {"id": 4, "value": 20} + ...] + >>> df = session.create_dataframe(data) + >>> def add_column(df, column_name, value): + ... return df.withColumn(column_name, F.lit(value)) + >>> transformed_df = ( + ... df + ... .pipe(add_column, "new_col", 42) + ... .pipe(lambda d: d.filter(d["new_col"]> 40)) + ... ) + >>> transformed_df.show() + ------------------------------ + |"ID" |"VALUE" |"NEW_COL" | + ------------------------------ + |1 |50 |42 | + |3 |70 |42 | + ------------------------------ + + """ + if isinstance(func, tuple): + func, target = func + if target in kwargs: + raise ValueError( + "%s is both the pipe target and a keyword " "argument" % target + ) + kwargs[target] = self + return func(*args, **kwargs) + else: + return func(self, *args, **kwargs) + def print_schema(self, level: Optional[int] = None) -> None: """ Prints the schema of a dataframe in tree format. diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 75d1754a762..532ae117621 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -3389,6 +3389,25 @@ def check_df_with_query_id_result_scan(session, df): Utils.check_answer(df, df_from_result_scan) +def test_pipe(session): + data = [ + {"id": 1, "value": 50}, + {"id": 2, "value": 30}, + {"id": 3, "value": 70}, + {"id": 4, "value": 20}, + ] + df = session.create_dataframe(data) + + def add_column(df, column_name, value): + return df.withColumn(column_name, lit(value)) + + transformed_df = df.pipe(add_column, "new_col", 42).pipe( + lambda d: d.filter(d["new_col"] > 40) + ) + expected_rows = [Row(ID=1, VALUE=50, NEW_COL=42), Row(ID=3, VALUE=70, NEW_COL=42)] + assert transformed_df.collect() == expected_rows + + @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="Result scan is a SQL feature",