Skip to content

Commit

Permalink
rework check_prop() and use it
Browse files Browse the repository at this point in the history
  • Loading branch information
hfrick committed Sep 25, 2024
1 parent d9fcbfe commit 92e653a
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 37 deletions.
7 changes: 4 additions & 3 deletions R/initial_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
initial_split <- function(data, prop = 3 / 4,
strata = NULL, breaks = 4, pool = 0.1, ...) {
check_dots_empty()
check_prop(prop)

res <-
mc_cv(
data = data,
Expand Down Expand Up @@ -74,9 +76,7 @@ initial_split <- function(data, prop = 3 / 4,
#' @export
initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {
check_dots_empty()
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
cli_abort("{.arg prop} must be a number on (0, 1).")
}
check_prop(prop)

if (!is.numeric(lag) | !(lag %% 1 == 0)) {
cli_abort("{.arg lag} must be a whole number.")
Expand Down Expand Up @@ -156,6 +156,7 @@ testing.rsplit <- function(x, ...) {
#' @export
group_initial_split <- function(data, group, prop = 3 / 4, ..., strata = NULL, pool = 0.1) {
check_dots_empty()
check_prop(prop)

if (missing(strata)) {
res <- group_mc_cv(
Expand Down
15 changes: 0 additions & 15 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ balance_observations_helper <- function(data_split, v, target_per_fold) {

balance_prop <- function(prop, data_ind, v, replace = FALSE, strata = NULL, ...) {
rlang::check_dots_empty()
check_prop(prop, replace)

# This is the core difference between stratification and not:
#
Expand Down Expand Up @@ -290,20 +289,6 @@ balance_prop_helper <- function(prop, data_ind, v, replace) {
list_rbind()
}

check_prop <- function(prop, replace) {
acceptable_prop <- is.numeric(prop)
acceptable_prop <- acceptable_prop &&
((prop <= 1 && replace) || (prop < 1 && !replace))
acceptable_prop <- acceptable_prop && prop > 0
if (!acceptable_prop) {
cli_abort(
"{.arg prop} must be a number between 0 and 1.",
call = rlang::caller_env()
)
}
}


collapse_groups <- function(freq_table, data_ind, v) {
data_ind <- dplyr::left_join(
data_ind,
Expand Down
6 changes: 2 additions & 4 deletions R/mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
mc_cv <- function(data, prop = 3 / 4, times = 25,
strata = NULL, breaks = 4, pool = 0.1, ...) {
check_dots_empty()
check_prop(prop)

if (!missing(strata)) {
strata <- tidyselect::vars_select(names(data), !!enquo(strata))
Expand Down Expand Up @@ -103,10 +104,6 @@ mc_complement <- function(ind, n) {

mc_splits <- function(data, prop = 3 / 4, times = 25,
strata = NULL, breaks = 4, pool = 0.1) {
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
cli_abort("{.arg prop} must be a number on (0, 1).")
}

n <- nrow(data)
if (is.null(strata)) {
indices <- purrr::map(rep(n, times), sample, size = floor(n * prop))
Expand Down Expand Up @@ -170,6 +167,7 @@ group_mc_cv <- function(data, group, prop = 3 / 4, times = 25, ...,
strata = NULL, pool = 0.1) {

check_dots_empty()
check_prop(prop)

group <- validate_group({{ group }}, data)

Expand Down
11 changes: 11 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ add_class <- function(x, cls) {
x
}

check_prop <- function(prop, call = caller_env()) {
check_number_decimal(prop, call = call)
if (!(prop > 0)) {
cli_abort("{.arg prop} must be greater than 0.", call = call)
}
if (!(prop < 1)) {
cli_abort("{.arg prop} must be less than 1.", call = call)
}
invisible(NULL)
}

check_strata <- function(strata, data, call = caller_env()) {
check_string(strata, allow_null = TRUE, call = call)

Expand Down
3 changes: 3 additions & 0 deletions R/validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ validation_split <- function(data, prop = 3 / 4,
)

check_dots_empty()
check_prop(prop)

if (!missing(strata)) {
strata <- tidyselect::vars_select(names(data), !!enquo(strata))
Expand Down Expand Up @@ -114,6 +115,7 @@ validation_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {

check_dots_empty()

check_prop(prop)
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
rlang::abort("`prop` must be a number on (0, 1).")
}
Expand Down Expand Up @@ -155,6 +157,7 @@ group_validation_split <- function(data, group, prop = 3 / 4, ..., strata = NULL
check_dots_empty()

group <- validate_group({{ group }}, data)
check_prop(prop)

if (!missing(strata)) {
strata <- check_grouped_strata({{ group }}, {{ strata }}, pool, data)
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/initial_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
initial_time_split(drinks, prop = 2)
Condition
Error in `initial_time_split()`:
! `prop` must be a number on (0, 1).
! `prop` must be less than 1.

---

Expand Down Expand Up @@ -43,22 +43,22 @@
Code
initial_split(mtcars, prop = 1)
Condition
Error in `mc_splits()`:
! `prop` must be a number on (0, 1).
Error in `initial_split()`:
! `prop` must be less than 1.

---

Code
initial_time_split(mtcars, prop = 1)
Condition
Error in `initial_time_split()`:
! `prop` must be a number on (0, 1).
! `prop` must be less than 1.

---

Code
group_initial_split(mtcars, group = "cyl", prop = 1)
Condition
Error in `balance_prop()`:
! `prop` must be a number between 0 and 1.
Error in `group_initial_split()`:
! `prop` must be less than 1.

8 changes: 4 additions & 4 deletions tests/testthat/_snaps/mc.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Code
mc_cv(mtcars, prop = 1)
Condition
Error in `mc_splits()`:
! `prop` must be a number on (0, 1).
Error in `mc_cv()`:
! `prop` must be less than 1.

---

Expand Down Expand Up @@ -83,8 +83,8 @@
Code
group_mc_cv(mtcars, group = "cyl", prop = 1)
Condition
Error in `balance_prop()`:
! `prop` must be a number between 0 and 1.
Error in `group_mc_cv()`:
! `prop` must be less than 1.

---

Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/_snaps/validation_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,24 @@
Code
validation_split(mtcars, prop = 1)
Condition
Error in `mc_splits()`:
! `prop` must be a number on (0, 1).
Error in `validation_split()`:
! `prop` must be less than 1.

---

Code
validation_time_split(mtcars, prop = 1)
Condition
Error in `validation_time_split()`:
! `prop` must be a number on (0, 1).
! `prop` must be less than 1.

---

Code
group_validation_split(mtcars, group = "cyl", prop = 1)
Condition
Error in `balance_prop()`:
! `prop` must be a number between 0 and 1.
Error in `group_validation_split()`:
! `prop` must be less than 1.

---

Expand Down

0 comments on commit 92e653a

Please sign in to comment.