Skip to content

Commit

Permalink
Merge pull request #110 from mrc-ide/mrc-5723-3
Browse files Browse the repository at this point in the history
Allow dim(a, b, c) <- ...
  • Loading branch information
richfitz authored Nov 14, 2024
2 parents ae22c51 + abbc878 commit aa49363
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 51 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin2
Title: Next generation odin
Version: 0.2.9
Version: 0.2.10
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
21 changes: 20 additions & 1 deletion R/parse_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ parse_expr_assignment_lhs <- function(lhs, src, call) {
update = function(name) NULL,
deriv = function(name) NULL,
output = function(name) NULL,
dim = function(name) NULL,
dim = function(...) NULL,
config = function(name) NULL)

args <- NULL
Expand All @@ -153,6 +153,25 @@ parse_expr_assignment_lhs <- function(lhs, src, call) {
"E1003", src, call)
}

if (special == "dim") {
if (length(lhs) < 2) {
odin_parse_error(c("Invalid call to dim function; no variables given"),
"E1003", src, call)
}
lhs <- vcapply(lhs[-1], function(x) {
if (!is.symbol(x)) {
odin_parse_error("Invalid target '{x}' in dim declaration",
"E1005", src, call)
}
deparse(x)
})

return(list(
name = lhs[1],
names = lhs,
special = special))
}

