Skip to content

Commit

Permalink
Merge pull request #465 from tidymodels/expand-group-intervals
Browse files Browse the repository at this point in the history
Expand grouping variables for bootstrap intervals
  • Loading branch information
hfrick authored Sep 12, 2024
2 parents 8ccee92 + d839b1f commit 617b619
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 38 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

* `vfold_cv()` and `clustering_cv()` now error on implicit leave-one-out cross-validation (@seb09, #527).

* Bootstrap intervals via `int_pctl()`, `int_t()`, and `int_bca()` now allow for more flexible grouping (#465).

## Bug fixes

* `vfold_cv()` now utilizes the `breaks` argument correctly for repeated cross-validation (@ZWael, #471).
Expand Down
115 changes: 81 additions & 34 deletions R/bootci.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,41 +65,46 @@ check_tidy <- function(x, std_col = FALSE) {
if (std_col) {
std_candidates <- colnames(x) %in% std_exp
std_candidates <- colnames(x)[std_candidates]
re_name <- list(std_err = std_candidates)
if (has_id) {
x <-
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates)) %>%
mutate(id = (id == "Apparent")) %>%
setNames(c("term", "estimate", "orig", "std_err"))
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates),
dplyr::starts_with(".")) %>%
mutate(orig = (id == "Apparent")) %>%
dplyr::rename(!!!re_name)
} else {
x <-
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates)) %>%
setNames(c("term", "estimate", "std_err"))
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates),
dplyr::starts_with(".")) %>%
dplyr::rename(!!!re_name)
}
} else {
if (has_id) {
x <-
dplyr::select(x, term, estimate, id) %>%
dplyr::select(x, term, estimate, id, dplyr::starts_with(".")) %>%
mutate(orig = (id == "Apparent")) %>%
dplyr::select(-id)
} else {
x <- dplyr::select(x, term, estimate)
x <- dplyr::select(x, term, estimate, dplyr::starts_with("."))
}
}

x
}


