Skip to content

Commit

Permalink
implementing DataFrame.pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mrojas committed Dec 10, 2024
1 parent 99dde61 commit 2ab515e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5689,6 +5689,57 @@ def _format_datatype(name, dtype, nullable=None, depth=0):

return f"root\n{schema_tmp_str}"

def pipe(
self, func: Callable[..., "DataFrame"], *args: str, **kwargs: Dict[str, str]
) -> "DataFrame":
"""
The pipe method facilitates chainable operations by allowing the output of one function to be passed as the
input to another function in a sequence. It is particularly useful for creating a clean, readable workflow
when applying multiple transformations or operations on data. The method supports chaining by returning
the result of each operation, making the sequence easy to understand and maintain.
Args:
func: A function that takes a DataFrame as its first argument and returns a transformed DataFrame.
args: Positional arguments to be passed to the function.
kwargs: Keyword arguments to be passed to the function.
Example::
>>> from snowflake.snowpark import Session
>>> from snowflake.snowpark.functions import lit
>>> data = [
... {"id": 1, "value": 50},
... {"id": 2, "value": 30},
... {"id": 3, "value": 70},
... {"id": 4, "value": 20}
...]
>>> df = session.create_dataframe(data)
>>> def add_column(df, column_name, value):
... return df.withColumn(column_name, F.lit(value))
>>> transformed_df = (
... df
... .pipe(add_column, "new_col", 42)
... .pipe(lambda d: d.filter(d["new_col"]> 40))
... )
>>> transformed_df.show()
------------------------------
|"ID" |"VALUE" |"NEW_COL" |
------------------------------
|1 |50 |42 |
|3 |70 |42 |
------------------------------
<BLANKLINE>
"""
if isinstance(func, tuple):
func, target = func
if target in kwargs:
raise ValueError(
"%s is both the pipe target and a keyword " "argument" % target
)
kwargs[target] = self
return func(*args, **kwargs)
else:
return func(self, *args, **kwargs)

def print_schema(self, level: Optional[int] = None) -> None:
"""
Prints the schema of a dataframe in tree format.
Expand Down
19 changes: 19 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,6 +3389,25 @@ def check_df_with_query_id_result_scan(session, df):
Utils.check_answer(df, df_from_result_scan)


def test_pipe(session):
data = [
{"id": 1, "value": 50},
{"id": 2, "value": 30},
{"id": 3, "value": 70},
{"id": 4, "value": 20},
]
df = session.create_dataframe(data)

def add_column(df, column_name, value):
return df.withColumn(column_name, lit(value))

transformed_df = df.pipe(add_column, "new_col", 42).pipe(
lambda d: d.filter(d["new_col"] > 40)
)
expected_rows = [Row(ID=1, VALUE=50, NEW_COL=42), Row(ID=3, VALUE=70, NEW_COL=42)]
assert transformed_df.collect() == expected_rows


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Result scan is a SQL feature",
Expand Down

0 comments on commit 2ab515e

Please sign in to comment.