Skip to content

Commit

Permalink
Respect na.rm = TRUE in pmin() and pmax() for Snowflake (#1330)
Browse files Browse the repository at this point in the history
* respect na.rm=TRUE for Snowflake

* enable dots

* add news

* tidy tests

* grammar

* Update backend-snowflake.R

INT to FLOAT

* Update test-backend-snowflake.R

INT to FLOAT

* style

* account for pairs

* support >2 comparisons with na.rm = TRUE"

* add snapshot

* code review comments
  • Loading branch information
fh-mthomson authored Jul 14, 2023
1 parent 49512ac commit 5fa4410
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 24 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
* Teradata:
* `as.Date(x)` is now translate to `CAST(x AS DATE)` again unless `x` is a
string (@mgirlich, #1285).

* Snowflake:
* `na.rm = TRUE` is now respected in `pmin()` and `pmax()` instead of being silently ignored (@fh-mthomson, #1329)

* `remote_name()` now returns a string with the name of the table. To get the
qualified identifier use the newly added `remote_table()` (@mgirlich, #1280).
Expand Down
43 changes: 39 additions & 4 deletions R/backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,33 @@ sql_translation.Snowflake <- function(con) {
},
# https://docs.snowflake.com/en/sql-reference/functions/date_trunc.html
floor_date = function(x, unit = "seconds") {
unit <- arg_match(unit,
c("second", "minute", "hour", "day", "week", "month", "quarter", "year",
"seconds", "minutes", "hours", "days", "weeks", "months", "quarters", "years")
unit <- arg_match(
unit,
c(
"second", "minute", "hour", "day", "week", "month", "quarter", "year",
"seconds", "minutes", "hours", "days", "weeks", "months", "quarters", "years"
)
)
sql_expr(DATE_TRUNC(!!unit, !!x))
},
# LEAST / GREATEST on Snowflake will not respect na.rm = TRUE by default (similar to Oracle/Access)
# https://docs.snowflake.com/en/sql-reference/functions/least
# https://docs.snowflake.com/en/sql-reference/functions/greatest
pmin = function(..., na.rm = FALSE) {
dots <- list(...)
if (identical(na.rm, TRUE)) {
snowflake_pmin_pmax_sql_expression(dots = dots, comparison = "<=")
} else {
glue_sql2(sql_current_con(), "LEAST({.val dots*})")
}
},
pmax = function(..., na.rm = FALSE) {
dots <- list(...)
if (identical(na.rm, TRUE)) {
snowflake_pmin_pmax_sql_expression(dots = dots, comparison = ">=")
} else {
glue_sql2(sql_current_con(), "GREATEST({.val dots*})")
}
}
),
sql_translator(
Expand Down Expand Up @@ -246,8 +268,9 @@ snowflake_grepl <- function(pattern,
# REGEXP on Snowflaake "implicitly anchors a pattern at both ends", which
# grepl does not. Left- and right-pad `pattern` with .* to get grepl-like
# behavior
sql_expr(((!!x)) %REGEXP% (".*" || !!paste0('(', pattern, ')') || ".*"))
sql_expr(((!!x)) %REGEXP% (".*" || !!paste0("(", pattern, ")") || ".*"))
}

snowflake_round <- function(x, digits = 0L) {
digits <- as.integer(digits)
sql_expr(round(((!!x)) %::% FLOAT, !!digits))
Expand All @@ -265,4 +288,16 @@ snowflake_paste <- function(default_sep) {
}
}

snowflake_pmin_pmax_sql_expression <- function(dots, comparison){
dot_combined <- dots[[1]]
for (i in 2:length(dots)){
dot_combined <- snowflake_pmin_pmax_builder(dots[i], dot_combined, comparison)
}
dot_combined
}

snowflake_pmin_pmax_builder <- function(dot_1, dot_2, comparison){
glue_sql2(sql_current_con(), glue("COALESCE(IFF({dot_2} {comparison} {dot_1}, {dot_2}, {dot_1}), {dot_2}, {dot_1})"))
}

utils::globalVariables(c("%REGEXP%", "DAYNAME", "DECODE", "FLOAT", "MONTHNAME", "POSITION", "trim"))
14 changes: 14 additions & 0 deletions tests/testthat/_snaps/backend-snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,17 @@
! `ignore.case = TRUE` isn't supported in Snowflake translation.
i It must be FALSE instead.

# pmin() and pmax() respect na.rm

Code
test_translate_sql(pmin(x, y, z, na.rm = TRUE))
Output
<SQL> COALESCE(IFF(COALESCE(IFF(`x` <= `y`, `x`, `y`), `x`, `y`) <= `z`, COALESCE(IFF(`x` <= `y`, `x`, `y`), `x`, `y`), `z`), COALESCE(IFF(`x` <= `y`, `x`, `y`), `x`, `y`), `z`)

---

Code
test_translate_sql(pmax(x, y, z, na.rm = TRUE))
Output
<SQL> COALESCE(IFF(COALESCE(IFF(`x` >= `y`, `x`, `y`), `x`, `y`) >= `z`, COALESCE(IFF(`x` >= `y`, `x`, `y`), `x`, `y`), `z`), COALESCE(IFF(`x` >= `y`, `x`, `y`), `x`, `y`), `z`)

80 changes: 60 additions & 20 deletions tests/testthat/test-backend-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ test_that("custom scalar translated correctly", {
local_con(simulate_snowflake())
expect_equal(test_translate_sql(log10(x)), sql("LOG(10.0, `x`)"))
expect_equal(test_translate_sql(round(x, digits = 1.1)), sql("ROUND((`x`) :: FLOAT, 1)"))
expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')"))
expect_equal(test_translate_sql(grepl("exp", x)), sql("(`x`) REGEXP ('.*' || '(exp)' || '.*')"))
expect_snapshot((expect_error(test_translate_sql(grepl("exp", x, ignore.case = TRUE)))))
})

test_that("pasting translated correctly", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(paste(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), ' ')"))
expect_equal(test_translate_sql(paste(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), ' ')"))
expect_equal(test_translate_sql(paste0(x, y)), sql("ARRAY_TO_STRING(ARRAY_CONSTRUCT_COMPACT(`x`, `y`), '')"))
expect_equal(test_translate_sql(str_c(x, y)), sql("CONCAT_WS('', `x`, `y`)"))
expect_equal(test_translate_sql(str_c(x, y, sep = '|')), sql("CONCAT_WS('|', `x`, `y`)"))
expect_equal(test_translate_sql(str_c(x, y, sep = "|")), sql("CONCAT_WS('|', `x`, `y`)"))

expect_error(test_translate_sql(paste0(x, collapse = "")), "`collapse` not supported")

Expand Down Expand Up @@ -40,33 +40,33 @@ test_that("aggregates are translated correctly", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(cor(x, y), window = FALSE), sql("CORR(`x`, `y`)"))
expect_equal(test_translate_sql(cor(x, y), window = TRUE), sql("CORR(`x`, `y`) OVER ()"))
expect_equal(test_translate_sql(cor(x, y), window = TRUE), sql("CORR(`x`, `y`) OVER ()"))

expect_equal(test_translate_sql(cov(x, y), window = FALSE), sql("COVAR_SAMP(`x`, `y`)"))
expect_equal(test_translate_sql(cov(x, y), window = TRUE), sql("COVAR_SAMP(`x`, `y`) OVER ()"))
expect_equal(test_translate_sql(cov(x, y), window = TRUE), sql("COVAR_SAMP(`x`, `y`) OVER ()"))

expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = FALSE), sql("BOOLAND_AGG(`x`)"))
expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = TRUE), sql("BOOLAND_AGG(`x`) OVER ()"))
expect_equal(test_translate_sql(all(x, na.rm = TRUE), window = TRUE), sql("BOOLAND_AGG(`x`) OVER ()"))

expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = FALSE), sql("BOOLOR_AGG(`x`)"))
expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = TRUE), sql("BOOLOR_AGG(`x`) OVER ()"))
expect_equal(test_translate_sql(any(x, na.rm = TRUE), window = TRUE), sql("BOOLOR_AGG(`x`) OVER ()"))

expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = FALSE), sql("STDDEV(`x`)"))
expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = TRUE), sql("STDDEV(`x`) OVER ()"))
expect_equal(test_translate_sql(sd(x, na.rm = TRUE), window = TRUE), sql("STDDEV(`x`) OVER ()"))
})

test_that("snowflake mimics two argument log", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(log(x)), sql('LN(`x`)'))
expect_equal(test_translate_sql(log(x, 10)), sql('LOG(10.0, `x`)'))
expect_equal(test_translate_sql(log(x, 10L)), sql('LOG(10, `x`)'))
expect_equal(test_translate_sql(log(x)), sql("LN(`x`)"))
expect_equal(test_translate_sql(log(x, 10)), sql("LOG(10.0, `x`)"))
expect_equal(test_translate_sql(log(x, 10L)), sql("LOG(10, `x`)"))
})

test_that("custom lubridate functions translated correctly", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(day(x)), sql("EXTRACT(DAY FROM `x`)"))
expect_equal(test_translate_sql(day(x)), sql("EXTRACT(DAY FROM `x`)"))
expect_equal(test_translate_sql(mday(x)), sql("EXTRACT(DAY FROM `x`)"))
expect_equal(test_translate_sql(yday(x)), sql("EXTRACT('dayofyear', `x`)"))
expect_equal(test_translate_sql(wday(x)), sql("EXTRACT('dayofweek', DATE(`x`) + 0) + 1.0"))
Expand All @@ -88,12 +88,52 @@ test_that("custom lubridate functions translated correctly", {

expect_equal(test_translate_sql(seconds(x)), sql("INTERVAL '`x` second'"))
expect_equal(test_translate_sql(minutes(x)), sql("INTERVAL '`x` minute'"))
expect_equal(test_translate_sql(hours(x)), sql("INTERVAL '`x` hour'"))
expect_equal(test_translate_sql(days(x)), sql("INTERVAL '`x` day'"))
expect_equal(test_translate_sql(weeks(x)), sql("INTERVAL '`x` week'"))
expect_equal(test_translate_sql(months(x)), sql("INTERVAL '`x` month'"))
expect_equal(test_translate_sql(years(x)), sql("INTERVAL '`x` year'"))

expect_equal(test_translate_sql(floor_date(x, 'month')), sql("DATE_TRUNC('month', `x`)"))
expect_equal(test_translate_sql(floor_date(x, 'week')), sql("DATE_TRUNC('week', `x`)"))
expect_equal(test_translate_sql(hours(x)), sql("INTERVAL '`x` hour'"))
expect_equal(test_translate_sql(days(x)), sql("INTERVAL '`x` day'"))
expect_equal(test_translate_sql(weeks(x)), sql("INTERVAL '`x` week'"))
expect_equal(test_translate_sql(months(x)), sql("INTERVAL '`x` month'"))
expect_equal(test_translate_sql(years(x)), sql("INTERVAL '`x` year'"))

expect_equal(test_translate_sql(floor_date(x, "month")), sql("DATE_TRUNC('month', `x`)"))
expect_equal(test_translate_sql(floor_date(x, "week")), sql("DATE_TRUNC('week', `x`)"))
})

