Skip to content

Commit

Permalink
remove checks that were added to mlr3 (#271)
Browse files Browse the repository at this point in the history
This was added in mlr3: mlr-org/mlr3@738b7bd
  • Loading branch information
sebffischer authored Aug 21, 2024
1 parent 41a5233 commit f9c0dcd
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 38 deletions.
9 changes: 0 additions & 9 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,6 @@ LearnerTorch = R6Class("LearnerTorch",
# Ideally we could rely on state$train_task, but there is this complication
# https://github.com/mlr-org/mlr3/issues/947
param_vals$device = auto_device(param_vals$device)
if (!test_equal_col_info(ci_train, ci_predict)) { # nolint
stopf(paste0(
"Predict task's column info does not match the train task's column info.\n",
"This migth be handled more gracefully in the future.\n",
"Training column info:\n'%s'\n",
"Prediction column info:\n'%s'"),
paste0(capture.output(ci_train), collapse = "\n"),
paste0(capture.output(ci_predict), collapse = "\n"))
}
private$.verify_predict_task(task, param_vals)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
Expand Down
18 changes: 0 additions & 18 deletions tests/testthat/test_LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -457,24 +457,6 @@ test_that("Input verification works during `$train()` (train-predict shapes work
expect_true(nrow(rr2$errors) == 0L)
})

test_that("Input verification works during `$predict()` (same column info, problematic fct -> int conversion)", {
task1 = as_task_classif(data.table(
y = factor(c("A", "B"))
), target = "y", id = "test1")

task2 = as_task_classif(data.table(
y = factor(c("A", "B"), labels = c("B", "A"), levels = c("B", "A"))
), target = "y", id = "test2")

learner = lrn("classif.torch_featureless", batch_size = 1L, epochs = 0L)

learner$train(task1)
expect_error(
learner$predict(task2),
"does not match"
)
})

test_that("col_info is propertly subset when comparing task validity during predict", {
task = tsk("iris")$select("Sepal.Length")
learner = classif_mlp2()
Expand Down
11 changes: 0 additions & 11 deletions tests/testthat/test_learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,3 @@ test_that("learner_torch_dataloader_predict works", {
expect_false(dl$drop_last)
expect_class(dl$batch_sampler$sampler, "utils_sampler_sequential")
})

test_that("wrong column info stops learner from prediction.", {
d1 = data.table(x = 1:10, y = 1:10)
d2 = data.table(x = runif(10), y = 1:10)
t1 = as_task_regr(d1, target = "y")
t2 = as_task_regr(d2, target = "y")

learner = lrn("regr.torch_featureless", epochs = 1, batch_size = 50)
learner$train(t1)
expect_error(learner$predict(t2), "more gracefully")
})

0 comments on commit f9c0dcd

Please sign in to comment.