Skip to content

Commit

Permalink
Issue #388: Refactor preprocessing functionality (#390)
Browse files Browse the repository at this point in the history
* Sketching out refactor

* Some restructure for clarity

* Refactor along new date approach

* Clear tests

* Refactor lines to save col_names

* Refactor validation functionality

* Redocument

* Remove epidist_validate and move to _model and _data approach plus some linting

* Add documentation of as_epidist_linelist arguments

* Move assert_class into imports and use in place of "check" class

* Documentation for epidist_validate_data.epidist_linelist

* Clear up the direct model file a bit

* Add creating the row_id back in to as_latent_individual

* Passing test-direct_model

* Start working to make data use dates

* Add start of unit tests and bug fix for datetime class check

* Use .row_id rather than row_id

* Use as_epidist_linelist_time function so that tests work with time data

* Fixes to tests

* Group into preprocessing functions

* Update FAQ vignette to run

* Update get started vignette to run

* Update ebola vignette to run

* Update approximate inference vignette to run

* Add documentation

* Methods consistency

* Document ...

* Again on ...

* Remove comment moved to issue

* Include as_epidist_linelist_time ad-hoc

* Add test for datetime column

* Update text in vignettes and add note about the ad-hoc function being included in package soon

* Refactor .rename_columns
  • Loading branch information