test_that("min() and max()", {
local_con(simulate_snowflake())

expect_equal(test_translate_sql(min(x, na.rm = TRUE)), sql("MIN(`x`) OVER ()"))
expect_equal(test_translate_sql(max(x, na.rm = TRUE)), sql("MAX(`x`) OVER ()"))

# na.rm = FALSE is ignored
# https://docs.snowflake.com/en/sql-reference/functions/min
# https://docs.snowflake.com/en/sql-reference/functions/max
# NULL values are ignored unless all the records are NULL, in which case a NULL value is returned.
expect_equal(
test_translate_sql(min(x, na.rm = TRUE)),
test_translate_sql(min(x, na.rm = FALSE))
)

expect_equal(
test_translate_sql(max(x, na.rm = TRUE)),
test_translate_sql(max(x, na.rm = FALSE))
)
})

test_that("pmin() and pmax() respect na.rm", {
local_con(simulate_snowflake())

# Snowflake default for LEAST/GREATEST: If any of the argument values is NULL, the result is NULL.
# https://docs.snowflake.com/en/sql-reference/functions/least
# https://docs.snowflake.com/en/sql-reference/functions/greatest

# na.rm = TRUE: override default behavior for Snowflake (only supports pairs)
expect_equal(test_translate_sql(pmin(x, y, na.rm = TRUE)), sql("COALESCE(IFF(`x` <= `y`, `x`, `y`), `x`, `y`)"))
expect_equal(test_translate_sql(pmax(x, y, na.rm = TRUE)), sql("COALESCE(IFF(`x` >= `y`, `x`, `y`), `x`, `y`)"))

expect_snapshot(test_translate_sql(pmin(x, y, z, na.rm = TRUE)))
expect_snapshot(test_translate_sql(pmax(x, y, z, na.rm = TRUE)))

# na.rm = FALSE: leverage default behavior for Snowflake
expect_equal(test_translate_sql(pmin(x, y, z, na.rm = FALSE)), sql("LEAST(`x`, `y`, `z`)"))
expect_equal(test_translate_sql(pmax(x, y, z, na.rm = FALSE)), sql("GREATEST(`x`, `y`, `z`)"))
})

0 comments on commit 5fa4410

Please sign in to comment.