Skip to content

Commit

Permalink
SIT-2192 Add a new overload for `com.snowflake.snowpark.functions.rou…
Browse files Browse the repository at this point in the history
…nd` function (#150)
  • Loading branch information
sfc-gh-fgonzalezmendez authored Aug 29, 2024
1 parent 2258be0 commit 9313150
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 7 deletions.
70 changes: 68 additions & 2 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,25 @@ public static Column pow(Column l, Column r) {
}

/**
* Returns rounded values for the specified column.
* Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places
* using the half away from zero rounding mode.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)");
* df.select(round(col("a"), lit(1)).alias("round")).show();
*
* -----------
* |"ROUND" |
* -----------
* |-3.8 |
* |-2.6 |
* |1.2 |
* |2.6 |
* |3.8 |
* -----------
* }</pre>
*
* @since 0.9.0
* @param e The input column
Expand All @@ -1042,7 +1060,25 @@ public static Column round(Column e, Column scale) {
}

/**
* Returns rounded values for the specified column.
* Rounds the numeric values of the given column {@code e} to 0 decimal places using the half away
* from zero rounding mode.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (-3.7), (-2.5), (1.2), (2.5), (3.7)) as T(a)");
* df.select(round(col("a")).alias("round")).show();
*
* -----------
* |"ROUND" |
* -----------
* |-4 |
* |-3 |
* |1 |
* |3 |
* |4 |
* -----------
* }</pre>
*
* @since 0.9.0
* @param e The input column
Expand All @@ -1052,6 +1088,36 @@ public static Column round(Column e) {
return new Column(com.snowflake.snowpark.functions.round(e.toScalaColumn()));
}

/**
* Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places
* using the half away from zero rounding mode.
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)");
* df.select(round(col("a"), 1).alias("round")).show();
*
* -----------
* |"ROUND" |
* -----------
* |-3.8 |
* |-2.6 |
* |1.2 |
* |2.6 |
* |3.8 |
* -----------
* }</pre>
*
* @param e The column of numeric values to round.
* @param scale The number of decimal places to which {@code e} should be rounded.
* @return A new column containing the rounded numeric values.
* @since 1.14.0
*/
public static Column round(Column e, int scale) {
return new Column(com.snowflake.snowpark.functions.round(e.toScalaColumn(), scale));
}

/**
* Shifts the bits for a numeric expression numBits positions to the left.
*
Expand Down
73 changes: 71 additions & 2 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -881,21 +881,90 @@ object functions {
def pow(l: Column, r: Column): Column = builtin("pow")(l, r)

/**
* Returns rounded values for the specified column.
* Rounds the numeric values of the given column `e` to the `scale` decimal places using the
* half away from zero rounding mode.
*
* Example:
* {{{
* val df = session.sql(
* "select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)")
* df.select(round(col("a"), lit(1)).alias("round")).show()
*
* -----------
* |"ROUND" |
* -----------
* |-3.8 |
* |-2.6 |
* |1.2 |
* |2.6 |
* |3.8 |
* -----------
* }}}
*
* @param e The column of numeric values to round.
* @param scale A column representing the number of decimal places to which `e` should be rounded.
* @return A new column containing the rounded numeric values.
* @group num_func
* @since 0.1.0
*/
def round(e: Column, scale: Column): Column = builtin("round")(e, scale)

/**
* Returns rounded values for the specified column.
* Rounds the numeric values of the given column `e` to 0 decimal places using the
* half away from zero rounding mode.
*
* Example:
* {{{
* val df = session.sql("select * from (values (-3.7), (-2.5), (1.2), (2.5), (3.7)) as T(a)")
* df.select(round(col("a")).alias("round")).show()
*
* -----------
* |"ROUND" |
* -----------
* |-4 |
* |-3 |
* |1 |
* |3 |
* |4 |
* -----------
* }}}
*
* @param e The column of numeric values to round.
* @return A new column containing the rounded numeric values.
* @group num_func
* @since 0.1.0
*/
def round(e: Column): Column = round(e, lit(0))

/**
* Rounds the numeric values of the given column `e` to the `scale` decimal places using the
* half away from zero rounding mode.
*
* Example:
* {{{
* val df = session.sql(
* "select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)")
* df.select(round(col("a"), 1).alias("round")).show()
*
* -----------
* |"ROUND" |
* -----------
* |-3.8 |
* |-2.6 |
* |1.2 |
* |2.6 |
* |3.8 |
* -----------
* }}}
*
* @param e The column of numeric values to round.
* @param scale The number of decimal places to which `e` should be rounded.
* @return A new column containing the rounded numeric values.
* @group num_func
* @since 1.14.0
*/
def round(e: Column, scale: Int): Column = round(e, lit(scale))

/**
* Shifts the bits for a numeric expression numBits positions to the left.
*
Expand Down
16 changes: 16 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,26 @@ public void pow() {

@Test
public void round() {
// Case: Scale greater than or equal to zero.
DataFrame df = getSession().sql("select * from values(1.111),(2.222),(3.333) as T(a)");
Row[] expected = {Row.create(1.0), Row.create(2.0), Row.create(3.0)};
checkAnswer(df.select(Functions.round(df.col("a"))), expected, false);
checkAnswer(df.select(Functions.round(df.col("a"), Functions.lit(0))), expected, false);
checkAnswer(df.select(Functions.round(df.col("a"), 0)), expected, false);

// Case: Scale less than zero.
DataFrame df2 = getSession().sql("select * from values(5),(55),(555) as T(a)");
Row[] expected2 = {Row.create(10, 0), Row.create(60, 100), Row.create(560, 600)};
checkAnswer(
df2.select(
Functions.round(df2.col("a"), Functions.lit(-1)),
Functions.round(df2.col("a"), Functions.lit(-2))),
expected2,
false);
checkAnswer(
df2.select(Functions.round(df2.col("a"), -1), Functions.round(df2.col("a"), -2)),
expected2,
false);
}

@Test
Expand Down
14 changes: 11 additions & 3 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,17 @@ trait FunctionSuite extends TestData {
}

test("round") {
checkAnswer(double1.select(round(col("A"))), Seq(Row(1.0), Row(2.0), Row(3.0)))
checkAnswer(double1.select(round(col("A"), lit(0))), Seq(Row(1.0), Row(2.0), Row(3.0)))

// Case: Scale greater than or equal to zero.
val expected1 = Seq(Row(1.0), Row(2.0), Row(3.0))
checkAnswer(double1.select(round(col("A"))), expected1)
checkAnswer(double1.select(round(col("A"), lit(0))), expected1)
checkAnswer(double1.select(round(col("A"), 0)), expected1)

// Case: Scale less than zero.
val df2 = session.sql("select * from values(5),(55),(555) as T(a)")
val expected2 = Seq(Row(10, 0), Row(60, 100), Row(560, 600))
checkAnswer(df2.select(round(col("a"), lit(-1)), round(col("a"), lit(-2))), expected2)
checkAnswer(df2.select(round(col("a"), -1), round(col("a"), -2)), expected2)
}

test("asin acos") {
Expand Down

0 comments on commit 9313150

Please sign in to comment.