Skip to content

Commit

Permalink
first stab at adding new partial predictive plot functions to this
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney committed Jan 28, 2025
1 parent 291d624 commit c838b2f
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 61 deletions.
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ S3method(as_setting_prediction_matrix,list)
S3method(autoplot,conmat_age_matrix)
S3method(autoplot,conmat_setting_prediction_matrix)
S3method(autoplot,ngm_setting_matrix)
S3method(autoplot,partial_predictions)
S3method(autoplot,partial_predictions_sum)
S3method(autoplot,setting_partial_predictions)
S3method(autoplot,setting_vaccination_matrix)
S3method(autoplot,transmission_probability_matrix)
S3method(conmat_partial_effects,contact_model)
S3method(conmat_partial_effects,setting_contact_model)
S3method(conmat_partial_effects_sum,contact_model)
S3method(conmat_partial_effects_sum,setting_contact_model)
S3method(generate_ngm,conmat_population)
S3method(generate_ngm,conmat_setting_prediction_matrix)
S3method(get_age_population_function,conmat_population)
Expand Down Expand Up @@ -68,6 +75,7 @@ export(apply_vaccination)
export(as_conmat_population)
export(as_setting_prediction_matrix)
export(autoplot)
export(conmat_partial_effects_sum)
export(conmat_population)
export(estimate_setting_contacts)
export(extrapolate_polymod)
Expand Down Expand Up @@ -111,6 +119,7 @@ importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_tile)
importFrom(ggplot2,ggplot)
importFrom(ggplot2,labs)
importFrom(ggplot2,scale_fill_viridis_c)
importFrom(ggplot2,theme_minimal)
importFrom(magrittr,"%>%")
Expand Down
7 changes: 7 additions & 0 deletions R/constructors.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ new_setting_contact_model <- function(list_model,
)
}

new_contact_model <- function(model){
structure(
model,
class = c("contact_model", class(model))
)
}

