Skip to content

Commit

Permalink
Merge pull request #11 from mrc-ide/mrc-5567
Browse files Browse the repository at this point in the history
Generate adjoint model for simple discrete-time systems
  • Loading branch information
richfitz authored Aug 7, 2024
2 parents 44b0db2 + 1751bfa commit 59bda99
Show file tree
Hide file tree
Showing 13 changed files with 493 additions and 19 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Suggests:
knitr,
rmarkdown,
mockery,
numDeriv,
testthat (>= 3.0.0),
withr
Config/testthat/edition: 3
Expand Down
147 changes: 139 additions & 8 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ generate_dust_system <- function(dat) {
body$add(sprintf(" %s", generate_dust_system_rhs(dat)))
body$add(sprintf(" %s", generate_dust_system_zero_every(dat)))
body$add(sprintf(" %s", generate_dust_system_compare_data(dat)))
body$add(sprintf(" %s", generate_dust_system_adjoint(dat)))
body$add("};")
body$get()
}
Expand All @@ -35,14 +36,20 @@ generate_prepare <- function(dat) {


generate_dust_system_attributes <- function(dat) {
if (length(dat$phases$compare) > 0) {
if (length(dat$phases$compare) == 0) {
has_compare <- NULL
} else {
has_compare <- "// [[dust2::has_compare()]]"
}
if (is.null(dat$adjoint)) {
has_adjoint <- NULL
} else {
has_compare <- NULL
has_adjoint <- "// [[dust2::has_adjoint()]]"
}
c(sprintf("// [[dust2::class(%s)]]", dat$class),
sprintf("// [[dust2::time_type(%s)]]", dat$time),
has_compare)
has_compare,
has_adjoint)
}


Expand Down Expand Up @@ -191,7 +198,7 @@ generate_dust_system_update <- function(dat) {
i <- variables %in% dat$phases$update$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$scalar) - 1))
match(variables[i], dat$storage$packing$state$scalar) - 1))
eqs <- dat$phases$update$equations
for (eq in c(dat$equations[eqs], dat$phases$update$variables)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_next")
Expand All @@ -216,7 +223,7 @@ generate_dust_system_rhs <- function(dat) {
i <- variables %in% dat$phases$deriv$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$scalar) - 1))
match(variables[i], dat$storage$packing$state$scalar) - 1))
eqs <- dat$phases$deriv$equations
for (eq in c(dat$equations[eqs], dat$phases$deriv$variables)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_deriv")
Expand Down Expand Up @@ -250,7 +257,7 @@ generate_dust_system_compare_data <- function(dat) {
i <- variables %in% dat$phases$compare$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$scalar) - 1))
match(variables[i], dat$storage$packing$state$scalar) - 1))
## TODO collision here in names with 'll'; we might need to prefix
## with compare_ perhaps?
body$add("real_type ll = 0;")
Expand All @@ -266,14 +273,134 @@ generate_dust_system_compare_data <- function(dat) {
for (eq in dat$phases$compare$compare) {
eq_args <- vcapply(eq$rhs$args, generate_dust_sexp, dat$sexp_data)
body$add(sprintf("ll += mcstate::density::%s(%s, true);",
eq$rhs$distribution, paste(eq_args, collapse = ", ")))
eq$rhs$density$cpp, paste(eq_args, collapse = ", ")))
}

body$add("return ll;")
cpp_function("real_type", "compare_data", args, body$get(), static = TRUE)
}


generate_dust_system_adjoint <- function(dat) {
if (is.null(dat$adjoint)) {
return(NULL)
}

c(generate_dust_system_adjoint_size(dat),
generate_dust_system_adjoint_update(dat),
generate_dust_system_adjoint_compare_data(dat),
generate_dust_system_adjoint_initial(dat))
}


generate_dust_system_adjoint_size <- function(dat) {
args <- c("const shared_state&" = "shared")
## We might return the _difference_ here, still undecided...
body <- sprintf("return %d;", length(dat$storage$contents$adjoint))
cpp_function("size_t", "adjoint_size", args, body, static = TRUE)
}


generate_dust_system_adjoint_update <- function(dat) {
args <- c("real_type" = "time",
"real_type" = "dt",
"const real_type*" = "state",
"const real_type*" = "adjoint",
"const shared_state&" = "shared",
"internal_state&" = "internal",
"real_type*" = "adjoint_next")
body <- collector()

variables <- dat$storage$contents$variables
i <- variables %in% dat$adjoint$update$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$state$scalar) - 1))

adjoint <- dat$storage$contents$adjoint
i <- adjoint %in% dat$adjoint$update$unpack_adjoint
body$add(sprintf("const auto %s = adjoint[%d];",
adjoint[i],
match(adjoint[i], dat$storage$packing$adjoint$scalar) - 1))

eqs <- dat$adjoint$update$equations
for (eq in c(dat$equations[eqs], dat$adjoint$update$adjoint)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_next")
rhs <- generate_dust_sexp(eq$rhs$expr, dat$sexp_data)
body$add(sprintf("%s = %s;", lhs, rhs))
}

cpp_function("void", "adjoint_update", args, body$get(), static = TRUE)
}


