Skip to content

Commit

Permalink
Merge pull request #118 from mrc-ide/mrc-6012
Browse files Browse the repository at this point in the history
Fix two array bugs
  • Loading branch information
weshinsley authored Nov 14, 2024
2 parents 51641bd + c4897d8 commit bbae425
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 17 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin2
Title: Next generation odin
Version: 0.2.12
Version: 0.2.13
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
4 changes: 3 additions & 1 deletion R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ generate_dust_system_packing_gradient <- function(dat) {

generate_dust_system_packing <- function(name, dat) {
packing <- dat$storage$packing[[name]]
arrays <- dat$storage$arrays
args <- c("const shared_state&" = "shared")
fmt <- "std::vector<size_t>(shared.dim.%s.dim.begin(), shared.dim.%s.dim.end())"
dim_name <- arrays$alias[match(packing$name, arrays$name)]
dims <- ifelse(packing$rank == 0, "{}",
sprintf(fmt, packing$name, packing$name))
sprintf(fmt, dim_name, dim_name))
els <- sprintf('{"%s", %s}', packing$name, dims)
## trailing comma if needed
els[-length(els)] <- sprintf("%s,", els[-length(els)])
Expand Down
1 change: 1 addition & 0 deletions R/parse_system.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ parse_system_overall <- function(exprs, call) {
check_duplicate_dims(arrays, exprs, call)
arrays <- resolve_array_references(arrays)
arrays <- resolve_split_dependencies(arrays, call)
exprs <- add_alias_dependency(exprs, arrays)

parameters <- parse_system_overall_parameters(exprs, arrays)
data <- data_frame(
Expand Down
24 changes: 24 additions & 0 deletions R/parse_system_arrays.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,27 @@ resolve_split_dependencies <- function(arrays, call) {

arrays
}


add_alias_dependency <- function(exprs, arrays) {
is_alias <- arrays$alias != arrays$name
if (!any(is_alias)) {
return(exprs)
}

remap <- set_names(
odin_dim_name(arrays$alias[is_alias]),
odin_dim_name(arrays$name[is_alias]))

update_alias_dependency <- function(eq) {
i <- eq$rhs$depends$variables %in% names(remap)
if (any(i)) {
eq$rhs$depends$variables <- union(
eq$rhs$depends$variables,
unname(remap[eq$rhs$depends$variables[i]]))
}
eq
}

lapply(exprs, update_alias_dependency)
}
5 changes: 2 additions & 3 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ counter <- function() {
}


odin_dim_name <- function(n) {
assert_scalar_character(n)
sprintf("dim_%s", n)
odin_dim_name <- function(name) {
sprintf("dim_%s", name)
}
47 changes: 35 additions & 12 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -1022,26 +1022,26 @@ test_that("can generate system with aliased arrays dimmed together", {
" size_t x;",
" } state;",
" } offset;",
" std::vector<real_type> b;",
" std::vector<real_type> a;",
" std::vector<real_type> b;",
"};"))

expect_equal(
generate_dust_system_build_shared(dat),
c(method_args$build_shared,
" shared_state::dim_type dim;",
" dim.a.set({static_cast<size_t>(1)});",
" std::vector<real_type> b(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" b[i - 1] = 2;",
" }",
" std::vector<real_type> a(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" a[i - 1] = 1;",
" }",
" std::vector<real_type> b(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" b[i - 1] = 2;",
" }",
" shared_state::offset_type offset;",
" offset.state.x = 0;",
" return shared_state{dim, offset, b, a};",
" return shared_state{dim, offset, a, b};",
"}"))
})

Expand Down Expand Up @@ -1120,8 +1120,8 @@ test_that("can generate system with dependent-arrays dimmed together", {
" size_t x;",
" } state;",
" } offset;",
" std::vector<real_type> c;",
" std::vector<real_type> a;",
" std::vector<real_type> c;",
" std::vector<real_type> b;",
"};"))

Expand All @@ -1130,21 +1130,21 @@ test_that("can generate system with dependent-arrays dimmed together", {
c(method_args$build_shared,
" shared_state::dim_type dim;",
" dim.a.set({static_cast<size_t>(1)});",
" std::vector<real_type> c(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" c[i - 1] = 3;",
" }",
" std::vector<real_type> a(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" a[i - 1] = 1;",
" }",
" std::vector<real_type> c(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" c[i - 1] = 3;",
" }",
" std::vector<real_type> b(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" b[i - 1] = 2;",
" }",
" shared_state::offset_type offset;",
" offset.state.x = 0;",
" return shared_state{dim, offset, c, a, b};",
" return shared_state{dim, offset, a, c, b};",
"}"))
})

Expand Down Expand Up @@ -2568,3 +2568,26 @@ test_that("Can read parameters into var with aliased dimension", {
" }",
"}"))
})


test_that("correct packing with aliased arrays", {
dat <- odin_parse({
dim(a, b) <- c(x, y)
x <- 2
y <- 3
initial(a[, ]) <- 0
initial(b[, ]) <- 0
update(a[, ]) <- 1
update(b[, ]) <- 1
})
dat <- generate_prepare(dat)

expect_equal(
generate_dust_system_packing_state(dat),
c(method_args$packing_state,
" return dust2::packing{",
' {"a", std::vector<size_t>(shared.dim.a.dim.begin(), shared.dim.a.dim.end())},',
' {"b", std::vector<size_t>(shared.dim.a.dim.begin(), shared.dim.a.dim.end())}',
" };",
"}"))
})
31 changes: 31 additions & 0 deletions tests/testthat/test-parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -1143,3 +1143,34 @@ test_that("can't use browser twice in the same phase", {
"Multiple calls to 'browser()' in phase 'update'",
fixed = TRUE)
})


test_that("correctly resolve dependency order with aliased parameter dims", {
dat <- odin_parse({
deriv(S) <- -beta * I * S / N + gamma * I
deriv(I) <- beta * I * S / N - gamma * I
initial(S) <- N - I0
initial(I) <- I0
I0 <- parameter(10)
N <- parameter(1000)
beta0 <- parameter(0.2)
schools <- interpolate(schools_time, schools_open, "constant")
schools_time <- parameter(constant = TRUE)
schools_open <- parameter(constant = TRUE)
dim(schools_time, schools_open) <- parameter(rank = 1)
schools_modifier <- parameter(0.6)
beta <- ((1 - schools) * (1 - schools_modifier) + schools) * beta0
gamma <- 0.1
})
## schools_time is the real one:
expect_equal(
dat$storage$arrays[c("name", "alias")],
data_frame(name = c("schools_time", "schools_open"),
alias = "schools_time"))
## dim_schools_time resoplved before schools_time, schools_time
## before schools_open:
expect_equal(
dat$phases$build_shared$equations,
c("I0", "N", "beta0", "dim_schools_time", "schools_modifier",
"gamma", "schools_time", "schools_open", "interpolate_schools"))
})

0 comments on commit bbae425

Please sign in to comment.