new_setting_vaccination_matrix <- function(list_matrix,
age_breaks) {
structure(
Expand Down
7 changes: 6 additions & 1 deletion R/fit_single_contact_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ fit_single_contact_model <- function(contact_data,
formula <- update(formula_no_offset, formula_offset)

# contact model for all locations together
contact_data %>%
model <- contact_data %>%
# NOTE
# Do we need to have this data cleaning step in here?
# I think we should instead have this as a separate preparation step for
Expand All @@ -228,4 +228,9 @@ fit_single_contact_model <- function(contact_data,
offset = log(participants),
data = .
)

new_contact_model(
model = model
)

}
188 changes: 129 additions & 59 deletions R/partial-prediction-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,70 +42,140 @@
#' similar ages and with their parents. Visualising the partial predictive
#' plots for other settings (school, work and other) show patterns that
#' correspond with real-life situations.
#'
#'
#' @param model either a model fit with
#' @param ages vector of integer ages
#' @return data frame with 20 columns plus n rows based on expand.grid
#' combination of ages. Contains transformed coefficients from ages.
#' @name partial-prediction
#' @noRd
#' @examples
#' fit_home <- polymod_setting_models$home
#' age_grid <- create_age_grid(ages = 1:99)
#' term_names <- extract_term_names(fit_home)
#' term_var_names <- clean_term_names(term_names)
#' age_predictions <- predict_individual_terms(
#' age_grid = age_grid,
#' fit = fit_home,
#' term_names = term_names,
#' term_var_names = term_var_names
#' )
#'
#' age_predictions_all_settings <- map_dfr(
#' .x = polymod_setting_models,
#' .f = function(x) {
#' predict_individual_terms(
#' age_grid = age_grid,
#' fit = x,
#' term_names = term_names,
#' term_var_names = term_var_names
#' )
#' },
#' .id = "setting"
#' )
#'
#' plot_age_term_settings <- gg_age_terms_settings(age_predictions_all_settings)
#' age_predictions_long <- pivot_longer_age_preds(age_predictions)
#'
#' library(ggplot2)
#' plot_age_predictions_long <- gg_age_partial_pred_long(age_predictions_long) +
#' coord_equal() +
#' labs(
#' x = "Age from",
#' y = "Age to"
#' ) +
#' theme(
#' legend.position = "bottom",
#' axis.text = element_text(size = 6),
#' panel.spacing = unit(x = 1, units = "lines")
#' ) +
#' scale_x_continuous(expand = c(0,0)) +
#' scale_y_continuous(expand = c(0,0)) +
#' expand_limits(x = c(0, 100), y = c(0, 100))
#'
#' age_predictions_long_sum <- add_age_partial_sum(age_predictions_long)
#' plot_age_predictions_sum <- gg_age_partial_sum(age_predictions_long_sum) + coord_equal() +
#' labs(x = "Age from",
#' y = "Age to") +
#' theme(
#' legend.position = "bottom"
#' ) +
#' scale_x_continuous(expand = c(0,0)) +
#' scale_y_continuous(expand = c(0,0)) +
#' expand_limits(x = c(0, 100), y = c(0, 100))
#'
#' plot_age_term_settings
#' plot_age_predictions_long
#' plot_age_predictions_sum
#' fit_home_partials <- conmat_partial_effects(
#' polymod_setting_models$home,
#' ages = 1:99
#' )
#' fit_setting_partials <- conmat_partial_effects(
#' polymod_setting_models,
#' ages = 1:99
#' )
#' autoplot(fit_home_partials)
#' autoplot(fit_setting_partials)
conmat_partial_effects <- function(model, ages, ...){
UseMethod("conmat_partial_effects")
}

#' @rdname partial-prediction
#' @export
conmat_partial_effects.contact_model <- function(model, ages, ...){

age_grid <- create_age_grid(ages = ages)
term_names <- extract_term_names(model)
term_var_names <- clean_term_names(term_names)

predict_individual_terms(
age_grid = age_grid,
fit = model,
term_names = term_names,
term_var_names = term_var_names
)

age_predictions_long <- pivot_longer_age_preds(age_predictions)

structure(
age_predictions_long,
class = c("partial_predictions", class(age_predictions_long))
)
}

#' @rdname partial-prediction
#' @export
conmat_partial_effects.setting_contact_model <- function(model, ages, ...){

age_grid <- create_age_grid(ages = ages)
term_names <- extract_term_names(model)
term_var_names <- clean_term_names(term_names)

age_predictions_all_settings <- purrr::map_dfr(
.x = model,
.f = function(x) {
predict_individual_terms(
age_grid = age_grid,
fit = x,
term_names = term_names,
term_var_names = term_var_names
)
},
.id = "setting"
)

structure(
age_predictions_all_settings,
class = c("setting_partial_predictions",
class(age_predictions_all_settings))
)
}

#' @rdname partial-prediction
#' @export
conmat_partial_effects_sum <- function(model, ages, ...){
UseMethod("conmat_partial_effects_sum")
}

#' @rdname partial-prediction
#' @export
conmat_partial_effects_sum.contact_model <- function(model, ages, ...){
age_predictions_long <- conmat_partial_effects(model, ages)
partial_sums <- add_age_partial_sum(age_predictions_long)
structure(
partial_sums,
class = c("partial_predictions_sum",
class(partial_sums))
)
}

#' @rdname autoplot-conmat-partial
#' @export
autoplot.partial_predictions_sum <- function(object, ...){
gg_age_partial_sum(object)
}

#' @rdname partial-prediction
#' @export
conmat_partial_effects_sum.setting_contact_model <- function(model, ages, ...){
setting_age_predictions_long <- conmat_partial_effects(model, ages)

setting_partial_sums <- purrr::map_dfr(
.x = setting_age_predictions_long,
.f = add_age_partial_sum,
.id = "setting"
)
structure(
setting_partial_sums,
class = c("setting_partial_predictions_sum",
class(setting_partial_sums))
)
}

# TODO
# add autoplot method for summed partials settings?


#' Plot partial predictive plots using ggplot2
#'
#' @param object An object with partial predictions from
#' @param ... Other arguments passed on
#' @return a ggplot visualisation of partial effects
#' @name autoplot-conmat-partial
#' @export
autoplot.partial_predictions <- function(object, ...){
gg_age_partial_pred_long(object)
}

#' @rdname autoplot-conmat-partial
#' @export
autoplot.setting_partial_predictions <- function(object, ...){
gg_age_terms_settings(object)
}


create_age_grid <- function(ages) {
age_grid <- expand.grid(
Expand Down
1 change: 1 addition & 0 deletions conmat.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 6b48146e-4332-448c-9082-a9cd977e528e

RestoreWorkspace: Default
SaveWorkspace: Default
Expand Down
2 changes: 1 addition & 1 deletion data-raw/create-polymod-model.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
library(conmat)
set.seed(2022 - 12 - 19)
set.seed(2025 - 01 - 28 - 1802)
polymod_contact_data <- get_polymod_setting_data()
polymod_survey_data <- get_polymod_population()
polymod_setting_models <- fit_setting_contacts(
Expand Down
Binary file modified data/polymod_setting_models.rda
Binary file not shown.
26 changes: 26 additions & 0 deletions man/autoplot-conmat-partial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions man/partial-prediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c838b2f

Please sign in to comment.