Skip to content

Commit

Permalink
Refactor group logic (#308)
Browse files Browse the repository at this point in the history
* Refactor groups

* Move group functions

* Move group documentation

* Don't add make_groups to website
  • Loading branch information
mikemahoney218 authored Jun 22, 2022
1 parent 0d67974 commit 7752ef1
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 46 deletions.
50 changes: 7 additions & 43 deletions R/groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,10 @@
#' out at a time. A common use of this kind of resampling is when you have
#' repeated measures of the same subject.
#'
#' @param data A data frame.
#' @param group A variable in `data` (single character or name) used for
#' grouping observations with the same value to either the analysis or
#' assessment set within a fold.
#' @param v The number of partitions of the data set. If let
#' `NULL`, `v` will be set to the number of unique values
#' in the group.
#' @param balance If `v` is less than the number of unique groups, how should
#' groups be combined into folds? If `"groups"`, the default, then groups are
#' combined randomly to balance the number of groups in each fold.
#' If `"observations"`, then groups are combined to balance the number of
#' observations in each fold.
#' @param ... Not currently used.
#' @inheritParams vfold_cv
#' @param v The number of partitions of the data set. If left as `NULL`, `v`
#' will be set to the number of unique values in the group.
#' @inheritParams make_groups
#' @export
#' @return A tibble with classes `group_vfold_cv`,
#' `rset`, `tbl_df`, `tbl`, and `data.frame`.
Expand Down Expand Up @@ -81,43 +72,16 @@ group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "obs

balance <- rlang::arg_match(balance)


uni_groups <- unique(getElement(data, group))
max_v <- length(uni_groups)
group <- getElement(data, group)
max_v <- length(unique(group))

if (is.null(v)) {
v <- max_v
} else {
check_v(v = v, max_v = max_v, rows = "rows", call = rlang::caller_env())
}
data_ind <- data.frame(..index = 1:nrow(data), ..group = getElement(data, group))
data_ind$..group <- as.character(data_ind$..group)

if (balance == "groups") {
keys <- data.frame(..group = uni_groups)
keys$..folds <- sample(rep(1:v, length.out = max_v))
} else if (balance == "observations") {
while (vec_unique_count(data_ind$..group) > v) {
freq_table <- vec_count(data_ind$..group)
# Recategorize the largest group to be collapsed
# as the smallest group to be kept
group_to_keep <- vec_slice(freq_table, v)
group_to_collapse <- vec_slice(freq_table, v + 1)
collapse_lgl <- vec_in(data_ind$..group, group_to_collapse$key)
vec_slice(data_ind$..group, collapse_lgl) <- group_to_keep$key
}

keys <- data.frame(..group = unique(data_ind$..group))
n <- nrow(keys)
keys$..folds <- sample(rep(1:v, length.out = n))
}

keys$..group <- as.character(keys$..group)

data_ind <- data_ind %>%
full_join(keys, by = "..group") %>%
arrange(..index)
indices <- split_unnamed(data_ind$..index, data_ind$..folds)
indices <- make_groups(data, group, v, balance)
indices <- lapply(indices, default_complement, n = nrow(data))
split_objs <-
purrr::map(indices,
Expand Down
73 changes: 73 additions & 0 deletions R/make_groups.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#' Make groupings for grouped rsplits
#'
#' This function powers [group_vfold_cv] by splitting the data based upon
#' a grouping variable and returning the assessment set indices for each
#' split.
#'
#' @inheritParams vfold_cv
#' @param group A variable in `data` (single character or name) used for
#' grouping observations with the same value to either the analysis or
#' assessment set within a fold.
#' @param balance If `v` is less than the number of unique groups, how should
#' groups be combined into folds? If `"groups"`, the default, then groups are
#' combined randomly to balance the number of groups in each fold.
#' If `"observations"`, then groups are combined to balance the number of
#' observations in each fold.
#'
#' @keywords internal
make_groups <- function(data, group, v, balance) {
data_ind <- data.frame(..index = 1:nrow(data), ..group = group)
data_ind$..group <- as.character(data_ind$..group)

res <- switch(
balance,
"groups" = balance_groups(data_ind, v),
"observations" = balance_observations(data_ind, v)
)
data_ind <- res$data_ind
keys <- res$keys

keys$..group <- as.character(keys$..group)

data_ind <- data_ind %>%
full_join(keys, by = "..group") %>%
arrange(..index)
split_unnamed(data_ind$..index, data_ind$..folds)

}

balance_groups <- function(data_ind, v) {
unique_groups <- unique(data_ind$..group)
keys <- data.frame(
..group = unique_groups,
..folds = sample(rep(seq_len(v), length.out = length(unique_groups)))
)
list(
data_ind = data_ind,
keys = keys
)
}

balance_observations <- function(data_ind, v) {
while (vec_unique_count(data_ind$..group) > v) {
freq_table <- vec_count(data_ind$..group)
# Recategorize the largest group to be collapsed
# as the smallest group to be kept
group_to_keep <- vec_slice(freq_table, v)
group_to_collapse <- vec_slice(freq_table, v + 1)
collapse_lgl <- vec_in(data_ind$..group, group_to_collapse$key)
vec_slice(data_ind$..group, collapse_lgl) <- group_to_keep$key
}
unique_groups <- unique(data_ind$..group)

keys <- data.frame(
..group = unique_groups,
..folds = sample(rep(seq_len(v), length.out = length(unique_groups)))
)

list(
data_ind = data_ind,
keys = keys
)

}
5 changes: 2 additions & 3 deletions man/group_vfold_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/make_groups.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7752ef1

Please sign in to comment.