Skip to content

Commit

Permalink
Merge pull request #550 from tidymodels/bootci-errors
Browse files Browse the repository at this point in the history
Errors and checking for bootstrap intervals
  • Loading branch information
hfrick authored Sep 25, 2024
2 parents 7b92366 + 2ca6b5e commit fe24aaa
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 89 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ export(vfold_cv)
import(rlang)
import(vctrs)
importFrom(cli,cli_abort)
importFrom(cli,cli_text)
importFrom(cli,cli_warn)
importFrom(dplyr,"%>%")
importFrom(dplyr,arrange)
Expand Down
104 changes: 53 additions & 51 deletions R/bootci.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,53 @@
# helpers


check_rset <- function(x, app = TRUE) {
if (!inherits(x, "bootstraps")) {
cli_abort("{.arg .data} should be an {.cls rset} object generated from {.fn bootstraps}.")
}

if (app) {
if (x %>% dplyr::filter(id == "Apparent") %>% nrow() != 1) {
cli_abort("Please set {.code apparent = TRUE} in {.fn bootstraps} function.")
}
check_includes_apparent <- function(x, call = caller_env()) {
if (x %>% dplyr::filter(id == "Apparent") %>% nrow() != 1) {
cli_abort(c(
"The bootstrap resamples must include an apparent sample.",
i = "Please set {.code apparent = TRUE} in the {.fn bootstraps} function."
),
call = call
)
}
invisible(NULL)
}


stat_fmt_err <- "{.arg statistics} should select a list column of tidy results."
stat_nm_err <- paste(
"The tibble in {.arg statistics} should have columns for",
"'estimate' and 'term'."
statistics_format_error <- cli::format_inline(
"{.arg statistics} should select a list column of tidy results."
)
std_exp <- c("std.error", "robust.se")

check_tidy_names <- function(x, std_col) {
check_statistics_names <- function(x, std_col, call = caller_env()) {
# check for proper columns
if (sum(colnames(x) == "estimate") != 1) {
cli_abort(stat_nm_err)
cli_abort(
"The tibble in {.arg statistics} must have a column for 'estimate'.",
call = call
)
}
if (sum(colnames(x) == "term") != 1) {
cli_abort(stat_nm_err)
cli_abort(
"The tibble in {.arg statistics} must have a column for 'term'.",
call = call
)
}
if (std_col) {
std_candidates <- colnames(x) %in% std_exp
if (sum(std_candidates) != 1) {
cli_abort("{.arg statistics} should select a single column for the standard error.")
cli_abort(
"{.arg statistics} should select a single column for the standard error.",
call = call
)
}
}
invisible(TRUE)
}

check_tidy <- function(x, std_col = FALSE) {
check_statistics <- function(x, std_col = FALSE, call = caller_env()) {
if (!is.list(x)) {
rlang::abort(stat_fmt_err)
cli_abort(statistics_format_error, call = call)
}

# convert to data frame from list
Expand All @@ -57,10 +63,10 @@ check_tidy <- function(x, std_col = FALSE) {
}

if (inherits(x, "try-error")) {
cli_abort(stat_fmt_err)
cli_abort(statistics_format_error, call = call)
}

check_tidy_names(x, std_col)
check_statistics_names(x, std_col, call = call)

if (std_col) {
std_candidates <- colnames(x) %in% std_exp
Expand Down Expand Up @@ -117,15 +123,15 @@ new_stats <- function(x, lo, hi) {
tibble(.lower = min(res), .estimate = mean(x, na.rm = TRUE), .upper = max(res))
}

has_dots <- function(x) {
check_has_dots <- function(x, call = caller_env()) {
nms <- names(formals(x))
if (!any(nms == "...")) {
cli_abort("{.arg .fn} must have an argument {.arg ...}.")
cli_abort("{.arg .fn} must have an argument {.arg ...}.", call = call)
}
invisible(NULL)
}

check_num_resamples <- function(x, B = 1000) {
check_num_resamples <- function(x, B = 1000, call = caller_env()) {
x <-
x %>%
dplyr::group_by(term) %>%
Expand All @@ -134,7 +140,10 @@ check_num_resamples <- function(x, B = 1000) {

if (nrow(x) > 0) {
terms <- x$term
cli_warn("Recommend at least {B} non-missing bootstrap resamples for {cli::qty(terms)} term{?s} {.code {terms}}.")
cli_warn(
"Recommend at least {B} non-missing bootstrap resamples for {cli::qty(terms)} term{?s} {.code {terms}}.",
call = call
)
}
invisible(NULL)
}
Expand All @@ -145,11 +154,11 @@ check_num_resamples <- function(x, B = 1000) {

pctl_single <- function(stats, alpha = 0.05) {
if (all(is.na(stats))) {
cli_abort("All statistics have missing values.")
cli_abort("All statistics have missing values.", call = call2("int_pctl"))
}

if (!is.numeric(stats)) {
cli_abort("{.arg stats} must be a numeric vector.")
cli_abort("All statistics must be numeric.", call = call2("int_pctl"))
}

# stats is a numeric vector of values
Expand Down Expand Up @@ -283,19 +292,16 @@ int_pctl <- function(.data, ...) {
#' @rdname int_pctl
int_pctl.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {
check_dots_empty()
check_rset(.data, app = FALSE)
if (length(alpha) != 1 || !is.numeric(alpha)) {
cli_abort("{.arg alpha} must be a single numeric value.")
}
check_number_decimal(alpha, min = 0, max = 1)

.data <- .data %>% dplyr::filter(id != "Apparent")

column_name <- tidyselect::vars_select(names(.data), !!rlang::enquo(statistics))
if (length(column_name) != 1) {
rlang::abort(stat_fmt_err)
cli_abort(statistics_format_error)
}
stats <- .data[[column_name]]
stats <- check_tidy(stats, std_col = FALSE)
stats <- check_statistics(stats, std_col = FALSE)

check_num_resamples(stats, B = 1000)

Expand All @@ -319,7 +325,7 @@ t_single <- function(stats, std_err, is_orig, alpha = 0.05) {
# which_orig is the index of stats and std_err that has the original result

if (all(is.na(stats))) {
cli_abort("All statistics have missing values.")
cli_abort("All statistics have missing values.", call = call2("int_t"))
}

if (!is.logical(is_orig) || any(is.na(is_orig))) {
Expand Down Expand Up @@ -368,17 +374,15 @@ int_t <- function(.data, ...) {
#' @export
int_t.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {
check_dots_empty()
check_rset(.data)
if (length(alpha) != 1 || !is.numeric(alpha)) {
cli_abort("{.arg alpha} must be a single numeric value.")
}
check_includes_apparent(.data)
check_number_decimal(alpha, min = 0, max = 1)

column_name <- tidyselect::vars_select(names(.data), !!enquo(statistics))
if (length(column_name) != 1) {
cli_abort(stat_fmt_err)
cli_abort(statistics_format_error)
}
stats <- .data %>% dplyr::select(!!column_name, id)
stats <- check_tidy(stats, std_col = TRUE)
stats <- check_statistics(stats, std_col = TRUE)

check_num_resamples(stats, B = 500)

Expand All @@ -394,11 +398,11 @@ int_t.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {

# ----------------------------------------------------------------

bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ..., call = caller_env()) {

# TODO check per term
if (all(is.na(stats$estimate))) {
cli_abort("All statistics have missing values.")
cli_abort("All statistics have missing values.", call = call)
}

stat_groups_chr <- c("term", grep("^\\.", names(stats), value = TRUE))
Expand All @@ -414,9 +418,9 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
# To test, we run on the first LOO data set and see if it is a vector or df
loo_test <- try(rlang::exec(.fn, loo_rs$splits[[1]], ...), silent = TRUE)
if (inherits(loo_test, "try-error")) {
cat("Running `.fn` on the LOO resamples produced an error:\n")
cli_text("Running {.fn .fn} on the LOO resamples produced an error:")
print(loo_test)
cli_abort("{.arg .fn} failed.")
cli_abort("{.arg .fn} failed.", call = call)
}

loo_res <- furrr::future_map(loo_rs$splits, .fn, ...) %>% list_rbind()
Expand Down Expand Up @@ -477,19 +481,17 @@ int_bca <- function(.data, ...) {
#' @rdname int_pctl
#' @export
int_bca.bootstraps <- function(.data, statistics, alpha = 0.05, .fn, ...) {
check_rset(.data)
if (length(alpha) != 1 || !is.numeric(alpha)) {
cli_abort("{.arg alpha} must be a single numeric value.")
}
check_includes_apparent(.data)
check_number_decimal(alpha, min = 0, max = 1)

has_dots(.fn)
check_has_dots(.fn)

column_name <- tidyselect::vars_select(names(.data), !!enquo(statistics))
if (length(column_name) != 1) {
cli_abort(stat_fmt_err)
cli_abort(statistics_format_error)
}
stats <- .data %>% dplyr::select(!!column_name, id, dplyr::starts_with("."))
stats <- check_tidy(stats)
stats <- check_statistics(stats)

check_num_resamples(stats, B = 1000)

Expand Down
2 changes: 1 addition & 1 deletion R/rsample-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
## usethis namespace: start
#' @import rlang
#' @importFrom lifecycle deprecated
#' @importFrom cli cli_abort cli_warn
#' @importFrom cli cli_abort cli_warn cli_text
#' @importFrom utils globalVariables
#' @importFrom purrr map map2 map_dbl pluck map_lgl list_rbind
#' @importFrom tibble tibble is_tibble as_tibble obj_sum
Expand Down
Loading

0 comments on commit fe24aaa

Please sign in to comment.