Skip to content

Commit

Permalink
Merge pull request #310 from tidymodels/a_little_unbalanced
Browse files Browse the repository at this point in the history
Remove balance
  • Loading branch information
juliasilge authored Jun 24, 2022
2 parents 7752ef1 + 18d0113 commit 4b853fe
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 98 deletions.
2 changes: 0 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

* Added better printing methods for initial split objects.

* Added a new `balance` option to `group_vfold_cv()` to balance folds either by the number of groups or the number of observations (@mikemahoney218, #300).

# rsample 0.1.1

* Updated documentation on stratified sampling (#245).
Expand Down
13 changes: 4 additions & 9 deletions R/groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
#'
#' set.seed(123)
#' group_vfold_cv(Sacramento, group = city, v = 5)
#' group_vfold_cv(Sacramento, group = city, v = 5, balance = "observations")
#'
#' @export
group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "observations"), ...) {
group_vfold_cv <- function(data, group = NULL, v = NULL, ...) {
if (!missing(group)) {
group <- tidyselect::vars_select(names(data), !!enquo(group))
if (length(group) == 0) {
Expand All @@ -42,9 +41,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "
rlang::abort("`group` should be a column in `data`.")
}

balance <- rlang::arg_match(balance)

split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance)
split_objs <- group_vfold_splits(data = data, group = group, v = v)

## We remove the holdout indices since it will save space and we can
## derive them later when they are needed.
Expand All @@ -68,9 +65,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "
)
}

group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "observations")) {

balance <- rlang::arg_match(balance)
group_vfold_splits <- function(data, group, v = NULL) {

group <- getElement(data, group)
max_v <- length(unique(group))
Expand All @@ -81,7 +76,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "obs
check_v(v = v, max_v = max_v, rows = "rows", call = rlang::caller_env())
}

indices <- make_groups(data, group, v, balance)
indices <- make_groups(data, group, v)
indices <- lapply(indices, default_complement, n = nrow(data))
split_objs <-
purrr::map(indices,
Expand Down
37 changes: 2 additions & 35 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,13 @@
#' @param group A variable in `data` (single character or name) used for
#' grouping observations with the same value to either the analysis or
#' assessment set within a fold.
#' @param balance If `v` is less than the number of unique groups, how should
#' groups be combined into folds? If `"groups"`, the default, then groups are
#' combined randomly to balance the number of groups in each fold.
#' If `"observations"`, then groups are combined to balance the number of
#' observations in each fold.
#'
#' @keywords internal
make_groups <- function(data, group, v, balance) {
make_groups <- function(data, group, v) {
data_ind <- data.frame(..index = 1:nrow(data), ..group = group)
data_ind$..group <- as.character(data_ind$..group)

res <- switch(
balance,
"groups" = balance_groups(data_ind, v),
"observations" = balance_observations(data_ind, v)
)
res <- balance_groups(data_ind, v)
data_ind <- res$data_ind
keys <- res$keys

Expand All @@ -47,27 +38,3 @@ balance_groups <- function(data_ind, v) {
keys = keys
)
}

balance_observations <- function(data_ind, v) {
while (vec_unique_count(data_ind$..group) > v) {
freq_table <- vec_count(data_ind$..group)
# Recategorize the largest group to be collapsed
# as the smallest group to be kept
group_to_keep <- vec_slice(freq_table, v)
group_to_collapse <- vec_slice(freq_table, v + 1)
collapse_lgl <- vec_in(data_ind$..group, group_to_collapse$key)
vec_slice(data_ind$..group, collapse_lgl) <- group_to_keep$key
}
unique_groups <- unique(data_ind$..group)

keys <- data.frame(
..group = unique_groups,
..folds = sample(rep(seq_len(v), length.out = length(unique_groups)))
)

list(
data_ind = data_ind,
keys = keys
)

}
15 changes: 1 addition & 14 deletions man/group_vfold_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 1 addition & 7 deletions man/make_groups.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 0 additions & 31 deletions tests/testthat/test-groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,37 +84,6 @@ test_that("tibble input", {
expect_true(all(table(sp_out) == 1))
})

test_that("other balance methods", {
data(ames, package = "modeldata")
set.seed(11)
rs1 <- group_vfold_cv(ames, "Neighborhood", balance = "observations", v = 2)
sizes1 <- dim_rset(rs1)

expect_true(all(sizes1$analysis == 1465))
expect_true(all(sizes1$assessment == 1465))
same_data <-
purrr::map_lgl(rs1$splits, function(x) {
all.equal(x$data, ames)
})
expect_true(all(same_data))

good_holdout <- purrr::map_lgl(
rs1$splits,
function(x) {
length(intersect(x$in_ind, x$out_id)) == 0
}
)
expect_true(all(good_holdout))

expect_true(
!any(
unique(as.character(assessment(rs1$splits[[1]])$Neighborhood)) %in%
unique(as.character(analysis(rs1$splits[[1]])$Neighborhood))
)
)

})

test_that("printing", {
expect_snapshot(group_vfold_cv(warpbreaks, "tension"))
})
Expand Down

0 comments on commit 4b853fe

Please sign in to comment.