From 9313150515f940f7dc961206b057b7bf1f0e5be4 Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Wed, 28 Aug 2024 20:36:44 -0600 Subject: [PATCH] SIT-2192 Add a new overload for `com.snowflake.snowpark.functions.round` function (#150) --- .../snowflake/snowpark_java/Functions.java | 70 +++++++++++++++++- .../com/snowflake/snowpark/functions.scala | 73 ++++++++++++++++++- .../snowpark_test/JavaFunctionSuite.java | 16 ++++ .../snowpark_test/FunctionSuite.scala | 14 +++- 4 files changed, 166 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index f84e6082..7d637690 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -1029,7 +1029,25 @@ public static Column pow(Column l, Column r) { } /** - * Returns rounded values for the specified column. + * Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places + * using the half away from zero rounding mode. + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)");
+   * df.select(round(col("a"), lit(1)).alias("round")).show();
+   *
+   * -----------
+   * |"ROUND"  |
+   * -----------
+   * |-3.8     |
+   * |-2.6     |
+   * |1.2      |
+   * |2.6      |
+   * |3.8      |
+   * -----------
+   * }
* * @since 0.9.0 * @param e The input column @@ -1042,7 +1060,25 @@ public static Column round(Column e, Column scale) { } /** - * Returns rounded values for the specified column. + * Rounds the numeric values of the given column {@code e} to 0 decimal places using the half away + * from zero rounding mode. + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (-3.7), (-2.5), (1.2), (2.5), (3.7)) as T(a)");
+   * df.select(round(col("a")).alias("round")).show();
+   *
+   * -----------
+   * |"ROUND"  |
+   * -----------
+   * |-4       |
+   * |-3       |
+   * |1        |
+   * |3        |
+   * |4        |
+   * -----------
+   * }
* * @since 0.9.0 * @param e The input column @@ -1052,6 +1088,36 @@ public static Column round(Column e) { return new Column(com.snowflake.snowpark.functions.round(e.toScalaColumn())); } + /** + * Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places + * using the half away from zero rounding mode. + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)");
+   * df.select(round(col("a"), 1).alias("round")).show();
+   *
+   * -----------
+   * |"ROUND"  |
+   * -----------
+   * |-3.8     |
+   * |-2.6     |
+   * |1.2      |
+   * |2.6      |
+   * |3.8      |
+   * -----------
+   * }
+ * + * @param e The column of numeric values to round. + * @param scale The number of decimal places to which {@code e} should be rounded. + * @return A new column containing the rounded numeric values. + * @since 1.14.0 + */ + public static Column round(Column e, int scale) { + return new Column(com.snowflake.snowpark.functions.round(e.toScalaColumn(), scale)); + } + /** * Shifts the bits for a numeric expression numBits positions to the left. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index fdfc3189..35a3aa43 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -881,21 +881,90 @@ object functions { def pow(l: Column, r: Column): Column = builtin("pow")(l, r) /** - * Returns rounded values for the specified column. + * Rounds the numeric values of the given column `e` to the `scale` decimal places using the + * half away from zero rounding mode. * + * Example: + * {{{ + * val df = session.sql( + * "select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)") + * df.select(round(col("a"), lit(1)).alias("round")).show() + * + * ----------- + * |"ROUND" | + * ----------- + * |-3.8 | + * |-2.6 | + * |1.2 | + * |2.6 | + * |3.8 | + * ----------- + * }}} + * + * @param e The column of numeric values to round. + * @param scale A column representing the number of decimal places to which `e` should be rounded. + * @return A new column containing the rounded numeric values. * @group num_func * @since 0.1.0 */ def round(e: Column, scale: Column): Column = builtin("round")(e, scale) /** - * Returns rounded values for the specified column. + * Rounds the numeric values of the given column `e` to 0 decimal places using the + * half away from zero rounding mode. * + * Example: + * {{{ + * val df = session.sql("select * from (values (-3.7), (-2.5), (1.2), (2.5), (3.7)) as T(a)") + * df.select(round(col("a")).alias("round")).show() + * + * ----------- + * |"ROUND" | + * ----------- + * |-4 | + * |-3 | + * |1 | + * |3 | + * |4 | + * ----------- + * }}} + * + * @param e The column of numeric values to round. + * @return A new column containing the rounded numeric values. * @group num_func * @since 0.1.0 */ def round(e: Column): Column = round(e, lit(0)) + /** + * Rounds the numeric values of the given column `e` to the `scale` decimal places using the + * half away from zero rounding mode. + * + * Example: + * {{{ + * val df = session.sql( + * "select * from (values (-3.78), (-2.55), (1.23), (2.55), (3.78)) as T(a)") + * df.select(round(col("a"), 1).alias("round")).show() + * + * ----------- + * |"ROUND" | + * ----------- + * |-3.8 | + * |-2.6 | + * |1.2 | + * |2.6 | + * |3.8 | + * ----------- + * }}} + * + * @param e The column of numeric values to round. + * @param scale The number of decimal places to which `e` should be rounded. + * @return A new column containing the rounded numeric values. + * @group num_func + * @since 1.14.0 + */ + def round(e: Column, scale: Int): Column = round(e, lit(scale)) + /** * Shifts the bits for a numeric expression numBits positions to the left. * diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index e34cee94..bf91aca1 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -602,10 +602,26 @@ public void pow() { @Test public void round() { + // Case: Scale greater than or equal to zero. DataFrame df = getSession().sql("select * from values(1.111),(2.222),(3.333) as T(a)"); Row[] expected = {Row.create(1.0), Row.create(2.0), Row.create(3.0)}; checkAnswer(df.select(Functions.round(df.col("a"))), expected, false); checkAnswer(df.select(Functions.round(df.col("a"), Functions.lit(0))), expected, false); + checkAnswer(df.select(Functions.round(df.col("a"), 0)), expected, false); + + // Case: Scale less than zero. + DataFrame df2 = getSession().sql("select * from values(5),(55),(555) as T(a)"); + Row[] expected2 = {Row.create(10, 0), Row.create(60, 100), Row.create(560, 600)}; + checkAnswer( + df2.select( + Functions.round(df2.col("a"), Functions.lit(-1)), + Functions.round(df2.col("a"), Functions.lit(-2))), + expected2, + false); + checkAnswer( + df2.select(Functions.round(df2.col("a"), -1), Functions.round(df2.col("a"), -2)), + expected2, + false); } @Test diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 3ae6372f..9a32120e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -274,9 +274,17 @@ trait FunctionSuite extends TestData { } test("round") { - checkAnswer(double1.select(round(col("A"))), Seq(Row(1.0), Row(2.0), Row(3.0))) - checkAnswer(double1.select(round(col("A"), lit(0))), Seq(Row(1.0), Row(2.0), Row(3.0))) - + // Case: Scale greater than or equal to zero. + val expected1 = Seq(Row(1.0), Row(2.0), Row(3.0)) + checkAnswer(double1.select(round(col("A"))), expected1) + checkAnswer(double1.select(round(col("A"), lit(0))), expected1) + checkAnswer(double1.select(round(col("A"), 0)), expected1) + + // Case: Scale less than zero. + val df2 = session.sql("select * from values(5),(55),(555) as T(a)") + val expected2 = Seq(Row(10, 0), Row(60, 100), Row(560, 600)) + checkAnswer(df2.select(round(col("a"), lit(-1)), round(col("a"), lit(-2))), expected2) + checkAnswer(df2.select(round(col("a"), -1), round(col("a"), -2)), expected2) } test("asin acos") {