Skip to content

Commit

Permalink
Compile a trivial model
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 31, 2024
1 parent 132b781 commit 5f0b00b
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ Imports:
cli,
rlang
Suggests:
dust2@mrc-5356,
mrc-ide/dust2@mrc-5356,
testthat (>= 3.0.0)
Config/testthat/edition: 3
136 changes: 136 additions & 0 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
generate_dust_model <- function(dat) {
core <- generate_dust_model_core(dat)
body <- collector()
body$add("#include <dust2/common.hpp>")
body$add(sprintf("// [[dust2::class(%s)]]", core$class))
body$add(sprintf("class %s {", core$class))
body$add("public:")
body$add(sprintf(" %s() = delete;", core$class))
body$add(" using real_type = double;")
body$add(" using data_type = dust2::no_data;")
body$add(" using rng_state_type = mcstate::random::generator<real_type>;")
body$add(paste0(" ", core$shared_state))
body$add(paste0(" ", core$internal_state))
body$add(paste0(" ", core$size))
body$add(paste0(" ", core$build_shared))
body$add(paste0(" ", core$build_internal))
body$add(paste0(" ", core$update_shared))
body$add(paste0(" ", core$update_internal))
body$add(paste0(" ", core$initial))
body$add(paste0(" ", core$update))
body$add("};")
body$get()
}


generate_dust_model_core <- function(dat) {
list(class = dat$class,
shared_state = generate_dust_model_core_shared_state(dat),
internal_state = generate_dust_model_core_internal_state(dat),
size = generate_dust_model_core_size(dat),
build_shared = generate_dust_model_core_build_shared(dat),
build_internal = generate_dust_model_core_build_internal(dat),
update_shared = generate_dust_model_core_update_shared(dat),
update_internal = generate_dust_model_core_update_internal(dat),
initial = generate_dust_model_core_initial(dat),
update = generate_dust_model_core_update(dat))
}


generate_dust_model_core_shared_state <- function(dat) {
nms <- dat$location$contents$shared
type <- dat$location$type[nms]
c("struct shared_state {",
sprintf(" %s %s;", type, nms),
"};")
}


generate_dust_model_core_internal_state <- function(dat) {
"struct internal_state {};"
}


generate_dust_model_core_size <- function(dat) {
args <- c("const shared_state&" = "shared")
body <- sprintf("return %d;", length(dat$location$contents$variables))
cpp_function("size_t", "size", args, body, static = TRUE)
}


generate_dust_model_core_build_shared <- function(dat) {
eqs <- dat$phases$create_shared$equations
body <- collector()
for (eq in dat$equations[eqs]) {
lhs <- eq$lhs$name
rhs <- generate_dust_sexp(eq$rhs$expr, dat)
body$add(sprintf("real_type %s = %s;", lhs, rhs))
}
body$add(sprintf("return shared_state{%s};", paste(eqs, collapse = ", ")))
args <- c("cpp11::list" = "parameters")
cpp_function("shared_state", "build_shared", args, body$get(), static = TRUE)
}


generate_dust_model_core_build_internal <- function(dat) {
args <- c("const shared_state&" = "shared")
body <- "return internal_state{};"
cpp_function("internal_state", "build_internal", args, body, static = TRUE)
}


generate_dust_model_core_update_shared <- function(dat) {
args <- c("cpp11::list" = "pars", "shared_state&" = "shared")
body <- character()
cpp_function("void", "update_shared", args, body, static = TRUE)
}


generate_dust_model_core_update_internal <- function(dat) {
args <- c("const shared_state&" = "shared", "internal_state&" = "internal")
body <- character()
cpp_function("void", "update_internal", args, body, static = TRUE)
}


generate_dust_model_core_initial <- function(dat) {
args <- c("real_type" = "time",
"real_type" = "dt",
"const shared_state&" = "shared",
"internal_state&" = "internal",
"rng_state_type&" = "rng_state",
"real_type*" = "state")
body <- collector()
eqs <- dat$phases$initial$equations
for (eq in c(dat$equations[eqs], dat$phases$initial$variables)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state")
rhs <- generate_dust_sexp(eq$rhs$expr, dat)
body$add(sprintf("%s = %s;", lhs, rhs))
}
cpp_function("void", "initial", args, body$get(), static = TRUE)
}


