Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jun 17, 2024
1 parent d24ae7d commit 94e18f2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion R/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ paramset_torchlearner = function(task_type) {
collate_fn = p_uty(tags = c("train", "predict"), default = NULL),
pin_memory = p_lgl(default = FALSE, tags = c("train", "predict")),
drop_last = p_lgl(tags = "train", default = FALSE),
timeout = p_int(default = -1L, tags = c("train_predict")),
timeout = p_int(default = -1L, tags = c("train", "predict")),
worker_init_fn = p_uty(tags = c("train", "predict")),
worker_globals = p_uty(tags = c("train", "predict")),
worker_packages = p_uty(tags = c("train", "predict"), custom_check = check_character, special_vals = list(NULL))
Expand Down
23 changes: 16 additions & 7 deletions tests/testthat/helper_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,29 @@ LearnerTorchTest1 = R6Class("LearnerTorchTest1",
.dataloader = function(task, param_vals) {
ingress_token = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, length(task$feature_names)))
dataset = task_dataset(
task,
task,
feature_ingress_tokens = list(num = ingress_token),
target_batchgetter = crate(function(data, device) {
torch_tensor(data = as.integer(data[[1]]), dtype = torch_long(), device = device)
}),
device = param_vals$device
)
dl = dataloader(
dataset = dataset,
batch_size = param_vals$batch_size,
shuffle = param_vals$shuffle
dl_args = c(
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"num_workers",
"collate_fn",
"pin_memory",
"drop_last",
"timeout",
"worker_init_fn",
"worker_globals",
"worker_packages"
)
return(dl)

args = param_vals[names(param_vals) %in% dl_args]
invoke(dataloader, dataset = dataset)
}
)
)
Expand Down
4 changes: 0 additions & 4 deletions tests/testthat/test_paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ test_that("paramset works", {
test_ps = function(param_set) {
expect_r6(param_set, "ParamSet")
expect_true(all(map_lgl(param_set$tags, function(tags) "train" %in% tags || "predict" %in% tags)))
# all parameters are required and have initial values set
expect_true(all(map_lgl(param_set$tags, function(tags) "required" %in% tags)))
# only parameters batch_size and epochs don't have initial values
expect_true(length(param_set$values) == param_set$length - 2)
}
param_set_regr = paramset_torchlearner("regr")

Expand Down

0 comments on commit 94e18f2

Please sign in to comment.