get_p0 <- function(x, alpha = 0.05) {
get_p0 <- function(x, alpha = 0.05, groups) {
group_sym <- rlang::syms(groups)

orig <- x %>%
group_by(term) %>%
group_by(!!!group_sym) %>%
dplyr::filter(orig) %>%
dplyr::select(term, theta_0 = estimate) %>%
dplyr::select(!!!group_sym, theta_0 = estimate) %>%
ungroup()
x %>%
dplyr::filter(!orig) %>%
inner_join(orig, by = "term") %>%
group_by(term) %>%
inner_join(orig, by = groups) %>%
group_by(!!!group_sym) %>%
summarize(p0 = mean(estimate <= theta_0, na.rm = TRUE)) %>%
mutate(
Z0 = stats::qnorm(p0),
Expand Down Expand Up @@ -172,9 +177,10 @@ pctl_single <- function(stats, alpha = 0.05) {
#' @param statistics An unquoted column name or `dplyr` selector that identifies
#' a single column in the data set containing the individual bootstrap
#' estimates. This must be a list column of tidy tibbles (with columns
#' `term` and `estimate`). For t-intervals, a
#' standard tidy column (usually called `std.err`) is required.
#' See the examples below.
#' `term` and `estimate`). Optionally, users can include columns whose names
#' begin with a period and the intervals will be created for each combination
#' of these variables and the `term` column. For t-intervals, a standard tidy
#' column (usually called `std.err`) is required. See the examples below.
#' @param alpha Level of significance.
#' @param .fn A function to calculate statistic of interest. The
#' function should take an `rsplit` as the first argument and the `...` are
Expand All @@ -200,12 +206,15 @@ pctl_single <- function(stats, alpha = 0.05) {
#' Application_. Cambridge: Cambridge University Press.
#' doi:10.1017/CBO9780511802843
#'
#' @examplesIf rlang::is_installed("broom")
#' @examplesIf rlang::is_installed("broom") & rlang::is_installed("modeldata")
#' \donttest{
#' library(broom)
#' library(dplyr)
#' library(purrr)
#' library(tibble)
#' library(tidyr)
#'
#' # ------------------------------------------------------------------------------
#'
#' lm_est <- function(split, ...) {
#' lm(mpg ~ disp + hp, data = analysis(split)) %>%
Expand All @@ -221,6 +230,8 @@ pctl_single <- function(stats, alpha = 0.05) {
#' int_t(car_rs, results)
#' int_bca(car_rs, results, .fn = lm_est)
#'
#' # ------------------------------------------------------------------------------
#'
#' # putting results into a tidy format
#' rank_corr <- function(split) {
#' dat <- analysis(split)
Expand All @@ -237,6 +248,31 @@ pctl_single <- function(stats, alpha = 0.05) {
#' bootstraps(Sacramento, 1000, apparent = TRUE) %>%
#' mutate(correlations = map(splits, rank_corr)) %>%
#' int_pctl(correlations)
#'
#' # ------------------------------------------------------------------------------
#' # An example of computing the interval for each value of a custom grouping
#' # factor (type of house in this example)
#'
#' # Get regression estimates for each house type
#' lm_est <- function(split, ...) {
#' analysis(split) %>%
#' tidyr::nest(.by = c(type)) %>%
#' # Compute regression estimates for each house type
#' mutate(
#' betas = purrr::map(data, ~ lm(log10(price) ~ sqft, data = .x) %>% tidy())
#' ) %>%
#' # Convert the column name to begin with a period
#' rename(.type = type) %>%
#' select(.type, betas) %>%
#' unnest(cols = betas)
#' }
#'
#' set.seed(52156)
#' house_rs <-
#' bootstraps(Sacramento, 1000, apparent = TRUE) %>%
#' mutate(results = map(splits, lm_est))
#'
#' int_pctl(house_rs, results)
#' }
#' @export
int_pctl <- function(.data, ...) {
Expand All @@ -263,8 +299,11 @@ int_pctl.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {

check_num_resamples(stats, B = 1000)

stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups <- rlang::syms(stat_groups)

vals <- stats %>%
dplyr::group_by(term) %>%
dplyr::group_by(!!!stat_groups) %>%
dplyr::do(pctl_single(.$estimate, alpha = alpha)) %>%
dplyr::ungroup()
vals
Expand Down Expand Up @@ -343,9 +382,10 @@ int_t.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {

check_num_resamples(stats, B = 500)

vals <-
stats %>%
dplyr::group_by(term) %>%
stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups <- rlang::syms(stat_groups)
vals <- stats %>%
dplyr::group_by(!!!stat_groups) %>%
dplyr::do(t_single(.$estimate, .$std_err, .$orig, alpha = alpha)) %>%
dplyr::ungroup()
vals
Expand All @@ -361,8 +401,11 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
cli_abort("All statistics have missing values.")
}

stat_groups_chr <- c("term", grep("^\\.", names(stats), value = TRUE))
stat_groups_sym <- rlang::syms(stat_groups_chr)

### Estimating Z0 bias-correction
bias_corr_stats <- get_p0(stats, alpha = alpha)
bias_corr_stats <- get_p0(stats, alpha = alpha, groups = stat_groups_chr)

# need the original data frame here
loo_rs <- loo_cv(orig_data)
Expand All @@ -380,16 +423,16 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {

loo_estimate <-
loo_res %>%
dplyr::group_by(term) %>%
dplyr::group_by(!!!stat_groups_sym) %>%
dplyr::summarize(loo = mean(estimate, na.rm = TRUE)) %>%
dplyr::inner_join(loo_res, by = "term", multiple = "all") %>%
dplyr::group_by(term) %>%
dplyr::inner_join(loo_res, by = stat_groups_chr, multiple = "all") %>%
dplyr::group_by(!!!stat_groups_sym) %>%
dplyr::summarize(
cubed = sum((loo - estimate)^3),
squared = sum((loo - estimate)^2)
) %>%
dplyr::ungroup() %>%
dplyr::inner_join(bias_corr_stats, by = "term") %>%
dplyr::inner_join(bias_corr_stats, by = stat_groups_chr) %>%
dplyr::mutate(
a = cubed / (6 * (squared^(3 / 2))),
Zu = (Z0 + Za) / (1 - a * (Z0 + Za)) + Z0,
Expand All @@ -400,21 +443,25 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {

terms <- loo_estimate$term
stats <- stats %>% dplyr::filter(!orig)
for (i in seq_along(terms)) {
tmp <- new_stats(stats$estimate[stats$term == terms[i]],
lo = loo_estimate$lo[i],
hi = loo_estimate$hi[i]
)
tmp$term <- terms[i]

keys <- stats %>% dplyr::distinct(!!!stat_groups_sym)
for (i in seq_len(nrow(keys))) {
tmp_stats <- dplyr::inner_join(stats, keys[i,], by = stat_groups_chr)
tmp_loo <- dplyr::inner_join(loo_estimate, keys[i,], by = stat_groups_chr)

tmp <- new_stats(tmp_stats$estimate,
lo = tmp_loo$lo,
hi = tmp_loo$hi)
tmp <- dplyr::bind_cols(tmp, keys[i,])
if (i == 1) {
ci_bca <- tmp
} else {
ci_bca <- bind_rows(ci_bca, tmp)
ci_bca <- dplyr::bind_rows(ci_bca, tmp)
}
}
ci_bca <-
ci_bca %>%
dplyr::select(term, .lower, .estimate, .upper) %>%
dplyr::select(!!!stat_groups_sym, .lower, .estimate, .upper) %>%
dplyr::mutate(
.alpha = alpha,
.method = "BCa"
Expand All @@ -441,7 +488,7 @@ int_bca.bootstraps <- function(.data, statistics, alpha = 0.05, .fn, ...) {
if (length(column_name) != 1) {
cli_abort(stat_fmt_err)
}
stats <- .data %>% dplyr::select(!!column_name, id)
stats <- .data %>% dplyr::select(!!column_name, id, dplyr::starts_with("."))
stats <- check_tidy(stats)

check_num_resamples(stats, B = 1000)
Expand Down
39 changes: 35 additions & 4 deletions man/int_pctl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 62 additions & 0 deletions tests/testthat/test-bootci.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,65 @@ test_that("regression intervals", {
"must be a single numeric value"
)
})

test_that("compute intervals with additional grouping terms", {
skip_if_not_installed("broom")

lm_coefs <- function(dat) {
mod <- lm(mpg ~ I(1/disp), data = dat)
tidy(mod)
}

coef_by_engine_shape <- function(split, ...) {
split %>%
analysis() %>%
dplyr::rename(.vs = vs) %>%
tidyr::nest(.by = .vs) %>%
dplyr::mutate(coefs = map(data, lm_coefs)) %>%
dplyr::select(-data) %>%
tidyr::unnest(coefs)
}

set.seed(270)
boot_rs <-
bootstraps(mtcars, 1000, apparent = TRUE) %>%
dplyr::mutate(results = purrr::map(splits, coef_by_engine_shape))

pctl_res <- int_pctl(boot_rs, results)
t_res <- int_t(boot_rs, results)
bca_res <- int_bca(boot_rs, results, .fn = coef_by_engine_shape)

exp_ptype <-
tibble::tibble(
term = character(0),
.vs = numeric(0),
.lower = numeric(0),
.estimate = numeric(0),
.upper = numeric(0),
.alpha = numeric(0),
.method = character(0)
)

expect_equal(pctl_res[0, ], exp_ptype)
expect_equal(t_res[0, ], exp_ptype)
expect_equal(bca_res[0, ], exp_ptype)

exp_combos <-
tibble::tribble(
~term, ~.vs,
"(Intercept)", 0,
"(Intercept)", 1,
"I(1/disp)", 0,
"I(1/disp)", 1
)

group_patterns <- function(x) {
dplyr::distinct(x, term, .vs) %>%
dplyr::arrange(term, .vs)
}

expect_equal(group_patterns(pctl_res), exp_combos)
expect_equal(group_patterns(t_res), exp_combos)
expect_equal(group_patterns(bca_res), exp_combos)
})

0 comments on commit 617b619

Please sign in to comment.