Skip to content

Commit

Permalink
Issue #386: Add direct model allowing pass through to brms (#393)
Browse files Browse the repository at this point in the history
* Template for functions required by direct model

* Template for direct_model tests

* First try at complete version of direct model (without documentation)

* Document and lint, plus move default formula to delay ~ .

* Update pkgdown

* Document following merge
  • Loading branch information
athowes authored Oct 21, 2024
1 parent 03609f9 commit a9b339a
Show file tree
Hide file tree
Showing 14 changed files with 249 additions and 1 deletion.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(add_mean_sd,default)
S3method(add_mean_sd,gamma_samples)
S3method(add_mean_sd,lognormal_samples)
S3method(as_direct_model,data.frame)
S3method(as_latent_individual,data.frame)
S3method(epidist,default)
S3method(epidist_family_model,default)
Expand All @@ -17,9 +18,11 @@ S3method(epidist_model_prior,default)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_individual)
S3method(epidist_validate,default)
S3method(epidist_validate,epidist_direct_model)
S3method(epidist_validate,epidist_latent_individual)
export(add_event_vars)
export(add_mean_sd)
export(as_direct_model)
export(as_latent_individual)
export(epidist)
export(epidist_diagnostics)
Expand All @@ -35,6 +38,7 @@ export(epidist_stancode)
export(epidist_validate)
export(filter_obs_by_obs_time)
export(filter_obs_by_ptime)
export(is_direct_model)
export(is_latent_individual)
export(observe_process)
export(predict_delay_parameters)
Expand Down
68 changes: 68 additions & 0 deletions R/direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#' Prepare direct model to pass through to `brms`
#'
#' @param data A `data.frame` containing line list data
#' @family direct_model
#' @export
as_direct_model <- function(data) {
UseMethod("as_direct_model")
}

assert_direct_model_input <- function(data) {
assert_data_frame(data)
assert_names(names(data), must.include = c("case", "ptime", "stime"))
assert_integer(data$case, lower = 0)
assert_numeric(data$ptime, lower = 0)
assert_numeric(data$stime, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the direct model. It does this by
#' adding columns used in the model to the `data` object provided. To do this,
#' the `data` must already have columns for the case number (integer),
#' (positive, numeric) times for the primary and secondary event times. The
#' output of this function is a `epidist_direct_model` class object, which may
#' be passed to [epidist()] to perform inference for the model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_direct_model
#' @method as_direct_model data.frame
#' @family direct_model
#' @autoglobal
#' @export
as_direct_model.data.frame <- function(data) {
assert_direct_model_input(data)
class(data) <- c("epidist_direct_model", class(data))
data <- data |>
mutate(delay = .data$stime - .data$ptime)
epidist_validate(data)
return(data)
}

#' Validate direct model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the direct model. As well as making sure that
#' `is_direct_model()` is true, it also checks that `data` is a `data.frame`
#' with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_direct_model
#' @family direct_model
#' @export
epidist_validate.epidist_direct_model <- function(data, ...) {
assert_true(is_direct_model(data))
assert_direct_model_input(data)
assert_names(names(data), must.include = c("case", "ptime", "stime", "delay"))
assert_numeric(data$delay, lower = 0)
}

#' Check if data has the `epidist_direct_model` class
#'
#' @param data A `data.frame` containing line list data
#' @family latent_individual
#' @export
is_direct_model <- function(data) {
inherits(data, "epidist_direct_model")
}
3 changes: 3 additions & 0 deletions R/formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,8 @@ epidist_formula_model <- function(data, formula, ...) {
#' @family formula
#' @export
epidist_formula_model.default <- function(data, formula, ...) {
formula <- stats::update(
formula, delay ~ .
)
return(formula)
}
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ reference:
desc: Specific methods for the latent individual model
contents:
- has_concept("latent_individual")
- title: Direct model
desc: Specific methods for the direct model
contents:
- has_concept("direct_model")
- title: Postprocess
desc: Functions for postprocessing model output
contents:
Expand Down
30 changes: 30 additions & 0 deletions man/as_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/as_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_family_model.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_formula_model.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/epidist_validate.epidist_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_validate.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/is_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/is_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions tests/testthat/test-direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
test_that("as_direct_model.data.frame with default settings an object with the correct classes", { # nolint: line_length_linter.
prep_obs <- as_direct_model(sim_obs)
expect_s3_class(prep_obs, "data.frame")
expect_s3_class(prep_obs, "epidist_direct_model")
})

test_that("as_direct_model.data.frame errors when passed incorrect inputs", { # nolint: line_length_linter.
expect_error(as_direct_model(list()))
expect_error(as_direct_model(sim_obs[, 1]))
expect_error({
sim_obs$case <- paste("case_", seq_len(nrow(sim_obs)))
as_direct_model(sim_obs)
})
})

# Make this data available for other tests
prep_obs <- as_direct_model(sim_obs)
family_lognormal <- epidist_family(prep_obs, family = brms::lognormal())

test_that("is_direct_model returns TRUE for correct input", { # nolint: line_length_linter.
expect_true(is_direct_model(prep_obs))
expect_true({
x <- list()
class(x) <- "epidist_direct_model"
is_direct_model(x)
})
})

test_that("is_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_false(is_direct_model(list()))
expect_false({
x <- list()
class(x) <- "epidist_direct_model_extension"
is_direct_model(x)
})
})

test_that("epidist_validate.epidist_direct_model doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate(prep_obs))
})

test_that("epidist_validate.epidist_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_error(epidist_validate(list()))
expect_error(epidist_validate(prep_obs[, 1]))
expect_error({
x <- list()
class(x) <- "epidist_direct_model"
epidist_validate(x)
})
})
36 changes: 36 additions & 0 deletions tests/testthat/test-int-direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Note: some tests in this script are stochastic. As such, test failure may be
# bad luck rather than indicate an issue with the code. However, as these tests
# are reproducible, the distribution of test failures may be investigated by
# varying the input seed. Test failure at an unusually high rate does suggest
# a potential code issue.

prep_obs <- as_direct_model(sim_obs)

test_that("epidist.epidist_direct_model Stan code has no syntax errors and compiles in the default case", { # nolint: line_length_linter.
skip_on_cran()
stancode <- epidist(
data = prep_obs,
fn = brms::make_stancode,
output_dir = fs::dir_create(tempfile())
)
mod <- cmdstanr::cmdstan_model(
stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE
)
expect_true(mod$check_syntax())
expect_no_error(mod$compile())
})

test_that("epidist.epidist_direct_model fits and the MCMC converges in the default case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
fit <- epidist(
data = prep_obs,
seed = 1,
silent = 2,
output_dir = fs::dir_create(tempfile())
)
expect_s3_class(fit, "brmsfit")
expect_s3_class(fit, "epidist_fit")
expect_convergence(fit)
})

0 comments on commit a9b339a

Please sign in to comment.