From 068a88f7b63a8992254a4a20332d43ef2aace45f Mon Sep 17 00:00:00 2001 From: Adam Black Date: Tue, 9 Jan 2024 18:13:29 +0100 Subject: [PATCH] Add more date/time translations (#1357) For SQL server, redshift, postgres, and snowflake. --- NEWS.md | 4 +++ R/backend-mssql.R | 37 ++++++++++++++++++++++++- R/backend-oracle.R | 26 ++++++++++++++++- R/backend-postgres.R | 35 +++++++++++++++++++++++ R/backend-redshift.R | 35 +++++++++++++++++++++++ R/backend-snowflake.R | 35 +++++++++++++++++++++++ R/backend-spark-sql.R | 37 ++++++++++++++++++++++++- R/tidyeval.R | 2 +- tests/testthat/test-backend-mssql.R | 21 ++++++++++++++ tests/testthat/test-backend-oracle.R | 16 +++++++++++ tests/testthat/test-backend-postgres.R | 21 ++++++++++++++ tests/testthat/test-backend-redshift.R | 21 ++++++++++++++ tests/testthat/test-backend-snowflake.R | 21 ++++++++++++++ tests/testthat/test-backend-spark-sql.R | 20 +++++++++++++ tests/testthat/test-tidyeval.R | 1 + 15 files changed, 328 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/test-backend-spark-sql.R diff --git a/NEWS.md b/NEWS.md index 9dfbc547f..c1fb3f5e2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # dbplyr (development version) +* Add translations for clock functions `add_years()`, `add_days()`, + `date_build()`, `get_year()`, `get_month()`, `get_day()`, + and `base::difftime()` on SQL server, Redshift, Snowflake, and Postgres. + * SQL server: `filter()` does a better job of converting logical vectors from bit to boolean (@ejneer, #1288). diff --git a/R/backend-mssql.R b/R/backend-mssql.R index ec259797a..71880030d 100644 --- a/R/backend-mssql.R +++ b/R/backend-mssql.R @@ -350,6 +350,41 @@ simulate_mssql <- function(version = "15.0") { sql_expr(DATEPART(QUARTER, !!x)) } }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(DATEFROMPARTS(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(DATEPART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATEPART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATEPART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) + } ) if (mssql_version(con) >= "11.0") { # MSSQL 2012 @@ -607,7 +642,7 @@ mssql_update_where_clause <- function(qry) { } qry$where <- lapply( - qry$where, + qry$where, function(x) set_expr(x, bit_to_boolean(get_expr(x))) ) qry diff --git a/R/backend-oracle.R b/R/backend-oracle.R index 9e2bfaa6e..fca852660 100644 --- a/R/backend-oracle.R +++ b/R/backend-oracle.R @@ -145,7 +145,31 @@ sql_translation.Oracle <- function(con) { # lubridate -------------------------------------------------------------- today = function() sql_expr(TRUNC(CURRENT_TIMESTAMP)), - now = function() sql_expr(CURRENT_TIMESTAMP) + now = function() sql_expr(CURRENT_TIMESTAMP), + + # clock ------------------------------------------------------------------ + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr((!!x + NUMTODSINTERVAL(!!n, 'day'))) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr((!!x + NUMTODSINTERVAL(!!n * 365.25, 'day'))) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(CEIL(CAST(!!time2 %AS% DATE) - CAST(!!time1 %AS% DATE))) + } + ), base_odbc_agg, base_odbc_win diff --git a/R/backend-postgres.R b/R/backend-postgres.R index 812c033bb..3a3a6a9b3 100644 --- a/R/backend-postgres.R +++ b/R/backend-postgres.R @@ -235,6 +235,41 @@ sql_translation.PqConnection <- function(con) { ) sql_expr(DATE_TRUNC(!!unit, !!x)) }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 day')") + }, + add_years = function(x, n, ...) { + check_dots_empty() + glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 year')") + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(make_date(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(date_part('year', !!x)) + }, + get_month = function(x) { + sql_expr(date_part('month', !!x)) + }, + get_day = function(x) { + sql_expr(date_part('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr((CAST(!!time2 %AS% DATE) - CAST(!!time1 %AS% DATE))) + }, ), sql_translator(.parent = base_agg, cor = sql_aggregate_2("CORR"), diff --git a/R/backend-redshift.R b/R/backend-redshift.R index 735085ebb..f40186f3e 100644 --- a/R/backend-redshift.R +++ b/R/backend-redshift.R @@ -60,6 +60,41 @@ sql_translation.RedshiftConnection <- function(con) { str_replace = sql_not_supported("str_replace"), str_replace_all = function(string, pattern, replacement) { sql_expr(REGEXP_REPLACE(!!string, !!pattern, !!replacement)) + }, + + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + glue_sql2(sql_current_con(), "TO_DATE(CAST({.val year} AS TEXT) || '-' CAST({.val month} AS TEXT) || '-' || CAST({.val day} AS TEXT)), 'YYYY-MM-DD')") + }, + get_year = function(x) { + sql_expr(DATE_PART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATE_PART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATE_PART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) } ), sql_translator(.parent = postgres$aggregate, diff --git a/R/backend-snowflake.R b/R/backend-snowflake.R index 11254577d..a72561524 100644 --- a/R/backend-snowflake.R +++ b/R/backend-snowflake.R @@ -210,6 +210,41 @@ sql_translation.Snowflake <- function(con) { ) sql_expr(DATE_TRUNC(!!unit, !!x)) }, + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(DAY, !!n, !!x)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(DATEADD(YEAR, !!n, !!x)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + # https://docs.snowflake.com/en/sql-reference/functions/date_from_parts + sql_expr(DATE_FROM_PARTS(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(DATE_PART('year', !!x)) + }, + get_month = function(x) { + sql_expr(DATE_PART('month', !!x)) + }, + get_day = function(x) { + sql_expr(DATE_PART('day', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(DATEDIFF(day, !!time1, !!time2)) + }, # LEAST / GREATEST on Snowflake will not respect na.rm = TRUE by default (similar to Oracle/Access) # https://docs.snowflake.com/en/sql-reference/functions/least # https://docs.snowflake.com/en/sql-reference/functions/greatest diff --git a/R/backend-spark-sql.R b/R/backend-spark-sql.R index 79a556107..5b43651b5 100644 --- a/R/backend-spark-sql.R +++ b/R/backend-spark-sql.R @@ -36,7 +36,42 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL") #' @export `sql_translation.Spark SQL` <- function(con) { sql_variant( - base_odbc_scalar, + sql_translator(.parent = base_odbc_scalar, + # clock --------------------------------------------------------------- + add_days = function(x, n, ...) { + check_dots_empty() + sql_expr(date_add(!!x, !!n)) + }, + add_years = function(x, n, ...) { + check_dots_empty() + sql_expr(add_months(!!!x, !!n*12)) + }, + date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) { + sql_expr(make_date(!!year, !!month, !!day)) + }, + get_year = function(x) { + sql_expr(date_part('YEAR', !!x)) + }, + get_month = function(x) { + sql_expr(date_part('MONTH', !!x)) + }, + get_day = function(x) { + sql_expr(date_part('DAY', !!x)) + }, + + difftime = function(time1, time2, tz, units = "days") { + + if (!missing(tz)) { + cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.") + } + + if (units[1] != "days") { + cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"') + } + + sql_expr(datediff(!!time2, !!time1)) + } + ), sql_translator(.parent = base_odbc_agg, var = sql_aggregate("VARIANCE", "var"), quantile = sql_quantile("PERCENTILE"), diff --git a/R/tidyeval.R b/R/tidyeval.R index 293fffb0e..f6d9fcc9c 100644 --- a/R/tidyeval.R +++ b/R/tidyeval.R @@ -163,7 +163,7 @@ partial_eval_sym <- function(sym, data, env) { } is_namespaced_dplyr_call <- function(call) { - packages <- c("base", "dplyr", "stringr", "lubridate") + packages <- c("base", "dplyr", "stringr", "lubridate", "clock") is_symbol(call[[1]], "::") && is_symbol(call[[2]], packages) } diff --git a/tests/testthat/test-backend-mssql.R b/tests/testthat/test-backend-mssql.R index bebaf7364..880aeb1d0 100644 --- a/tests/testthat/test-backend-mssql.R +++ b/tests/testthat/test-backend-mssql.R @@ -124,6 +124,27 @@ test_that("custom lubridate functions translated correctly", { expect_error(test_translate_sql(quarter(x, fiscal_start = 5))) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_mssql()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("DATEFROMPARTS(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("DATEFROMPARTS(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATEPART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATEPART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATEPART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_mssql()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("last_value_sql() translated correctly", { con <- simulate_mssql() expect_equal( diff --git a/tests/testthat/test-backend-oracle.R b/tests/testthat/test-backend-oracle.R index 9824c8245..1ab5462c6 100644 --- a/tests/testthat/test-backend-oracle.R +++ b/tests/testthat/test-backend-oracle.R @@ -82,3 +82,19 @@ test_that("copy_inline uses UNION ALL", { copy_inline(con, y, types = types) %>% remote_query() }) }) + +test_that("custom clock functions translated correctly", { + local_con(simulate_oracle()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("(`x` + NUMTODSINTERVAL(1.0 * 365.25, 'day'))")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("(`x` + NUMTODSINTERVAL(1.0, 'day'))")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_oracle()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("CEIL(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("CEIL(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-backend-postgres.R b/tests/testthat/test-backend-postgres.R index d62d737b4..0517f195e 100644 --- a/tests/testthat/test-backend-postgres.R +++ b/tests/testthat/test-backend-postgres.R @@ -88,6 +88,27 @@ test_that("custom lubridate functions translated correctly", { expect_equal(test_translate_sql(floor_date(x, 'week')), sql("DATE_TRUNC('week', `x`)")) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_postgres()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("(`x` + 1.0*INTERVAL'1 year')")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("(`x` + 1.0*INTERVAL'1 day')")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("MAKE_DATE(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("MAKE_DATE(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_postgres()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("(CAST(`end_date` AS DATE) - CAST(`start_date` AS DATE))")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("custom window functions translated correctly", { local_con(simulate_postgres()) diff --git a/tests/testthat/test-backend-redshift.R b/tests/testthat/test-backend-redshift.R index 9e90aa192..55e66b20f 100644 --- a/tests/testthat/test-backend-redshift.R +++ b/tests/testthat/test-backend-redshift.R @@ -57,3 +57,24 @@ test_that("copy_inline uses UNION ALL", { copy_inline(con, y, types = types) %>% remote_query() }) }) + +test_that("custom clock functions translated correctly", { + local_con(simulate_redshift()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("TO_DATE(CAST(2020.0 AS TEXT) || '-' CAST(1.0 AS TEXT) || '-' || CAST(1.0 AS TEXT)), 'YYYY-MM-DD')")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("TO_DATE(CAST(`year_column` AS TEXT) || '-' CAST(1 AS TEXT) || '-' || CAST(1 AS TEXT)), 'YYYY-MM-DD')")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_redshift()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-backend-snowflake.R b/tests/testthat/test-backend-snowflake.R index 0a6e58157..c628308d5 100644 --- a/tests/testthat/test-backend-snowflake.R +++ b/tests/testthat/test-backend-snowflake.R @@ -102,6 +102,27 @@ test_that("custom lubridate functions translated correctly", { expect_equal(test_translate_sql(floor_date(x, "week")), sql("DATE_TRUNC('week', `x`)")) }) +test_that("custom clock functions translated correctly", { + local_con(simulate_snowflake()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("DATEADD(YEAR, 1.0, `x`)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATEADD(DAY, 1.0, `x`)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("DATE_FROM_PARTS(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("DATE_FROM_PARTS(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('year', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('month', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('day', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_snowflake()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(day, `start_date`, `end_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(day, `start_date`, `end_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) + test_that("min() and max()", { local_con(simulate_snowflake()) diff --git a/tests/testthat/test-backend-spark-sql.R b/tests/testthat/test-backend-spark-sql.R new file mode 100644 index 000000000..e1276c7a0 --- /dev/null +++ b/tests/testthat/test-backend-spark-sql.R @@ -0,0 +1,20 @@ +test_that("custom clock functions translated correctly", { + local_con(simulate_spark_sql()) + expect_equal(test_translate_sql(add_years(x, 1)), sql("ADD_MONTHS('`x`', 1.0 * 12.0)")) + expect_equal(test_translate_sql(add_days(x, 1)), sql("DATE_ADD(`x`, 1.0)")) + expect_error(test_translate_sql(add_days(x, 1, "dots", "must", "be empty"))) + expect_equal(test_translate_sql(date_build(2020, 1, 1)), sql("MAKE_DATE(2020.0, 1.0, 1.0)")) + expect_equal(test_translate_sql(date_build(year_column, 1L, 1L)), sql("MAKE_DATE(`year_column`, 1, 1)")) + expect_equal(test_translate_sql(get_year(date_column)), sql("DATE_PART('YEAR', `date_column`)")) + expect_equal(test_translate_sql(get_month(date_column)), sql("DATE_PART('MONTH', `date_column`)")) + expect_equal(test_translate_sql(get_day(date_column)), sql("DATE_PART('DAY', `date_column`)")) +}) + +test_that("difftime is translated correctly", { + local_con(simulate_spark_sql()) + expect_equal(test_translate_sql(difftime(start_date, end_date, units = "days")), sql("DATEDIFF(`end_date`, `start_date`)")) + expect_equal(test_translate_sql(difftime(start_date, end_date)), sql("DATEDIFF(`end_date`, `start_date`)")) + + expect_error(test_translate_sql(difftime(start_date, end_date, units = "auto"))) + expect_error(test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))) +}) diff --git a/tests/testthat/test-tidyeval.R b/tests/testthat/test-tidyeval.R index 652e90207..ef105ce8c 100644 --- a/tests/testthat/test-tidyeval.R +++ b/tests/testthat/test-tidyeval.R @@ -38,6 +38,7 @@ test_that("namespaced calls to dplyr functions are stripped", { # hack to avoid check complaining about not declared imports expect_equal(partial_eval(rlang::parse_expr("stringr::str_to_lower(x)"), lf), expr(str_to_lower(x))) expect_equal(partial_eval(rlang::parse_expr("lubridate::today()"), lf), expr(today())) + expect_equal(partial_eval(rlang::parse_expr("clock::add_years(x, 1)"), lf), expr(add_years(x, 1))) }) test_that("use quosure environment for unevaluted formulas", {