lhs <- lhs[[2]]
if (length(m$value) > 2) {
args <- as.list(m$value[-(1:2)])
Expand Down
54 changes: 6 additions & 48 deletions R/parse_system.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,10 @@ parse_system_overall <- function(exprs, call) {
output <- NULL
}

dims <- lapply(exprs[is_dim], function(x) x$rhs$value)
names <- vcapply(exprs[is_dim], function(x) x$lhs$name_data)
sizes <- lapply(dims, function(x) {
if (rlang::is_call(x, "dim")) {
1
} else {
expr_prod(x)
}
})
arrays <- resolve_array_references(data_frame(
name = names,
rank = lengths(dims),
dims = I(dims),
size = I(sizes)))
arrays <- build_array_table(exprs[is_dim], call)
check_duplicate_dims(arrays, exprs, call)
arrays <- resolve_array_references(arrays)
arrays <- resolve_split_dependencies(arrays, call)

parameters <- parse_system_overall_parameters(exprs, arrays)
data <- data_frame(
Expand Down Expand Up @@ -233,38 +223,6 @@ parse_system_overall <- function(exprs, call) {
exprs = exprs)
}


resolve_array_references <- function(arrays) {
lookup_array <- function(name, copy_from, d) {
i <- which(d$name == copy_from)
if (length(i) == 0) {
return(NULL)
}
dim_i <- d$dims[i]
if (rlang::is_call(dim_i[[1]], "dim")) {
rhs_dim_var <- deparse(dim_i[[1]][[2]])
return(lookup_array(name, rhs_dim_var, d[-i, ]))
}
return(list(rank = d$rank[i], alias = d$name[i]))
}

arrays$alias <- arrays$name
is_ref <- vlapply(arrays$dims, rlang::is_call, "dim")

for (i in which(is_ref)) {
lhs_dim_var <- arrays$name[i]
rhs_dim_var <- deparse(arrays$dims[i][[1]][[2]])
res <- lookup_array(lhs_dim_var, rhs_dim_var, arrays[-i, ])
if (!is.null(res)) {
arrays$dims[i] <- list(NULL)
arrays$size[i] <- NA_integer_
arrays$rank[i] <- res$rank
arrays$alias[i] <- res$alias
}
}
arrays
}

parse_system_depends <- function(equations, variables, call) {
automatic <- c("time", "dt")
implicit <- c(variables, automatic)
Expand Down Expand Up @@ -561,7 +519,7 @@ parse_zero_every <- function(time, phases, equations, variables, call) {
parse_system_arrays <- function(exprs, call) {
is_dim <- vlapply(exprs, function(x) identical(x$special, "dim"))

dim_nms <- vcapply(exprs[is_dim], function(x) x$lhs$name_data)
dim_nms <- unlist(lapply(exprs[is_dim], function(x) x$lhs$names))

## First, look for any array calls that do not have a corresponding
## dim()
Expand Down Expand Up @@ -598,7 +556,7 @@ parse_system_arrays <- function(exprs, call) {


name_dim_equation <- set_names(
vcapply(exprs[is_dim], function(eq) eq$lhs$name),
unlist(lapply(dim_nms, odin_dim_name)),
dim_nms)

is_array_assignment <- is_array | (nms %in% dim_nms)
Expand Down
112 changes: 112 additions & 0 deletions R/parse_system_arrays.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
build_array_table <- function(exprs, call) {
dims <- list()
names <- list()
sizes <- list()
n <- 1
for (expr in exprs) {
names_i <- expr$lhs$names
dims_i <- expr$rhs$value
if (rlang::is_call(dims_i, "dim")) {
size_i <- 1
} else {
size_i <- expr_prod(dims_i)
if (is.null(size_i)) {
size_i <- list(NULL)
}
}

first_dim <- call("dim", as.symbol(expr$lhs$names[1]))

for (j in seq_along(names_i)) {
dims[[n]] <- if (j == 1) dims_i else first_dim
names[[n]] <- names_i[j]
sizes[[n]] <- size_i
n <- n + 1
}
}

data_frame(
name = unlist(names),
rank = lengths(dims),
dims = I(dims),
size = I(sizes))
}

check_duplicate_dims <- function(arrays, exprs, call) {
throw_duplicate_dim <- function(name, src) {
odin_parse_error(
paste("The variable {name} was given dimensions multiple times."),
"E2021", src, call)
}

names <- unlist(arrays$name)
if (any(duplicated(names))) {
dup_dim <- unique(names[duplicated(names)])[1]
lines <- vlapply(exprs, function(x) {
isTRUE(x$special == "dim" &
dup_dim %in% c(x$lhs$names))
})
srcs <- lapply(exprs[lines], "[[", "src")
throw_duplicate_dim(dup_dim, srcs)
}
}


resolve_array_references <- function(arrays) {
lookup_array <- function(name, copy_from, d) {
i <- which(d$name == copy_from)
if (length(i) == 0) {
return(NULL)
}
dim_i <- d$dims[i]
if (rlang::is_call(dim_i[[1]], "dim")) {
rhs_dim_var <- deparse(dim_i[[1]][[2]])
return(lookup_array(name, rhs_dim_var, d[-i, ]))
}
return(list(rank = d$rank[i], alias = d$name[i]))
}

arrays$alias <- arrays$name
is_ref <- vlapply(arrays$dims, rlang::is_call, "dim")

for (i in which(is_ref)) {
lhs_dim_var <- arrays$name[i]
rhs_dim_var <- deparse(arrays$dims[i][[1]][[2]])
res <- lookup_array(lhs_dim_var, rhs_dim_var, arrays[-i, ])
if (!is.null(res)) {
arrays$dims[i] <- list(NULL)
arrays$size[i] <- NA_integer_
arrays$rank[i] <- res$rank
arrays$alias[i] <- res$alias
}
}

arrays
}

resolve_split_dependencies <- function(arrays, call) {
# Resolve case where
# dim(a) <- 1
# dim(b, c) <- dim(a)
# At this point, dim(c) will be aliased to dim(b), not dim(a),
# so find aliases that actually point to other aliases, and
# resolve them to something that is not an alias.

find_non_alias <- function(current, original = current, visited = NULL) {
stopifnot(!any(duplicated(visited)))

array <- arrays[arrays$name == current, ]
if (array$alias == array$name) {
return(array$alias)
}
find_non_alias(array$alias, c(visited, original, array$name))
}

not_aliased <- arrays$name[arrays$name == arrays$alias]
wrong <- arrays$name != arrays$alias & !(arrays$alias %in% not_aliased)
for (i in which(wrong)) {
arrays$alias[i] <- find_non_alias(arrays$alias[i], arrays$name[i])
}

arrays
}
Binary file modified R/sysdata.rda
Binary file not shown.
148 changes: 148 additions & 0 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,154 @@ test_that("can generate system with aliased array", {
"}"))
})

test_that("can generate system with aliased arrays dimmed together", {
dat <- odin_parse({
initial(x) <- 0
update(x) <- x + a[1] + b[1]
dim(a, b) <- 1
a[] <- 1
b[] <- 2
})
dat <- generate_prepare(dat)

expect_equal(
generate_dust_system_shared_state(dat),
c("struct shared_state {",
" struct dim_type {",
" dust2::array::dimensions<1> a;",
" } dim;",
" struct offset_type {",
" struct {",
" size_t x;",
" } state;",
" } offset;",
" std::vector<real_type> b;",
" std::vector<real_type> a;",
"};"))

expect_equal(
generate_dust_system_build_shared(dat),
c(method_args$build_shared,
" shared_state::dim_type dim;",
" dim.a.set({static_cast<size_t>(1)});",
" std::vector<real_type> b(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" b[i - 1] = 2;",
" }",
" std::vector<real_type> a(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" a[i - 1] = 1;",
" }",
" shared_state::offset_type offset;",
" offset.state.x = 0;",
" return shared_state{dim, offset, b, a};",
"}"))
})


test_that("can generate system with 2 aliased arrays dimmed together", {
dat <- odin_parse({
initial(x) <- 0
update(x) <- x + a[1] + b[1] + c[1]
dim(c, a, b) <- 1
a[] <- 1
b[] <- 2
c[] <- 3
})
dat <- generate_prepare(dat)

expect_equal(
generate_dust_system_shared_state(dat),
c("struct shared_state {",
" struct dim_type {",
" dust2::array::dimensions<1> c;",
" } dim;",
" struct offset_type {",
" struct {",
" size_t x;",
" } state;",
" } offset;",
" std::vector<real_type> a;",
" std::vector<real_type> b;",
" std::vector<real_type> c;",
"};"))

expect_equal(
generate_dust_system_build_shared(dat),
c(method_args$build_shared,
" shared_state::dim_type dim;",
" dim.c.set({static_cast<size_t>(1)});",
" std::vector<real_type> a(dim.c.size);",
" for (size_t i = 1; i <= dim.c.size; ++i) {",
" a[i - 1] = 1;",
" }",
" std::vector<real_type> b(dim.c.size);",
" for (size_t i = 1; i <= dim.c.size; ++i) {",
" b[i - 1] = 2;",
" }",
" std::vector<real_type> c(dim.c.size);",
" for (size_t i = 1; i <= dim.c.size; ++i) {",
" c[i - 1] = 3;",
" }",
" shared_state::offset_type offset;",
" offset.state.x = 0;",
" return shared_state{dim, offset, a, b, c};",
"}"))
})


test_that("can generate system with dependent-arrays dimmed together", {
dat <- odin_parse({
update(x) <- sum(a) + sum(b) + sum(c)
initial(x) <- 0
dim(a) <- 1
dim(b, c) <- dim(a)
a[] <- 1
b[] <- 2
c[] <- 3
})
dat <- generate_prepare(dat)

expect_equal(
generate_dust_system_shared_state(dat),
c("struct shared_state {",
" struct dim_type {",
" dust2::array::dimensions<1> a;",
" } dim;",
" struct offset_type {",
" struct {",
" size_t x;",
" } state;",
" } offset;",
" std::vector<real_type> c;",
" std::vector<real_type> a;",
" std::vector<real_type> b;",
"};"))

expect_equal(
generate_dust_system_build_shared(dat),
c(method_args$build_shared,
" shared_state::dim_type dim;",
" dim.a.set({static_cast<size_t>(1)});",
" std::vector<real_type> c(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" c[i - 1] = 3;",
" }",
" std::vector<real_type> a(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" a[i - 1] = 1;",
" }",
" std::vector<real_type> b(dim.a.size);",
" for (size_t i = 1; i <= dim.a.size; ++i) {",
" b[i - 1] = 2;",
" }",
" shared_state::offset_type offset;",
" offset.state.x = 0;",
" return shared_state{dim, offset, c, a, b};",
"}"))
})



test_that("can generate system with length and sum of aliased array", {
dat <- odin_parse({
Expand Down
Loading

0 comments on commit aa49363

Please sign in to comment.