Skip to content

Commit

Permalink
Merge pull request #16 from mrc-ide/mrc-5578
Browse files Browse the repository at this point in the history
Parse systems that periodically zero
  • Loading branch information
weshinsley authored Aug 9, 2024
2 parents f6ec9fe + eac4fed commit 3b91997
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 14 deletions.
3 changes: 0 additions & 3 deletions R/constants.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
SPECIAL_LHS <- c(
"initial", "deriv", "update", "output", "dim", "config", "compare")

FUNCTIONS <- list(
exp = function(x) {})
10 changes: 9 additions & 1 deletion R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,15 @@ generate_dust_system_rhs <- function(dat) {

generate_dust_system_zero_every <- function(dat) {
args <- c("const shared_state&" = "shared")
body <- "return dust2::zero_every_type<real_type>();"
if (is.null(dat$zero_every)) {
body <- "return dust2::zero_every_type<real_type>();"
} else {
index <- match(names(dat$zero_every), dat$variables) - 1
every <- vcapply(dat$zero_every, generate_dust_sexp, dat$sexp_data,
USE.NAMES = FALSE)
str <- paste(sprintf("{%s, {%s}}", every, index), collapse = ", ")
body <- sprintf("return dust2::zero_every_type<real_type>{%s};", str)
}
cpp_function("auto", "zero_every", args, body, static = TRUE)
}

Expand Down
3 changes: 3 additions & 0 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ odin_parse_quo <- function(quo, input_type, compatibility, call) {
system$exprs, equations, system$variables, system$data$name, call)
storage <- parse_storage(
equations, phases, system$variables, system$data, call)
zero_every <- parse_zero_every(system$time, phases, equations,
system$variables, call)

ret <- list(time = system$time,
class = "odin",
Expand All @@ -24,6 +26,7 @@ odin_parse_quo <- function(quo, input_type, compatibility, call) {
equations = equations,
phases = phases,
storage = storage,
zero_every = zero_every,
data = system$data)

parse_check_usage(exprs, ret, call)
Expand Down
4 changes: 2 additions & 2 deletions R/parse_error.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ cnd_footer.odin_parse_error <- function(cnd, ...) {
## TODO: later, we might want to point at specific bits of the error
## and say "here, this is where you are wrong" but that's not done
## yet...
src <- cnd$src
src <- cnd$src[order(viapply(cnd$src, function(x) x$index %||% 1L))]

if (is.null(src[[1]]$str)) {
context <- unlist(lapply(cnd$src, function(x) deparse1(x$value)))
context <- unlist(lapply(src, function(x) deparse1(x$value)))
} else {
line <- unlist(lapply(src, function(x) seq(x$start, x$end)))
src_str <- unlist(lapply(src, "[[", "str"))
Expand Down
54 changes: 51 additions & 3 deletions R/parse_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ parse_expr_assignment <- function(expr, src, call) {
special <- "parameter"
}

if (identical(special, "initial")) {
zero_every <- lhs$args$zero_every
if (!rlang::is_missing(zero_every) && !is.null(zero_every)) {
if (!rlang::is_integerish(zero_every)) {
odin_parse_error(
"Argument to 'zero_every' must be an integer",
"E1019", src, call)
}
if (!(identical(rhs$expr, 0) || identical(rhs$expr, 0L))) {
odin_parse_error(
"Initial condition of periodically zeroed variable must be 0",
"E1020", src, call)
}
}
}

list(special = special,
lhs = lhs,
rhs = rhs,
Expand All @@ -48,14 +64,32 @@ parse_expr_assignment_lhs <- function(lhs, src, call) {
special <- NULL
name <- NULL

if (rlang::is_call(lhs, SPECIAL_LHS)) {
special_def <- list(
initial = function(name, zero_every) NULL,
update = function(name) NULL,
deriv = function(name) NULL,
output = function(name) NULL,
dim = function(name) NULL,
config = function(name) NULL,
compare = function(name) NULL)

args <- NULL
if (rlang::is_call(lhs, names(special_def))) {
special <- deparse1(lhs[[1]])
if (length(lhs) != 2 || !is.null(names(lhs))) {
m <- match_call(lhs, special_def[[special]])
if (!m$success) {
odin_parse_error(c("Invalid special function call",
x = conditionMessage(m$error)),
"E1003", src, call)
}
if (rlang::is_missing(m$value$name)) {
odin_parse_error(
c("Invalid special function call",
i = "Expected a single unnamed argument to '{special}()'"),
i = paste("Missing target for '{special}()', typically the first",
"(unnamed) argument")),
"E1003", src, call)
}

if (special == "compare") {
## TODO: a good candidate for pointing at the source location of
## the error.
Expand All @@ -68,6 +102,14 @@ parse_expr_assignment_lhs <- function(lhs, src, call) {
"E1004", src, call)
}
lhs <- lhs[[2]]
if (length(m$value) > 2) {
args <- as.list(m$value[-(1:2)])
i <- !vlapply(args, rlang::is_missing)
args <- args[i]
if (length(args) == 0) {
args <- NULL
}
}
}

is_array <- rlang::is_call(lhs, "[")
Expand All @@ -81,6 +123,12 @@ parse_expr_assignment_lhs <- function(lhs, src, call) {
lhs <- list(
name = name,
special = special)

if (!is.null(args)) {
lhs$args <- args
}

lhs
}


Expand Down
6 changes: 5 additions & 1 deletion R/parse_prepare.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ parse_prepare <- function(quo, input_type, call) {
"Expected 'expr' to be a multiline expression within curly braces",
arg = "expr", call = call)
}
exprs <- lapply(as.list(info$value[-1]), function(x) list(value = x))
exprs <- as.list(info$value[-1])
exprs <- Map(list,
value = exprs,
index = seq_along(exprs))
} else {
if (info$type == "file") {
filename <- info$value
Expand All @@ -31,6 +34,7 @@ parse_prepare <- function(quo, input_type, call) {
})
exprs <- Map(list,
value = as.list(exprs),
index = seq_along(exprs),
start = start,
end = end,
str = src_str)
Expand Down
24 changes: 24 additions & 0 deletions R/parse_system.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,27 @@ parse_storage <- function(equations, phases, variables, data, call) {
type = type,
packing = packing)
}


parse_zero_every <- function(time, phases, equations, variables, call) {
zero_every <- lapply(phases$initial$variables, function(eq) {
eq$lhs$args$zero_every
})
i <- !vlapply(zero_every, is.null)
if (!any(i)) {
return(NULL)
}

names(zero_every) <- variables
zero_every <- zero_every[i]

is_zero <- function(expr) {
identical(expr, 0) || identical(expr, 0L)
}

## If time is continuous, we should also check that the reset
## variables don't reference any other variables, even indirectly;
## do this as mrc-5615.

zero_every
}
Binary file modified R/sysdata.rda
Binary file not shown.
6 changes: 6 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ vlapply <- function(...) {
}


viapply <- function(...) {
vapply(..., FUN.VALUE = 1L)
}


vcapply <- function(...) {
vapply(..., FUN.VALUE = "")
}
Expand All @@ -31,6 +36,7 @@ match_value <- function(x, choices, name = deparse(substitute(x)), arg = name,
}


## See mrc-5614 for ideas about improving this, for later.
match_call <- function(call, fn) {
## We'll probably expand on the error case here to return something
## much nicer?
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/helper-odin2.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ method_args <- list(
update = "static void update(real_type time, real_type dt, const real_type* state, const shared_state& shared, internal_state& internal, rng_state_type& rng_state, real_type* state_next) {",
rhs = "static void rhs(real_type time, const real_type* state, const shared_state& shared, internal_state& internal, real_type* state_deriv) {",
compare_data = "static real_type compare_data(real_type time, real_type dt, const real_type* state, const data_type& data, const shared_state& shared, internal_state& internal, rng_state_type& rng_state) {",
zero_every = "static auto zero_every(const shared_state& shared) {",
adjoint_size = "static size_t adjoint_size(const shared_state& shared) {",
adjoint_update = "static void adjoint_update(real_type time, real_type dt, const real_type* state, const real_type* adjoint, const shared_state& shared, internal_state& internal, real_type* adjoint_next) {",
adjoint_compare_data = "static void adjoint_compare_data(real_type time, real_type dt, const real_type* state, const real_type* adjoint, const data_type& data, const shared_state& shared, internal_state& internal, real_type* adjoint_next) {",
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,33 @@ test_that("can generate simple stochastic system", {
" state_next[0] = mcstate::random::normal(rng_state, x, 1);",
"}"))
})


test_that("can generate empty zero_every method", {
dat <- odin_parse({
update(x) <- 0
initial(x) <- 0
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_zero_every(dat),
c(method_args$zero_every,
" return dust2::zero_every_type<real_type>();",
"}"))
})


test_that("can generate nontrivial zero_every method", {
dat <- odin_parse({
update(x) <- 0
initial(x) <- 0
update(y) <- 1
initial(y, zero_every = 4) <- 0
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_zero_every(dat),
c(method_args$zero_every,
" return dust2::zero_every_type<real_type>{{4, {1}}};",
"}"))
})
4 changes: 2 additions & 2 deletions tests/testthat/test-parse-expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ test_that("allow calls on lhs", {

test_that("require that special calls are (currently) simple", {
expect_error(
parse_expr(quote(initial(x, TRUE) <- 1), NULL, NULL),
parse_expr(quote(update(x, TRUE) <- 1), NULL, NULL),
"Invalid special function call")
expect_error(
parse_expr(quote(initial() <- 1), NULL, NULL),
"Invalid special function call")
expect_error(
err <- expect_error(
parse_expr(quote(initial(x = 1) <- 1), NULL, NULL),
"Invalid special function call")
})
Expand Down
8 changes: 6 additions & 2 deletions tests/testthat/test-parse-prepare.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ test_that("can load code from an expression", {
parse_prepare(quo, NULL, NULL),
list(type = "expression",
filename = NULL,
exprs = list(list(value = quote(a <- 1)),
list(value = quote(b <- 2)))))
exprs = list(list(value = quote(a <- 1), index = 1),
list(value = quote(b <- 2), index = 2))))
})


Expand Down Expand Up @@ -155,10 +155,12 @@ test_that("can read expressions from file", {
list(type = "file",
filename = path,
exprs = list(list(value = quote(a <- 1),
index = 1,
start = 1,
end = 1,
str = "a <- 1"),
list(value = quote(b <- 2),
index = 2,
start = 2,
end = 2,
str = "b <- 2"))))
Expand All @@ -172,10 +174,12 @@ test_that("can read expressions from file", {
list(type = "text",
filename = NULL,
exprs = list(list(value = quote(a <- 1),
index = 1,
start = 1,
end = 1,
str = "a <- 1"),
list(value = quote(b <- 2),
index = 2,
start = 2,
end = 2,
str = "b <- 2"))))
Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ test_that("throw error with context", {
expect_equal(
err$src,
list(list(value = quote(b <- parameter(invalid = TRUE)),
index = 3,
start = 3,
end = 3,
str = c("b<-parameter(invalid=TRUE)"))))
Expand Down Expand Up @@ -93,3 +94,51 @@ test_that("throw stochastic parse error sensibly", {
deparse(quote(update(x) <- Normal(mu = 0, sd = 1))),
fixed = TRUE)
})


test_that("can parse system that resets", {
d <- odin_parse({
update(x) <- x + 1
initial(x, zero_every = 4) <- 0
update(y) <- y + 1
initial(y) <- 0
})
expect_equal(d$zero_every, list(x = 4))
})


test_that("zero_reset requires that initial conditions are zero", {
expect_error(
odin_parse({
update(x) <- x + 1
initial(x, zero_every = 1) <- 1
}),
"Initial condition of periodically zeroed variable must be 0")
})


test_that("zero_reset requires an integer argument", {
expect_error(
odin_parse({
update(x) <- x + 1
initial(x, zero_every = 1.4) <- 1
}),
"Argument to 'zero_every' must be an integer")
expect_error(
odin_parse({
update(x) <- x + 1
initial(x, zero_every = a) <- 1
}),
"Argument to 'zero_every' must be an integer")
})


test_that("can parse ode system that resets", {
d <- odin_parse({
deriv(x) <- 1
initial(x, zero_every = 4) <- 0
deriv(y) <- 1
initial(y) <- 0
})
expect_equal(d$zero_every, list(x = 4))
})
22 changes: 22 additions & 0 deletions vignettes/errors.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,28 @@ Example calls that will fail:

The details for the failure will be included in the body of the error message.

# `E1019`

Invalid value for the `zero_every` argument to `initial()`. At present, this must be a literal value, and that value must be an integer-like number (e.g., 2 or 2L). We may relax this in future to allow more flexibility (e.g., a variable which contains an integer-like number).

Examples that would error

```
initial(x, zero_every = a) <- 0
initial(y, zero_every = 2.5) <- 0
```

# `E1020`

The right hand side of a call to `initial()` that uses the `zero_every` argument was not 0, but it must be. Because we periodically reset values to zero, any initial condition other than zero makes no sense. See [the `dust2` docs on periodic variables](https://mrc-ide.github.io/dust2/articles/periodic.html) for details.

Examples that would error

```
initial(x, zero_every = 1) <- 10
initial(x, zero_every = 1) <- a
```

# `E2001`

Your system of equations does not include any expressions with `initial()` on the lhs. This is what we derive the set of variables from, so at least one must be present.
Expand Down

0 comments on commit 3b91997

Please sign in to comment.