generate_dust_system_adjoint_compare_data <- function(dat) {
args <- c("real_type" = "time",
"real_type" = "dt",
"const real_type*" = "state",
"const real_type*" = "adjoint",
"const data_type&" = "data",
"const shared_state&" = "shared",
"internal_state&" = "internal",
"real_type*" = "adjoint_next")
body <- collector()

variables <- dat$storage$contents$variables
i <- variables %in% dat$adjoint$compare$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$state$scalar) - 1))

adjoint <- dat$storage$contents$adjoint
i <- adjoint %in% dat$adjoint$compare$unpack_adjoint
body$add(sprintf("const auto %s = adjoint[%d];",
adjoint[i],
match(adjoint[i], dat$storage$packing$adjoint$scalar) - 1))

eqs <- dat$adjoint$compare$equations
for (eq in c(dat$equations[eqs], dat$adjoint$compare$adjoint)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_next")
rhs <- generate_dust_sexp(eq$rhs$expr, dat$sexp_data)
body$add(sprintf("%s = %s;", lhs, rhs))
}

cpp_function("void", "adjoint_compare_data", args, body$get(), static = TRUE)
}


generate_dust_system_adjoint_initial <- function(dat) {
args <- c("real_type" = "time",
"real_type" = "dt",
"const real_type*" = "state",
"const real_type*" = "adjoint",
"const shared_state&" = "shared",
"internal_state&" = "internal",
"real_type*" = "adjoint_next")
body <- collector()

variables <- dat$storage$contents$variables
i <- variables %in% dat$adjoint$initial$unpack
body$add(sprintf("const auto %s = state[%d];",
variables[i],
match(variables[i], dat$storage$packing$state$scalar) - 1))

adjoint <- dat$storage$contents$adjoint
i <- adjoint %in% dat$adjoint$initial$unpack_adjoint
body$add(sprintf("const auto %s = adjoint[%d];",
adjoint[i],
match(adjoint[i], dat$storage$packing$adjoint$scalar) - 1))

eqs <- dat$adjoint$initial$equations
for (eq in c(dat$equations[eqs], dat$adjoint$initial$adjoint)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_next")
rhs <- generate_dust_sexp(eq$rhs$expr, dat$sexp_data)
body$add(sprintf("%s = %s;", lhs, rhs))
}

cpp_function("void", "adjoint_initial", args, body$get(), static = TRUE)
}


generate_dust_lhs <- function(name, dat, name_state) {
location <- dat$storage$location[[name]]
if (location == "stack") {
Expand All @@ -282,7 +409,11 @@ generate_dust_lhs <- function(name, dat, name_state) {
## TODO: this wil need to change, once we support arrays: at that
## point we'll make this more efficient too by computing
## expressions for access.
sprintf("%s[%s]", name_state, match(name, dat$storage$packing$scalar) - 1)
sprintf("%s[%s]",
name_state, match(name, dat$storage$packing$state$scalar) - 1)
} else if (location == "adjoint") {
sprintf("%s[%s]",
"adjoint_next", match(name, dat$storage$packing$adjoint$scalar) - 1)
} else {
stop("Unsupported location")
}
Expand Down
16 changes: 14 additions & 2 deletions R/generate_dust_sexp.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
ret <- sprintf("-%s", args[[1]])
} else if (n == 1 && fn == "+") {
ret <- args[[1]]
} else if (n == 2 && fn %in% c("+", "-", "*", "/")) {
} else if (n == 2 && fn %in% c("+", "-", "*", "/", "==")) {
## Some care might be needed for division in some cases.
ret <- sprintf("%s %s %s", args[[1]], fn, args[[2]])
} else if (fn %in% "exp") {
ret <- sprintf("mcstate::math::%s(%s)",
fn, paste(args, collapse = ", "))
} else if (fn == "%%") {
## TODO: we'll use our usual fmodr thing here once we get that
## into mcstate's math library, but for now this is ok.
ret <- sprintf("std::fmod(%s, %s)", args[[1]], args[[2]])
} else if (fn == "if") {
## NOTE: The ternary operator has very low precendence, so we
## will agressively parenthesise it. This is strictly not
## needed when this expression is the only element of `expr` but
## that's hard to detect so we'll tolerate a few additional
## parens for now.
ret <- sprintf("(%s ? %s : %s)", args[[1L]], args[[2L]], args[[3L]])
} else {
## TODO: we should catch this during parse; erroring here is a
## bug as we don't offer context.
Expand All @@ -29,7 +40,8 @@ generate_dust_sexp <- function(expr, dat, options = list()) {
} else {
location <- dat$location[[name]]
shared_exists <- !isFALSE(options$shared_exists)
if (location %in% c("state", "stack", if (!shared_exists) "shared")) {
if (location %in% c("state", "stack", "adjoint",
if (!shared_exists) "shared")) {
ret <- name
} else { # shared, internal, data
ret <- sprintf("%s.%s", location, name)
Expand Down
1 change: 1 addition & 0 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ odin_parse_quo <- function(quo, input_type, compatibility, call) {
data = system$data)

parse_check_usage(exprs, ret, call)
ret <- parse_adjoint(ret)

ret
}
Expand Down
Loading

0 comments on commit 59bda99

Please sign in to comment.