generate_dust_model_core_update <- function(dat) {
args <- c("real_type" = "time",
"real_type" = "dt",
"const real_type*" = "state",
"const shared_state&" = "shared",
"internal_state&" = "internal",
"rng_state_type&" = "rng_state",
"real_type*" = "state_next")
body <- collector()
variables <- dat$location$contents$variables
packing <- dat$location$packing$state
i <- variables %in% dat$phases$update$unpack
## TODO: this will get changed, and also reused.
body$add(sprintf("const auto %s = state[%d];",
variables[i], unlist(packing[i])))
eqs <- dat$phases$update$equations
for (eq in c(dat$equations[eqs], dat$phases$update$variables)) {
lhs <- generate_dust_lhs(eq$lhs$name, dat, "state_next")
rhs <- generate_dust_sexp(eq$rhs$expr, dat)
body$add(sprintf("%s = %s;", lhs, rhs))
}
cpp_function("void", "update", args, body$get(), static = TRUE)
}
55 changes: 55 additions & 0 deletions R/generate_dust_sexp.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
generate_dust_sexp <- function(expr, dat) {
if (is.recursive(expr)) {
fn <- as.character(expr[[1]])
args <- expr[-1]
n <- length(args)
values <- vcapply(args, generate_dust_sexp, dat)

if (fn == "(") {
ret <- sprintf("(%s)", values[[1]])
} else if (n == 1 && fn == "-") {
ret <- sprintf("-%s", values[[1]])
} else if (n == 1 && fn == "+") {
ret <- values[[1]]
} else if (n == 2 && fn %in% c("+", "-", "*", "/")) {
ret <- sprintf("%s %s %s", values[[1]], fn, values[[2]])
} else {
## TODO: we should catch this elsewhere.
stop("Unhandled function")
}
} else if (is.symbol(expr)) {
name <- as.character(expr)
location <- dat$location$location[[name]]
if (location %in% c("state", "stack")) {
ret <- name
} else if (location == "shared") {
ret <- sprintf("shared.%s", name)
} else {
stop("Unhandled location")
}
} else if (is.numeric(expr)) {
if (expr %% 1 == 0) {
ret <- format(expr)
} else {
ret <- sprintf("static_cast<real_type>(%s)",
deparse(expr, control = "digits17"))
}
} else if (is.logical(expr)) {
ret <- tolower(expr)
} else {
stop("Unhandled data type")
}
ret
}


generate_dust_lhs <- function(name, dat, name_state = "state") {
location <- dat$location$location[[name]]
if (location == "stack") {
sprintf("const %s %s", dat$location$type[[name]], name)
} else if (location == "state") {
sprintf("%s[%s]", name_state, dat$location$packing$state[[name]])
} else {
stop("Unsupported location")
}
}
3 changes: 3 additions & 0 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ odin_parse <- function(expr, input = NULL) {
exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call))
system <- parse_system(exprs, call)
ret <- parse_depends(system, call)
## This changes immensely once we have arrays as we need to work
## with a more flexible packing structure. For now we just cheat
## and assume variables are packed in order as they are all scalars.
ret
}

Expand Down
28 changes: 24 additions & 4 deletions R/parse_depends.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ parse_depends <- function(system, call) {
location <- parse_depend_location(equations, system$variables)

list(time = system$time,
class = "odin",
location = location,
phases = phases,
equations = equations)
Expand Down Expand Up @@ -87,7 +88,7 @@ parse_depend_equations <- function(equations, implicit) {

## This is going to be the grossest bit I think.
parse_depend_phases <- function(exprs, equations, variables) {
phases <- list(create_shared = list(equations = character()))
phases <- list(build_shared = list(equations = character()))
used <- character()

stage <- vcapply(equations, function(x) x$stage)
Expand All @@ -109,8 +110,8 @@ parse_depend_phases <- function(exprs, equations, variables) {

## These bits will change with user variables, and where we
## support parameters and updates. This is fine for now though.
phases$create_shared$equations <-
union(phases$create_shared$equations,
phases$build_shared$equations <-
union(phases$build_shared$equations,
eqs[stage[eqs] == "constant"])

if (phase == "update") { # also deriv
Expand All @@ -129,17 +130,36 @@ parse_depend_phases <- function(exprs, equations, variables) {
}
}

## Reorder following the dependency graph:
phases$build_shared$equations <- intersect(
names(equations),
phases$build_shared$equations)

phases
}


parse_depend_location <- function(equations, variables) {
stage <- vcapply(equations, "[[", "stage")
list(

contents <- list(
variables = variables,
shared = names(stage)[stage == "constant"],
internal = character(),
data = character(),
output = character(),
stack = names(stage)[stage == "time"])
location <- set_names(rep(names(contents), lengths(contents)),
unlist(contents, FALSE, TRUE))
location[location == "variables"] <- "state"
type <- set_names(rep("real_type", length(location)), names(location))
## This will change soon, as we'll need more flexibility with
## arrays, and output, and adjoints.
packing <- list(
state = set_names(as.list(seq_along(variables) - 1), variables))

list(contents = contents,
location = location,
type = type,
packing = packing)
}
21 changes: 21 additions & 0 deletions R/util_cpp.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cpp_function <- function (return_type, name, args, body, static = FALSE) {
c(cpp_args(return_type, name, args, static = static),
paste0(" ", body),
"}")
}


cpp_args <- function(return_type, name, args, static = FALSE) {
static_str <- if (static) "static " else ""
args_str <- paste(sprintf("%s %s", names(args), unname(args)),
collapse = ", ")
sprintf("%s%s %s(%s) {",
static_str, return_type, name, args_str)
}


cpp_block <- function(body) {
c("{",
paste0(" ", body),
"}")
}

0 comments on commit 5f0b00b

Please sign in to comment.