diff --git a/NEWS.md b/NEWS.md index bdfc8e30..a0aa6d06 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,8 +16,6 @@ * Added better printing methods for initial split objects. -* Added a new `balance` option to `group_vfold_cv()` to balance folds either by the number of groups or the number of observations (@mikemahoney218, #300). - # rsample 0.1.1 * Updated documentation on stratified sampling (#245). diff --git a/R/groups.R b/R/groups.R index d8712bca..f43c9732 100644 --- a/R/groups.R +++ b/R/groups.R @@ -22,10 +22,9 @@ #' #' set.seed(123) #' group_vfold_cv(Sacramento, group = city, v = 5) -#' group_vfold_cv(Sacramento, group = city, v = 5, balance = "observations") #' #' @export -group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "observations"), ...) { +group_vfold_cv <- function(data, group = NULL, v = NULL, ...) { if (!missing(group)) { group <- tidyselect::vars_select(names(data), !!enquo(group)) if (length(group) == 0) { @@ -42,9 +41,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", " rlang::abort("`group` should be a column in `data`.") } - balance <- rlang::arg_match(balance) - - split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance) + split_objs <- group_vfold_splits(data = data, group = group, v = v) ## We remove the holdout indices since it will save space and we can ## derive them later when they are needed. @@ -68,9 +65,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", " ) } -group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "observations")) { - - balance <- rlang::arg_match(balance) +group_vfold_splits <- function(data, group, v = NULL) { group <- getElement(data, group) max_v <- length(unique(group)) @@ -81,7 +76,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "obs check_v(v = v, max_v = max_v, rows = "rows", call = rlang::caller_env()) } - indices <- make_groups(data, group, v, balance) + indices <- make_groups(data, group, v) indices <- lapply(indices, default_complement, n = nrow(data)) split_objs <- purrr::map(indices, diff --git a/R/make_groups.R b/R/make_groups.R index 6cb07895..73cad123 100644 --- a/R/make_groups.R +++ b/R/make_groups.R @@ -8,22 +8,13 @@ #' @param group A variable in `data` (single character or name) used for #' grouping observations with the same value to either the analysis or #' assessment set within a fold. -#' @param balance If `v` is less than the number of unique groups, how should -#' groups be combined into folds? If `"groups"`, the default, then groups are -#' combined randomly to balance the number of groups in each fold. -#' If `"observations"`, then groups are combined to balance the number of -#' observations in each fold. #' #' @keywords internal -make_groups <- function(data, group, v, balance) { +make_groups <- function(data, group, v) { data_ind <- data.frame(..index = 1:nrow(data), ..group = group) data_ind$..group <- as.character(data_ind$..group) - res <- switch( - balance, - "groups" = balance_groups(data_ind, v), - "observations" = balance_observations(data_ind, v) - ) + res <- balance_groups(data_ind, v) data_ind <- res$data_ind keys <- res$keys @@ -47,27 +38,3 @@ balance_groups <- function(data_ind, v) { keys = keys ) } - -balance_observations <- function(data_ind, v) { - while (vec_unique_count(data_ind$..group) > v) { - freq_table <- vec_count(data_ind$..group) - # Recategorize the largest group to be collapsed - # as the smallest group to be kept - group_to_keep <- vec_slice(freq_table, v) - group_to_collapse <- vec_slice(freq_table, v + 1) - collapse_lgl <- vec_in(data_ind$..group, group_to_collapse$key) - vec_slice(data_ind$..group, collapse_lgl) <- group_to_keep$key - } - unique_groups <- unique(data_ind$..group) - - keys <- data.frame( - ..group = unique_groups, - ..folds = sample(rep(seq_len(v), length.out = length(unique_groups))) - ) - - list( - data_ind = data_ind, - keys = keys - ) - -} diff --git a/man/group_vfold_cv.Rd b/man/group_vfold_cv.Rd index 7c25d590..85ecfb29 100644 --- a/man/group_vfold_cv.Rd +++ b/man/group_vfold_cv.Rd @@ -4,13 +4,7 @@ \alias{group_vfold_cv} \title{Group V-Fold Cross-Validation} \usage{ -group_vfold_cv( - data, - group = NULL, - v = NULL, - balance = c("groups", "observations"), - ... -) +group_vfold_cv(data, group = NULL, v = NULL, ...) } \arguments{ \item{data}{A data frame.} @@ -22,12 +16,6 @@ assessment set within a fold.} \item{v}{The number of partitions of the data set. If left as \code{NULL}, \code{v} will be set to the number of unique values in the group.} -\item{balance}{If \code{v} is less than the number of unique groups, how should -groups be combined into folds? If \code{"groups"}, the default, then groups are -combined randomly to balance the number of groups in each fold. -If \code{"observations"}, then groups are combined to balance the number of -observations in each fold.} - \item{...}{Not currently used.} } \value{ @@ -51,6 +39,5 @@ data(Sacramento, package = "modeldata") set.seed(123) group_vfold_cv(Sacramento, group = city, v = 5) -group_vfold_cv(Sacramento, group = city, v = 5, balance = "observations") \dontshow{\}) # examplesIf} } diff --git a/man/make_groups.Rd b/man/make_groups.Rd index ac969161..b2d0cc49 100644 --- a/man/make_groups.Rd +++ b/man/make_groups.Rd @@ -4,7 +4,7 @@ \alias{make_groups} \title{Make groupings for grouped rsplits} \usage{ -make_groups(data, group, v, balance) +make_groups(data, group, v) } \arguments{ \item{data}{A data frame.} @@ -14,12 +14,6 @@ grouping observations with the same value to either the analysis or assessment set within a fold.} \item{v}{The number of partitions of the data set.} - -\item{balance}{If \code{v} is less than the number of unique groups, how should -groups be combined into folds? If \code{"groups"}, the default, then groups are -combined randomly to balance the number of groups in each fold. -If \code{"observations"}, then groups are combined to balance the number of -observations in each fold.} } \description{ This function powers \link{group_vfold_cv} by splitting the data based upon diff --git a/tests/testthat/test-groups.R b/tests/testthat/test-groups.R index f9cade07..907621fe 100644 --- a/tests/testthat/test-groups.R +++ b/tests/testthat/test-groups.R @@ -84,37 +84,6 @@ test_that("tibble input", { expect_true(all(table(sp_out) == 1)) }) -test_that("other balance methods", { - data(ames, package = "modeldata") - set.seed(11) - rs1 <- group_vfold_cv(ames, "Neighborhood", balance = "observations", v = 2) - sizes1 <- dim_rset(rs1) - - expect_true(all(sizes1$analysis == 1465)) - expect_true(all(sizes1$assessment == 1465)) - same_data <- - purrr::map_lgl(rs1$splits, function(x) { - all.equal(x$data, ames) - }) - expect_true(all(same_data)) - - good_holdout <- purrr::map_lgl( - rs1$splits, - function(x) { - length(intersect(x$in_ind, x$out_id)) == 0 - } - ) - expect_true(all(good_holdout)) - - expect_true( - !any( - unique(as.character(assessment(rs1$splits[[1]])$Neighborhood)) %in% - unique(as.character(analysis(rs1$splits[[1]])$Neighborhood)) - ) - ) - -}) - test_that("printing", { expect_snapshot(group_vfold_cv(warpbreaks, "tension")) })