-
Notifications
You must be signed in to change notification settings - Fork 90
/
Chicago_grid.R
68 lines (51 loc) · 1.43 KB
/
Chicago_grid.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
library(tidymodels)
library(tune)
library(doMC)
registerDoMC(cores = 8)
# ------------------------------------------------------------------------------
set.seed(7898)
data_folds <- rolling_origin(Chicago, initial = 364 * 15, assess = 7 * 4, skip = 7 * 4, cumulative = FALSE)
# ------------------------------------------------------------------------------
library(stringr)
us_hol <-
timeDate::listHolidays() %>%
str_subset("(^US)|(Easter)")
chi_rec <-
recipe(ridership ~ ., data = Chicago) %>%
step_holiday(date, holidays = us_hol) %>%
step_date(date) %>%
step_rm(date) %>%
step_dummy(all_nominal()) %>%
step_zv(all_predictors())
mars_rec <-
chi_rec %>%
step_normalize(one_of(!!stations)) %>%
step_pca(one_of(!!stations), num_comp = tune("pca comps"))
mars_mod <-
mars(num_terms = tune("mars terms"), prod_degree = tune(), prune_method = "none") %>%
set_engine("earth") %>%
set_mode("regression")
chi_wflow <-
workflow() %>%
add_recipe(mars_rec) %>%
add_model(mars_mod)
chi_grid <-
expand.grid(
`pca comps` = 0:20,
prod_degree = 1:2,
`mars terms` = 2:100
)
all_res <-
tune_grid(
chi_wflow,
resamples = data_folds,
grid = chi_grid,
control = control_grid(verbose = TRUE)
)
print(all_res$.notes[[1]])
complete_mars_grid <-
all_res %>%
collect_metrics() %>%
dplyr::filter(.metric == "rmse")
save(complete_mars_grid, file = "complete_mars_grid.RData")
q("no")