diff --git a/R/paramset_torchlearner.R b/R/paramset_torchlearner.R index 36121820..24b303fe 100644 --- a/R/paramset_torchlearner.R +++ b/R/paramset_torchlearner.R @@ -70,7 +70,7 @@ paramset_torchlearner = function(task_type) { collate_fn = p_uty(tags = c("train", "predict"), default = NULL), pin_memory = p_lgl(default = FALSE, tags = c("train", "predict")), drop_last = p_lgl(tags = "train", default = FALSE), - timeout = p_int(default = -1L, tags = c("train_predict")), + timeout = p_int(default = -1L, tags = c("train", "predict")), worker_init_fn = p_uty(tags = c("train", "predict")), worker_globals = p_uty(tags = c("train", "predict")), worker_packages = p_uty(tags = c("train", "predict"), custom_check = check_character, special_vals = list(NULL)) diff --git a/tests/testthat/helper_learner.R b/tests/testthat/helper_learner.R index 2a043b4b..8000a730 100644 --- a/tests/testthat/helper_learner.R +++ b/tests/testthat/helper_learner.R @@ -29,20 +29,29 @@ LearnerTorchTest1 = R6Class("LearnerTorchTest1", .dataloader = function(task, param_vals) { ingress_token = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, length(task$feature_names))) dataset = task_dataset( - task, + task, feature_ingress_tokens = list(num = ingress_token), target_batchgetter = crate(function(data, device) { torch_tensor(data = as.integer(data[[1]]), dtype = torch_long(), device = device) }), device = param_vals$device ) - dl = dataloader( - dataset = dataset, - batch_size = param_vals$batch_size, - shuffle = param_vals$shuffle + dl_args = c( + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "num_workers", + "collate_fn", + "pin_memory", + "drop_last", + "timeout", + "worker_init_fn", + "worker_globals", + "worker_packages" ) - return(dl) - + args = param_vals[names(param_vals) %in% dl_args] + invoke(dataloader, dataset = dataset) } ) ) diff --git a/tests/testthat/test_paramset_torchlearner.R b/tests/testthat/test_paramset_torchlearner.R index 0ed4272a..21eba467 100644 --- a/tests/testthat/test_paramset_torchlearner.R +++ b/tests/testthat/test_paramset_torchlearner.R @@ -2,10 +2,6 @@ test_that("paramset works", { test_ps = function(param_set) { expect_r6(param_set, "ParamSet") expect_true(all(map_lgl(param_set$tags, function(tags) "train" %in% tags || "predict" %in% tags))) - # all parameters are required and have initial values set - expect_true(all(map_lgl(param_set$tags, function(tags) "required" %in% tags))) - # only parameters batch_size and epochs don't have initial values - expect_true(length(param_set$values) == param_set$length - 2) } param_set_regr = paramset_torchlearner("regr")