athowes authored Nov 13, 2024
1 parent 8a41c4c commit db76be3
Show file tree
Hide file tree
Showing 42 changed files with 446 additions and 398 deletions.
17 changes: 11 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ S3method(add_mean_sd,default)
S3method(add_mean_sd,gamma_samples)
S3method(add_mean_sd,lognormal_samples)
S3method(as_direct_model,data.frame)
S3method(as_latent_individual,data.frame)
S3method(as_latent_individual,epidist_linelist)
S3method(epidist,default)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_individual)
Expand All @@ -17,12 +17,14 @@ S3method(epidist_formula_model,epidist_latent_individual)
S3method(epidist_model_prior,default)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_individual)
S3method(epidist_validate,default)
S3method(epidist_validate,epidist_direct_model)
S3method(epidist_validate,epidist_latent_individual)
export(add_event_vars)
S3method(epidist_validate_data,default)
S3method(epidist_validate_data,epidist_linelist)
S3method(epidist_validate_model,default)
S3method(epidist_validate_model,epidist_direct_model)
S3method(epidist_validate_model,epidist_latent_individual)
export(add_mean_sd)
export(as_direct_model)
export(as_epidist_linelist)
export(as_latent_individual)
export(epidist)
export(epidist_diagnostics)
Expand All @@ -35,10 +37,12 @@ export(epidist_formula_model)
export(epidist_model_prior)
export(epidist_prior)
export(epidist_stancode)
export(epidist_validate)
export(epidist_validate_data)
export(epidist_validate_model)
export(filter_obs_by_obs_time)
export(filter_obs_by_ptime)
export(is_direct_model)
export(is_epidist_linelist)
export(is_latent_individual)
export(observe_process)
export(predict_delay_parameters)
Expand All @@ -50,6 +54,7 @@ export(simulate_uniform_cases)
import(ggplot2)
importFrom(brms,bf)
importFrom(brms,prior)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_integer)
Expand Down
26 changes: 3 additions & 23 deletions R/direct_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,6 @@ assert_direct_model_input <- function(data) {
assert_numeric(data$stime, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the direct model. It does this by
#' adding columns used in the model to the `data` object provided. To do this,
#' the `data` must already have columns for the case number (integer),
#' (positive, numeric) times for the primary and secondary event times. The
#' output of this function is a `epidist_direct_model` class object, which may
#' be passed to [epidist()] to perform inference for the model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_direct_model
#' @method as_direct_model data.frame
#' @family direct_model
#' @autoglobal
Expand All @@ -35,23 +24,14 @@ as_direct_model.data.frame <- function(data) {
class(data) <- c("epidist_direct_model", class(data))
data <- data |>
mutate(delay = .data$stime - .data$ptime)
epidist_validate(data)
epidist_validate_model(data)
return(data)
}

#' Validate direct model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the direct model. As well as making sure that
#' `is_direct_model()` is true, it also checks that `data` is a `data.frame`
#' with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_direct_model
#' @method epidist_validate_model epidist_direct_model
#' @family direct_model
#' @export
epidist_validate.epidist_direct_model <- function(data, ...) {
epidist_validate_model.epidist_direct_model <- function(data, ...) {
assert_true(is_direct_model(data))
assert_direct_model_input(data)
assert_names(names(data), must.include = c("case", "ptime", "stime", "delay"))
Expand Down
2 changes: 1 addition & 1 deletion R/epidist-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @importFrom dplyr filter select
#' @importFrom brms bf prior
#' @importFrom checkmate assert_data_frame assert_names assert_integer
#' assert_true assert_factor assert_numeric
#' assert_true assert_factor assert_numeric assert_class
#' @importFrom cli cli_abort cli_inform cli_abort cli_warn
#' @importFrom stats as.formula
## usethis namespace: end
Expand Down
2 changes: 1 addition & 1 deletion R/family.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @family family
#' @export
epidist_family <- function(data, family = "lognormal", ...) {
epidist_validate(data)
epidist_validate_model(data)
family <- brms:::validate_family(family)
class(family) <- c(family$family, class(family))
family <- .add_dpar_info(family)
Expand Down
2 changes: 1 addition & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ epidist <- function(data, formula, family, prior, backend, fn, ...) {
epidist.default <- function(data, formula = mu ~ 1,
family = "lognormal", prior = NULL,
backend = "cmdstanr", fn = brms::brm, ...) {
epidist_validate(data)
epidist_validate_model(data)
epidist_family <- epidist_family(data, family)
epidist_formula <- epidist_formula(
data = data, family = epidist_family, formula = formula
Expand Down
2 changes: 1 addition & 1 deletion R/formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' @family formula
#' @export
epidist_formula <- function(data, family, formula, ...) {
epidist_validate(data)
epidist_validate_model(data)
formula <- brms:::validate_formula(formula, data = data, family = family)
formula <- .make_intercepts_explicit(formula)
formula <- epidist_formula_model(data, formula)
Expand Down
3 changes: 0 additions & 3 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

utils::globalVariables(c(
"samples", # <epidist_diagnostics>
"stime_lwr", # <as_latent_individual.data.frame>
"stime_upr", # <as_latent_individual.data.frame>
"ptime_upr", # <as_latent_individual.data.frame>
"woverlap", # <epidist_stancode.epidist_latent_individual>
":=", # <filter_obs_by_ptime>
"rlnorm", # <simulate_secondary>
Expand Down
83 changes: 18 additions & 65 deletions R/latent_individual.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,89 +7,42 @@ as_latent_individual <- function(data) {
UseMethod("as_latent_individual")
}

assert_latent_individual_input <- function(data) {
assert_data_frame(data)
assert_names(
names(data),
must.include = c("case", "ptime_lwr", "ptime_upr",
"stime_lwr", "stime_upr", "obs_time")
)
assert_integer(data$case, lower = 0)
assert_numeric(data$ptime_lwr, lower = 0)
assert_numeric(data$ptime_upr, lower = 0)
assert_true(all(data$ptime_upr - data$ptime_lwr > 0))
assert_numeric(data$stime_lwr, lower = 0)
assert_numeric(data$stime_upr, lower = 0)
assert_true(all(data$stime_upr - data$stime_lwr > 0))
assert_numeric(data$obs_time, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the latent individual model. It does
#' this by adding columns used in the model to the `data` object provided. To do
#' this, the `data` must already have columns for the case number (integer),
#' (positive, numeric) upper and lower bounds for the primary and secondary
#' event times, as well as a (positive, numeric) time that observation takes
#' place. The output of this function is a `epidist_latent_individual` class
#' object, which may be passed to [epidist()] to perform inference for the
#' model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_latent_individual
#' @method as_latent_individual data.frame
#' @method as_latent_individual epidist_linelist
#' @family latent_individual
#' @autoglobal
#' @export
as_latent_individual.data.frame <- function(data) {
assert_latent_individual_input(data)
as_latent_individual.epidist_linelist <- function(data) {
epidist_validate_data(data)
class(data) <- c("epidist_latent_individual", class(data))
data <- data |>
mutate(
relative_obs_time = .data$obs_time - .data$ptime_lwr,
pwindow = ifelse(
stime_lwr < .data$ptime_upr,
stime_upr - .data$ptime_lwr,
ptime_upr - .data$ptime_lwr
.data$stime_lwr < .data$ptime_upr,
.data$stime_upr - .data$ptime_lwr,
.data$ptime_upr - .data$ptime_lwr
),
woverlap = as.numeric(.data$stime_lwr < .data$ptime_upr),
swindow = .data$stime_upr - .data$stime_lwr,
delay = .data$stime_lwr - .data$ptime_lwr,
row_id = dplyr::row_number()
.row_id = dplyr::row_number()
)
if (nrow(data) > 1) {
data <- mutate(data, row_id = factor(.data$row_id))
}
epidist_validate(data)
epidist_validate_model(data)
return(data)
}

#' Validate latent individual model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the latent individual model. As well as making sure that
#' `is_latent_individual()` is true, it also checks that `data` is a
#' `data.frame` with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_latent_individual
#' @method epidist_validate_model epidist_latent_individual
#' @family latent_individual
#' @export
epidist_validate.epidist_latent_individual <- function(data, ...) {
epidist_validate_model.epidist_latent_individual <- function(data, ...) {
assert_true(is_latent_individual(data))
assert_latent_individual_input(data)
assert_names(
names(data),
must.include = c("case", "ptime_lwr", "ptime_upr",
"stime_lwr", "stime_upr", "obs_time",
"relative_obs_time", "pwindow", "woverlap",
"swindow", "delay", "row_id")
col_names <- c(
"ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_time",
"relative_obs_time", "pwindow", "woverlap", "swindow", "delay", ".row_id"
)
if (nrow(data) > 1) {
assert_factor(data$row_id)
}
assert_names(names(data), must.include = col_names)
assert_numeric(data$relative_obs_time, lower = 0)
# pwindow as f(p) and swindow as f(s) checks here?
assert_numeric(data$pwindow, lower = 0)
assert_numeric(data$woverlap, lower = 0)
assert_numeric(data$swindow, lower = 0)
Expand Down Expand Up @@ -159,7 +112,7 @@ epidist_stancode.epidist_latent_individual <- function(data,
epidist_formula(data),
...) {

epidist_validate(data)
epidist_validate_model(data)

stanvars_version <- .version_stanvar()

Expand Down Expand Up @@ -202,13 +155,13 @@ epidist_stancode.epidist_latent_individual <- function(data,
brms::stanvar(
block = "data",
scode = "array[N - wN] int noverlap;",
x = filter(data, woverlap == 0)$row_id,
x = filter(data, woverlap == 0)$.row_id,
name = "noverlap"
) +
brms::stanvar(
block = "data",
scode = "array[wN] int woverlap;",
x = filter(data, woverlap > 0)$row_id,
x = filter(data, woverlap > 0)$.row_id,
name = "woverlap"
)

Expand Down
Loading

0 comments on commit db76be3

Please sign in to comment.