Skip to content

Commit

Permalink
Merge pull request #14 from mrc-ide/mrc-5588
Browse files Browse the repository at this point in the history
Translate old user calls to new parameter calls
  • Loading branch information
richfitz authored Aug 7, 2024
2 parents ecbe675 + f659d0c commit 44b0db2
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 12 deletions.
13 changes: 11 additions & 2 deletions R/odin.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,24 @@
##' an error if the wrong type of input is given. Using this may be
##' beneficial in programmatic environments.
##'
##' @param compatibility Compatibility mode to use. Valid options are
##' "warning", which updates code that can be fixed, with warnings,
##' and "error", which will error. The option "silent" will
##' silently rewrite code, but this is not recommended for general
##' use as eventually the compatibility mode will be removed (this
##' option is primarily intended for comparing output of odin1 and
##' odin2 models against old code).
##'
##' @inheritParams dust2::dust_compile
##'
##' @return A `dust_system_generator` object, suitable for using with
##' dust functions (starting from [dust2::dust_system_create])
##'
##' @export
odin <- function(expr, input_type = NULL, quiet = FALSE, workdir = NULL,
debug = FALSE, skip_cache = FALSE) {
dat <- odin_parse_quo(rlang::enquo(expr), input_type, call)
debug = FALSE, skip_cache = FALSE, compatibility = "warning") {
call <- environment()
dat <- odin_parse_quo(rlang::enquo(expr), input_type, compatibility, call)
code <- generate_dust_system(dat)
tmp <- tempfile(fileext = ".cpp")
on.exit(unlink(tmp))
Expand Down
9 changes: 5 additions & 4 deletions R/parse.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
odin_parse <- function(expr, input_type = NULL) {
odin_parse <- function(expr, input_type = NULL, compatibility = "warning") {
call <- environment()
odin_parse_quo(rlang::enquo(expr), input_type, call)
odin_parse_quo(rlang::enquo(expr), input_type, compatibility, call)
}


odin_parse_quo <- function(quo, input_type, call) {
odin_parse_quo <- function(quo, input_type, compatibility, call) {
match_value(compatibility, c("silent", "warning", "error"), call = call)
dat <- parse_prepare(quo, input_type, call)
dat$exprs <- parse_compat(dat$exprs, compatibility, call)
exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call))

system <- parse_system_overall(exprs, call)
equations <- parse_system_depends(
system$exprs$equations, system$variables, call)
Expand Down
77 changes: 77 additions & 0 deletions R/parse_compat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
parse_compat <- function(exprs, action, call) {
## Things we can translate:
##
## user() -> parameter()
## rbinom() -> Binomial() [etc]
## dt -> drop
## step -> time
exprs <- lapply(exprs, parse_compat_fix_user, call)
parse_compat_report(exprs, action, call)
exprs
}


parse_compat_fix_user <- function(expr, call) {
is_user_assignment <-
rlang::is_call(expr$value, c("<-", "=")) &&
rlang::is_call(expr$value[[3]], "user")
if (is_user_assignment) {
res <- match_call(
expr$value[[3]],
function(default, integer, min, max, ...) NULL)
if (!res$success) {
odin_parse_error(
"Failed to translate your 'user()' expression to use 'parameter()'",
"E1016", expr, call = call)
}
for (arg in c("integer", "min", "max")) {
if (!rlang::is_missing(res$value[[arg]])) {
odin_parse_error(
c("Can't yet translate 'user()' calls that use the '{arg}' argument",
i = paste("We don't support this argument in parameter() yet,",
"but once we do we will support translation")),
"E0001", expr, call = call)
}
}

original <- expr$value
args <- list(as.name("parameter"))
if (!rlang::is_missing(res$value$default)) {
args <- c(args, list(res$value$default))
}
expr$value[[3]] <- as.call(args)
expr$compat <- list(type = "user", original = original)
}
expr
}


parse_compat_report <- function(exprs, action, call) {
i <- !vlapply(exprs, function(x) is.null(x$compat))
if (action != "silent" && any(i)) {
description <- c(
user = "Replace calls to 'user()' with 'parameter()'")
type <- vcapply(exprs[i], function(x) x$compat$type)

detail <- NULL
for (t in intersect(names(description), type)) {
j <- i[type == t]
## Getting line numbers here is really hard, so let's just not
## try for now and do this on deparsed expressions.
updated <- vcapply(exprs[j], function(x) deparse1(x$value))
original <- vcapply(exprs[j], function(x) deparse1(x$compat$original))
context_t <- set_names(
c(rbind(updated, original, deparse.level = 0)),
rep(c("x", "v"), length(updated)))
detail <- c(detail, description[[t]], cli_nbsp(context_t))
}

header <- "Found {sum(i)} compatibility issue{?s}"

if (action == "error") {
odin_parse_error(c(header, detail), "E1017", exprs[i], call)
} else {
cli::cli_warn(c(header, detail), call = call)
}
}
}
22 changes: 19 additions & 3 deletions R/parse_error.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ cnd_footer.odin_parse_error <- function(cnd, ...) {
## and say "here, this is where you are wrong" but that's not done
## yet...
src <- cnd$src

if (is.null(src[[1]]$str)) {
context <- unlist(lapply(cnd$src, function(x) deparse1(x$value)))
} else {
line <- unlist(lapply(src, function(x) seq(x$start, x$end)))
src <- unlist(lapply(src, "[[", "str"))
context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src)
src_str <- unlist(lapply(src, "[[", "str"))
context <- sprintf("%s| %s", cli_nbsp(format(line)), src_str)
}

code <- cnd$code
Expand All @@ -35,8 +36,23 @@ cnd_footer.odin_parse_error <- function(cnd, ...) {
explain <- cli::format_inline(
"For more information, run {.run odin2::odin_error_explain(\"{code}\")}")

## It's quite annoying to try and show the original and updated code
## here at the same time so instead let's just let the user know
## that things might not be totally accurate.
uses_compat <- !vlapply(src, function(x) is.null(x$compat))
if (any(uses_compat)) {
compat_warning <- c(
"!" = cli::format_inline(
paste("{cli::qty(length(src))}{?The expression/Expressions} above",
"{?has/have} been translated while updating for use with",
"odin2, the context may not reflect your original code.")))
} else {
compat_warning <- NULL
}

c(">" = "Context:", context,
"i" = explain)
"i" = explain,
compat_warning)
}


Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
5 changes: 5 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ set_names <- function(x, nms) {
names(x) <- nms
x
}


cli_nbsp <- function(x) {
gsub(" ", "\u00a0", x)
}
11 changes: 10 additions & 1 deletion man/odin.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

82 changes: 82 additions & 0 deletions tests/testthat/test-parse-compat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
test_that("can translate user()", {
d <- odin_parse({
a <- user(1)
initial(x) <- 0
update(x) <- x + a
}, compatibility = "silent")
expect_equal(d$equations$a$src$value,
quote(a <- parameter(1)))
expect_equal(d$equations$a$src$compat,
list(type = "user", original = quote(a <- user(1))))
})


test_that("can control severity of reporting", {
code <- "a <- user(1)\ninitial(x) <- 0\nupdate(x) <- x + a"
expect_silent(odin_parse(code, compatibility = "silent"))
w <- expect_warning(
odin_parse(code, compatibility = "warning"),
"Found 1 compatibility issue")

e <- expect_error(
odin_parse(code, compatibility = "error"),
"Found 1 compatibility issue")

expect_true(startsWith(conditionMessage(e), conditionMessage(w)))
})


test_that("can translate simple user calls", {
expect_equal(
parse_compat_fix_user(list(value = quote(a <- user()))),
list(value = quote(a <- parameter()),
compat = list(type = "user", original = quote(a <- user()))))
expect_equal(
parse_compat_fix_user(list(value = quote(a <- user(1)))),
list(value = quote(a <- parameter(1)),
compat = list(type = "user", original = quote(a <- user(1)))))
expect_error(
parse_compat_fix_user(list(value = quote(a <- user(integer = TRUE)))),
"Can't yet translate 'user()' calls that use the 'integer' argument",
fixed = TRUE,
class = "odin_parse_error")
expect_error(
parse_compat_fix_user(list(value = quote(a <- user(min = 0)))),
"Can't yet translate 'user()' calls that use the 'min' argument",
fixed = TRUE,
class = "odin_parse_error")
expect_error(
parse_compat_fix_user(list(value = quote(a <- user(max = 1)))),
"Can't yet translate 'user()' calls that use the 'max' argument",
fixed = TRUE,
class = "odin_parse_error")
})


test_that("handle errors that occur in translated code", {
err <- expect_error(
odin_parse({
a <- user(sqrt(b))
b <- 1
initial(x) <- 0
update(x) <- x + a
}, compatibility = "silent",),
"Invalid default argument to 'parameter()': sqrt(b)",
fixed = TRUE)

expect_match(
conditionMessage(err),
"The expression above has been translated")
})


test_that("handle failure to pass a user call", {
expect_error(
odin_parse({
a <- user(min = 1, min = 1)
initial(x) <- 0
update(x) <- x + a
}),
"Failed to translate your 'user()' expression to use 'parameter()'",
fixed = TRUE)
})
6 changes: 4 additions & 2 deletions tests/testthat/test-zzz-integration.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ test_that("can compile a simple ode model", {
expect_s3_class(res(), "dust_system_generator")
expect_mapequal(res()$properties,
list(time_type = "continuous",
has_compare = FALSE))
has_compare = FALSE,
has_adjoint = FALSE))

pars <- list(N = 100, beta = 0.2, gamma = 0.1, I0 = 1)
sys <- dust2::dust_system_create(res(), pars, 1)
Expand Down Expand Up @@ -84,7 +85,8 @@ test_that("can compile a discrete-time model that compares to data", {
expect_s3_class(res(), "dust_system_generator")
expect_mapequal(res()$properties,
list(time_type = "discrete",
has_compare = TRUE))
has_compare = TRUE,
has_adjoint = FALSE))

pars <- list(N = 1000, beta = 0.2, gamma = 0.1, I0 = 10,
exp_noise = 1e6)
Expand Down
10 changes: 10 additions & 0 deletions vignettes/errors.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ which is impossible. Parameters that are differentiable need to be able to be s

Probably you want to set at least one of these to `FALSE`, or omit the argument and accept the default.

# `E1016`

Failed to translate a `user()` expression (valid in odin before version 2) into a call to `parameter()`. This was likely code that would not work in old odin either.

# `E1017`

Compatibility issues were present in the system (e.g., using `user()` instead of `parameter()` and the compatibility action was `"error"`. You can, in the short term, disable failure here by using `compatibility = "warning"` or `compatibility = "silent"`, but eventually this will become an error that is always thrown when running with old odin code.

The error message will explain how to update your code to use new odin2 syntax.

# `E2001`

Your system of equations does not include any expressions with `initial()` on the lhs. This is what we derive the set of variables from, so at least one must be present.
Expand Down

0 comments on commit 44b0db2

Please sign in to comment.