diff --git a/DESCRIPTION b/DESCRIPTION index 51a5ee1f..debd1135 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,9 +33,9 @@ Description: Deep Learning library that extends the mlr3 framework by building defined in 'mlr3pipelines'. License: LGPL (>= 3) Depends: - mlr3 (>= 0.16.1), - mlr3pipelines (>= 0.5.0), - torch (>= 0.11.0), + mlr3 (>= 0.19.0), + mlr3pipelines (>= 0.5.2), + torch (>= 0.12.0), R (>= 3.5.0) Imports: backports, @@ -58,12 +58,13 @@ Suggests: magick, progress, rmarkdown, + rpart, viridis, - testthat (>= 3.0.0), - torchvision + torchvision, + testthat (>= 3.0.0) Remotes: - mlr-org/mlr3, - mlr-org/mlr3pipelines, + mlverse/torch, + mlr-org/paradox, mlverse/torchvision Config/testthat/edition: 3 NeedsCompilation: no diff --git a/NAMESPACE b/NAMESPACE index fc31420f..4a3f9a92 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -29,6 +29,7 @@ S3method(c,lazy_tensor) S3method(col_info,DataBackendLazy) S3method(format,lazy_tensor) S3method(hash_input,lazy_tensor) +S3method(marshal_model,learner_torch_state) S3method(materialize,data.frame) S3method(materialize,lazy_tensor) S3method(materialize,list) @@ -49,6 +50,7 @@ S3method(t_opt,"NULL") S3method(t_opt,character) S3method(t_opts,"NULL") S3method(t_opts,character) +S3method(unmarshal_model,learner_torch_state_marshaled) export(CallbackSet) export(CallbackSetCheckpoint) export(CallbackSetHistory) diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index 648670a9..9be3c98d 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -35,7 +35,12 @@ #' Defaults to an empty` list()`, i.e. no callbacks. #' #' @section State: -#' The state is a list with elements `network`, `optimizer`, `loss_fn`, `callbacks` and `seed`. +#' The state is a list with elements: +#' * `network` :: The trained [network][torch::nn_module]. +#' * `optimizer` :: The `$state_dict()` [optimizer][torch::optimizer] used to train the network. +#' * `loss_fn` :: The `$state_dict()` of the [loss][torch::nn_module] used to train the network. +#' * `callbacks` :: The [callbacks][mlr3torch::mlr_callback_set] used to train the network. +#' * `seed` :: The seed that was / is used for training and prediction. #' #' @template paramset_torchlearner #' @@ -91,7 +96,6 @@ LearnerTorch = R6Class("LearnerTorch", #' @description Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id, task_type, param_set, properties, man, label, feature_types, optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) { - assert_choice(task_type, c("regr", "classif")) predict_types = predict_types %??% switch(task_type, regr = "response", @@ -115,9 +119,6 @@ LearnerTorch = R6Class("LearnerTorch", private$.optimizer = as_torch_optimizer(optimizer, clone = TRUE) } - private$.optimizer$param_set$set_id = "opt" - private$.loss$param_set$set_id = "loss" - callbacks = as_torch_callbacks(callbacks, clone = TRUE) callback_ids = ids(callbacks) if (!test_names(callback_ids, type = "unique")) { @@ -127,10 +128,6 @@ LearnerTorch = R6Class("LearnerTorch", } private$.callbacks = set_names(callbacks, callback_ids) - walk(private$.callbacks, function(cb) { - cb$param_set$set_id = paste0("cb.", cb$id) - }) - packages = unique(c( packages, unlist(map(private$.callbacks, "packages")), @@ -138,7 +135,9 @@ LearnerTorch = R6Class("LearnerTorch", private$.optimizer$packages )) + assert_subset(properties, mlr_reflections$learner_properties[[task_type]]) + properties = union(properties, "marshal") assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]])) if (any(grepl("^(loss\\.|opt\\.|cb\\.)", param_set$ids()))) { stopf("Prefixes 'loss.', 'opt.', and 'cb.' are reserved for dynamically constructed parameters.") @@ -148,7 +147,7 @@ LearnerTorch = R6Class("LearnerTorch", paramset_torch = paramset_torchlearner(task_type) if (param_set$length > 0) { - private$.param_set_base = ParamSetCollection$new(list(param_set, paramset_torch)) + private$.param_set_base = ParamSetCollection$new(sets = list(param_set, paramset_torch)) } else { private$.param_set_base = paramset_torch } @@ -202,9 +201,31 @@ LearnerTorch = R6Class("LearnerTorch", if (length(e)) { catn(str_indent("* Errors:", e)) } + }, + #' @description + #' Marshal the learner. + #' @param ... (any)\cr + #' Additional parameters. + #' @return self + marshal = function(...) { + learner_marshal(.learner = self, ...) + }, + #' @description + #' Unmarshal the learner. + #' @param ... (any)\cr + #' Additional parameters. + #' @return self + unmarshal = function(...) { + learner_unmarshal(.learner = self, ...) } ), active = list( + #' @field marshaled (`logical(1)`)\cr + #' Whether the learner is marshaled. + marshaled = function(rhs) { + assert_ro_binding(rhs) + learner_marshaled(self) + }, #' @field network ([`nn_module()`][torch::nn_module])\cr #' The network (only available after training). network = function(rhs) { @@ -218,9 +239,9 @@ LearnerTorch = R6Class("LearnerTorch", #' The parameter set param_set = function(rhs) { if (is.null(private$.param_set)) { - private$.param_set = ParamSetCollection$new(c( - list(private$.param_set_base, private$.optimizer$param_set, private$.loss$param_set), - map(private$.callbacks, "param_set")) + private$.param_set = ParamSetCollection$new(sets = c( + list(private$.param_set_base, opt = private$.optimizer$param_set, loss = private$.loss$param_set), + set_names(map(private$.callbacks, "param_set"), sprintf("cb.%s", ids(private$.callbacks)))) ) } private$.param_set @@ -335,6 +356,27 @@ LearnerTorch = R6Class("LearnerTorch", ) ) +#' @export +marshal_model.learner_torch_state = function(model, inplace = FALSE, ...) { + # FIXME: optimizer and loss_fn + model$network = torch_serialize(model$network) + model$loss_fn = torch_serialize(model$loss_fn) + model$optimizer = torch_serialize(model$optimizer) + + structure(list( + marshaled = model, + packages = "mlr3torch" + ), class = c("learner_torch_state_marshaled", "list_marshaled", "marshaled")) +} + +#' @export +unmarshal_model.learner_torch_state_marshaled = function(model, inplace = FALSE, device = "cpu", ...) { + model = model$marshaled + model$network = torch_load(model$network, device = device) + model$loss_fn = torch_load(model$loss_fn, device = device) + model$optimizer = torch_load(model$optimizer, device = device) + return(model) +} deep_clone = function(self, private, super, name, value) { private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly diff --git a/R/LearnerTorchFeatureless.R b/R/LearnerTorchFeatureless.R index 7c4d22a5..d1b62354 100644 --- a/R/LearnerTorchFeatureless.R +++ b/R/LearnerTorchFeatureless.R @@ -22,8 +22,8 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless", #' @description Creates a new instance of this [R6][R6::R6Class] class. initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) { properties = switch(task_type, - classif = c("twoclass", "multiclass", "missings", "featureless"), - regr = c("missings", "featureless") + classif = c("twoclass", "multiclass", "missings", "featureless", "marshal"), + regr = c("missings", "featureless", "marshal") ) super$initialize( id = paste0(task_type, ".torch_featureless"), diff --git a/R/LearnerTorchMLP.R b/R/LearnerTorchMLP.R index 58553206..f411c9cc 100644 --- a/R/LearnerTorchMLP.R +++ b/R/LearnerTorchMLP.R @@ -39,7 +39,7 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP", check_activation = crate(function(x) check_class(x, "nn_module")) check_activation_args = crate(function(x) check_list(x, names = "unique")) check_neurons = crate(function(x) check_integerish(x, any.missing = FALSE, lower = 1)) - cechk_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L)) + check_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L), .parent = topenv()) param_set = ps( neurons = p_uty(tags = c("train", "predict"), custom_check = check_neurons), diff --git a/R/PipeOpTaskPreprocTorch.R b/R/PipeOpTaskPreprocTorch.R index d662915d..c6ac6b24 100644 --- a/R/PipeOpTaskPreprocTorch.R +++ b/R/PipeOpTaskPreprocTorch.R @@ -187,7 +187,7 @@ PipeOpTaskPreprocTorch = R6Class("PipeOpTaskPreprocTorch", private$.rowwise = assert_flag(rowwise) param_set = assert_param_set(param_set$clone(deep = TRUE)) - param_set$add(ps( + param_set = c(param_set, ps( stages = p_fct(c("train", "predict", "both"), tags = c("train", "required")) )) param_set$set_values(stages = stages_init) diff --git a/R/PipeOpTorchAvgPool.R b/R/PipeOpTorchAvgPool.R index 0736808d..2fb9f079 100644 --- a/R/PipeOpTorchAvgPool.R +++ b/R/PipeOpTorchAvgPool.R @@ -17,7 +17,9 @@ PipeOpTorchAvgPool = R6Class("PipeOpTorchAvgPool", count_include_pad = p_lgl(default = TRUE, tags = "train") ) if (d >= 2L) { - param_set$add(ParamDbl$new("divisor_override", default = NULL, lower = 0, tags = "train", special_vals = list(NULL))) + param_set = c(param_set, ps( + divisor_override = p_dbl(default = NULL, lower = 0, tags = "train", special_vals = list(NULL)) + )) } super$initialize( diff --git a/R/PipeOpTorchCallbacks.R b/R/PipeOpTorchCallbacks.R index 442c23a5..6f3fe869 100644 --- a/R/PipeOpTorchCallbacks.R +++ b/R/PipeOpTorchCallbacks.R @@ -47,17 +47,16 @@ PipeOpTorchCallbacks = R6Class("PipeOpTorchCallbacks", cbids = ids(private$.callbacks) assert_names(cbids, type = "unique") walk(private$.callbacks, function(cb) { - cb$param_set$set_id = cb$id - walk(cb$param_set$params, function(p) { - p$tags = union(p$tags, "train") - }) + if (length(cb$param_set$tags)) { + cb$param_set$tags = map(cb$param_set$tags, function(tags) union(tags, "train")) + } }) private$.callbacks = set_names(private$.callbacks, cbids) input = data.table(name = "input", train = "ModelDescriptor", predict = "Task") output = data.table(name = "output", train = "ModelDescriptor", predict = "Task") super$initialize( id = id, - param_set = alist(invoke(ParamSetCollection$new, sets = map(private$.callbacks, "param_set"))), + param_set = alist(ParamSetCollection$new(sets = map(private$.callbacks, "param_set"))), param_vals = param_vals, input = input, output = output, diff --git a/R/TorchCallback.R b/R/TorchCallback.R index fa8bfba0..a9ec8d77 100644 --- a/R/TorchCallback.R +++ b/R/TorchCallback.R @@ -143,7 +143,7 @@ as_torch_callbacks.character = function(x, clone = FALSE, ...) { # nolint #' @section Parameters: #' Defined by the constructor argument `param_set`. #' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`]. +#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`. #' #' @family Callback #' @family Torch Descriptor diff --git a/R/TorchDescriptor.R b/R/TorchDescriptor.R index 757c8b16..843e02a1 100644 --- a/R/TorchDescriptor.R +++ b/R/TorchDescriptor.R @@ -42,7 +42,10 @@ TorchDescriptor = R6Class("TorchDescriptor", self$generator = generator # TODO: Assert that all parameters are tagged with "train" self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator) - walk(self$param_set$params, function(param) param$tags = union(param$tags, "train")) + if (length(self$param_set$tags)) { + self$param_set$tags = map(self$param_set$tags, function(tags) union(tags, "train")) + + } if (is.function(generator)) { args = formalArgs(generator) } else { diff --git a/R/TorchLoss.R b/R/TorchLoss.R index 715be0f9..005a564c 100644 --- a/R/TorchLoss.R +++ b/R/TorchLoss.R @@ -51,7 +51,7 @@ as_torch_loss.character = function(x, clone = FALSE, ...) { # nolint #' @section Parameters: #' Defined by the constructor argument `param_set`. #' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`]. +#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`. #' #' @family Torch Descriptor #' @export diff --git a/R/TorchOptimizer.R b/R/TorchOptimizer.R index dfea10b2..97282cd9 100644 --- a/R/TorchOptimizer.R +++ b/R/TorchOptimizer.R @@ -53,7 +53,7 @@ as_torch_optimizer.character = function(x, clone = FALSE, ...) { # nolint #' @section Parameters: #' Defined by the constructor argument `param_set`. #' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`]. +#' for each argument of the wrapped loss function, where the parametes are then of type [ParamUty][`paradox::p_uty`]. #' #' @family Torch Descriptor #' @export @@ -278,7 +278,7 @@ mlr3torch_optimizers$add("sgd", mlr3torch_optimizers$add("asgd", function() { p = ps( - lr = p_dbl(default = 1e-2, lower = 0, tags = c("required", "train")), + lr = p_dbl(default = 1e-2, lower = 0, tags = "train"), lambda = p_dbl(lower = 0, upper = 1, default = 1e-4, tags = "train"), alpha = p_dbl(lower = 0, upper = Inf, default = 0.75, tags = "train"), t0 = p_int(lower = 1L, upper = Inf, default = 1e6, tags = "train"), @@ -308,7 +308,7 @@ mlr3torch_optimizers$add("rprop", p = ps( lr = p_dbl(default = 0.01, lower = 0, tags = "train"), etas = p_uty(default = c(0.5, 1.2), tags = "train"), - step_sizes = p_uty(c(1e-06, 50), tags = "train") + step_sizes = p_uty(default = c(1e-06, 50), tags = "train") ) TorchOptimizer$new( torch_optimizer = torch::optim_rprop, diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R index 07f322d1..86f104df 100644 --- a/R/learner_torch_methods.R +++ b/R/learner_torch_methods.R @@ -66,7 +66,8 @@ learner_torch_train = function(self, private, super, task, param_vals) { # In case the seed was "random" initially we want to make the sampled seed available in the state. model$seed = param_vals$seed - return(model) + + structure(model, class = c("learner_torch_state", "list")) } @@ -166,8 +167,8 @@ train_loop = function(ctx, cbs) { # The seed is added later list( network = ctx$network, - loss_fn = ctx$loss_fn, - optimizer = ctx$optimizer, + loss_fn = ctx$loss_fn$state_dict(), + optimizer = ctx$optimizer$state_dict(), callbacks = cbs ) } diff --git a/man/PipeOpPreprocTorchAugmentCenterCrop.Rd b/man/PipeOpPreprocTorchAugmentCenterCrop.Rd index 6a0c706e..9087f0a6 100644 --- a/man/PipeOpPreprocTorchAugmentCenterCrop.Rd +++ b/man/PipeOpPreprocTorchAugmentCenterCrop.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \cr size \tab untyped \tab - \tab \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchAugmentColorJitter.Rd b/man/PipeOpPreprocTorchAugmentColorJitter.Rd index bf55d6b6..a94c830f 100644 --- a/man/PipeOpPreprocTorchAugmentColorJitter.Rd +++ b/man/PipeOpPreprocTorchAugmentColorJitter.Rd @@ -20,7 +20,7 @@ The preprocessing is applied row wise (no batch dimension). saturation \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr hue \tab numeric \tab 0 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentCrop.Rd b/man/PipeOpPreprocTorchAugmentCrop.Rd index 15aa7624..bc0ad2cc 100644 --- a/man/PipeOpPreprocTorchAugmentCrop.Rd +++ b/man/PipeOpPreprocTorchAugmentCrop.Rd @@ -20,7 +20,7 @@ The preprocessing is applied row wise (no batch dimension). height \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr width \tab integer \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentHflip.Rd b/man/PipeOpPreprocTorchAugmentHflip.Rd index 39f23af7..1cb729bc 100644 --- a/man/PipeOpPreprocTorchAugmentHflip.Rd +++ b/man/PipeOpPreprocTorchAugmentHflip.Rd @@ -16,7 +16,7 @@ The preprocessing is applied row wise (no batch dimension). \tabular{llll}{ Id \tab Type \tab Default \tab Levels \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomAffine.Rd b/man/PipeOpPreprocTorchAugmentRandomAffine.Rd index 043feaab..72ce2e23 100644 --- a/man/PipeOpPreprocTorchAugmentRandomAffine.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomAffine.Rd @@ -16,12 +16,12 @@ The preprocessing is applied row wise (no batch dimension). \tabular{lllll}{ Id \tab Type \tab Default \tab Levels \tab Range \cr degrees \tab untyped \tab - \tab \tab - \cr - translate \tab untyped \tab \tab \tab - \cr - scale \tab untyped \tab \tab \tab - \cr + translate \tab untyped \tab NULL \tab \tab - \cr + scale \tab untyped \tab NULL \tab \tab - \cr resample \tab integer \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr fillcolor \tab untyped \tab 0 \tab \tab - \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomApply.Rd b/man/PipeOpPreprocTorchAugmentRandomApply.Rd index 169d6b1c..a433785c 100644 --- a/man/PipeOpPreprocTorchAugmentRandomApply.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomApply.Rd @@ -18,7 +18,7 @@ The preprocessing is applied row wise (no batch dimension). transforms \tab untyped \tab - \tab \tab - \cr p \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomChoice.Rd b/man/PipeOpPreprocTorchAugmentRandomChoice.Rd index f73d415f..70129668 100644 --- a/man/PipeOpPreprocTorchAugmentRandomChoice.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomChoice.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \cr transforms \tab untyped \tab - \tab \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomCrop.Rd b/man/PipeOpPreprocTorchAugmentRandomCrop.Rd index c07d5c90..1ba24179 100644 --- a/man/PipeOpPreprocTorchAugmentRandomCrop.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomCrop.Rd @@ -16,12 +16,12 @@ The preprocessing is applied row wise (no batch dimension). \tabular{llll}{ Id \tab Type \tab Default \tab Levels \cr size \tab untyped \tab - \tab \cr - padding \tab untyped \tab \tab \cr + padding \tab untyped \tab NULL \tab \cr pad_if_needed \tab logical \tab FALSE \tab TRUE, FALSE \cr - fill \tab untyped \tab 0 \tab \cr + fill \tab untyped \tab 0L \tab \cr padding_mode \tab character \tab constant \tab constant, edge, reflect, symmetric \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomHorizontalFlip.Rd b/man/PipeOpPreprocTorchAugmentRandomHorizontalFlip.Rd index 49522bb6..ce21c73e 100644 --- a/man/PipeOpPreprocTorchAugmentRandomHorizontalFlip.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomHorizontalFlip.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr p \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomOrder.Rd b/man/PipeOpPreprocTorchAugmentRandomOrder.Rd index ed629317..4cfe236c 100644 --- a/man/PipeOpPreprocTorchAugmentRandomOrder.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomOrder.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \cr transforms \tab untyped \tab - \tab \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomResizedCrop.Rd b/man/PipeOpPreprocTorchAugmentRandomResizedCrop.Rd index a5f74d63..5c1cf8a7 100644 --- a/man/PipeOpPreprocTorchAugmentRandomResizedCrop.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomResizedCrop.Rd @@ -16,11 +16,11 @@ The preprocessing is applied row wise (no batch dimension). \tabular{lllll}{ Id \tab Type \tab Default \tab Levels \tab Range \cr size \tab untyped \tab - \tab \tab - \cr - scale \tab untyped \tab c , 0.08, 1 \tab \tab - \cr - ratio \tab untyped \tab c , 3/4, 4/3 \tab \tab - \cr + scale \tab untyped \tab c(0.08, 1) \tab \tab - \cr + ratio \tab untyped \tab c(3/4, 4/3) \tab \tab - \cr interpolation \tab integer \tab 2 \tab \tab \eqn{[0, 3]}{[0, 3]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRandomVerticalFlip.Rd b/man/PipeOpPreprocTorchAugmentRandomVerticalFlip.Rd index 5781cfba..353b1914 100644 --- a/man/PipeOpPreprocTorchAugmentRandomVerticalFlip.Rd +++ b/man/PipeOpPreprocTorchAugmentRandomVerticalFlip.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr p \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentResizedCrop.Rd b/man/PipeOpPreprocTorchAugmentResizedCrop.Rd index 799dbc10..d4aaacf7 100644 --- a/man/PipeOpPreprocTorchAugmentResizedCrop.Rd +++ b/man/PipeOpPreprocTorchAugmentResizedCrop.Rd @@ -22,7 +22,7 @@ The preprocessing is applied row wise (no batch dimension). size \tab untyped \tab - \tab \tab - \cr interpolation \tab integer \tab 2 \tab \tab \eqn{[0, 3]}{[0, 3]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentRotate.Rd b/man/PipeOpPreprocTorchAugmentRotate.Rd index 6508b070..356c284d 100644 --- a/man/PipeOpPreprocTorchAugmentRotate.Rd +++ b/man/PipeOpPreprocTorchAugmentRotate.Rd @@ -18,10 +18,10 @@ The preprocessing is applied row wise (no batch dimension). angle \tab untyped \tab - \tab \tab - \cr resample \tab integer \tab 0 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr expand \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr - center \tab untyped \tab \tab \tab - \cr - fill \tab untyped \tab \tab \tab - \cr + center \tab untyped \tab NULL \tab \tab - \cr + fill \tab untyped \tab NULL \tab \tab - \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchAugmentVflip.Rd b/man/PipeOpPreprocTorchAugmentVflip.Rd index 36c15cfe..84a928b2 100644 --- a/man/PipeOpPreprocTorchAugmentVflip.Rd +++ b/man/PipeOpPreprocTorchAugmentVflip.Rd @@ -16,7 +16,7 @@ The preprocessing is applied row wise (no batch dimension). \tabular{llll}{ Id \tab Type \tab Default \tab Levels \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchTrafoAdjustBrightness.Rd b/man/PipeOpPreprocTorchTrafoAdjustBrightness.Rd index fe636d29..43c6abdb 100644 --- a/man/PipeOpPreprocTorchTrafoAdjustBrightness.Rd +++ b/man/PipeOpPreprocTorchTrafoAdjustBrightness.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr brightness_factor \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchTrafoAdjustGamma.Rd b/man/PipeOpPreprocTorchTrafoAdjustGamma.Rd index a2bcef11..2784ba21 100644 --- a/man/PipeOpPreprocTorchTrafoAdjustGamma.Rd +++ b/man/PipeOpPreprocTorchTrafoAdjustGamma.Rd @@ -18,7 +18,7 @@ The preprocessing is applied row wise (no batch dimension). gamma \tab numeric \tab - \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr gain \tab numeric \tab 1 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchTrafoAdjustHue.Rd b/man/PipeOpPreprocTorchTrafoAdjustHue.Rd index 1cc75a96..4bd82be6 100644 --- a/man/PipeOpPreprocTorchTrafoAdjustHue.Rd +++ b/man/PipeOpPreprocTorchTrafoAdjustHue.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr hue_factor \tab numeric \tab - \tab \tab \eqn{[-0.5, 0.5]}{[-0.5, 0.5]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchTrafoAdjustSaturation.Rd b/man/PipeOpPreprocTorchTrafoAdjustSaturation.Rd index ced93824..ea69257e 100644 --- a/man/PipeOpPreprocTorchTrafoAdjustSaturation.Rd +++ b/man/PipeOpPreprocTorchTrafoAdjustSaturation.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr saturation_factor \tab numeric \tab - \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchTrafoGrayscale.Rd b/man/PipeOpPreprocTorchTrafoGrayscale.Rd index 014d8fee..bcefcca1 100644 --- a/man/PipeOpPreprocTorchTrafoGrayscale.Rd +++ b/man/PipeOpPreprocTorchTrafoGrayscale.Rd @@ -17,7 +17,7 @@ The preprocessing is applied row wise (no batch dimension). Id \tab Type \tab Default \tab Levels \tab Range \cr num_output_channels \tab integer \tab - \tab \tab \eqn{[1, 3]}{[1, 3]} \cr stages \tab character \tab - \tab train, predict, both \tab - \cr - affect_columns \tab untyped \tab selector_all \tab \tab - \cr + affect_columns \tab untyped \tab selector_all() \tab \tab - \cr } } diff --git a/man/PipeOpPreprocTorchTrafoNormalize.Rd b/man/PipeOpPreprocTorchTrafoNormalize.Rd index a802514b..6d28fdc4 100644 --- a/man/PipeOpPreprocTorchTrafoNormalize.Rd +++ b/man/PipeOpPreprocTorchTrafoNormalize.Rd @@ -18,7 +18,7 @@ The preprocessing is applied row wise (no batch dimension). mean \tab untyped \tab - \tab \cr std \tab untyped \tab - \tab \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchTrafoPad.Rd b/man/PipeOpPreprocTorchTrafoPad.Rd index b9c7f1ca..9cf97a62 100644 --- a/man/PipeOpPreprocTorchTrafoPad.Rd +++ b/man/PipeOpPreprocTorchTrafoPad.Rd @@ -19,7 +19,7 @@ The preprocessing is applied row wise (no batch dimension). fill \tab untyped \tab 0 \tab \cr padding_mode \tab character \tab constant \tab constant, edge, reflect, symmetric \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchTrafoResize.Rd b/man/PipeOpPreprocTorchTrafoResize.Rd index b031940b..f4dcc278 100644 --- a/man/PipeOpPreprocTorchTrafoResize.Rd +++ b/man/PipeOpPreprocTorchTrafoResize.Rd @@ -18,7 +18,7 @@ The preprocessing is applied to the whole batch. size \tab untyped \tab - \tab \cr interpolation \tab character \tab 2 \tab Undefined, Bartlett, Blackman, Bohman, Box, Catrom, Cosine, Cubic, Gaussian, Hamming, \link{...} \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/PipeOpPreprocTorchTrafoRgbToGrayscale.Rd b/man/PipeOpPreprocTorchTrafoRgbToGrayscale.Rd index 10a89e1a..ca13fc86 100644 --- a/man/PipeOpPreprocTorchTrafoRgbToGrayscale.Rd +++ b/man/PipeOpPreprocTorchTrafoRgbToGrayscale.Rd @@ -16,7 +16,7 @@ The preprocessing is applied row wise (no batch dimension). \tabular{llll}{ Id \tab Type \tab Default \tab Levels \cr stages \tab character \tab - \tab train, predict, both \cr - affect_columns \tab untyped \tab selector_all \tab \cr + affect_columns \tab untyped \tab selector_all() \tab \cr } } diff --git a/man/TorchCallback.Rd b/man/TorchCallback.Rd index b7fa6f48..e6399a22 100644 --- a/man/TorchCallback.Rd +++ b/man/TorchCallback.Rd @@ -17,7 +17,7 @@ To conveniently retrieve a \code{\link{TorchCallback}}, use \code{\link[=t_clbk] Defined by the constructor argument \code{param_set}. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -for each argument of the wrapped loss function, where the parametes are then of type \code{\link{ParamUty}}. +for each argument of the wrapped loss function, where the parametes are then of type \link[paradox:Domain]{ParamUty}. } \examples{ diff --git a/man/TorchLoss.Rd b/man/TorchLoss.Rd index 51a60296..b6d54e60 100644 --- a/man/TorchLoss.Rd +++ b/man/TorchLoss.Rd @@ -17,7 +17,7 @@ Items from this dictionary can be retrieved using \code{\link[=t_loss]{t_loss()} Defined by the constructor argument \code{param_set}. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -for each argument of the wrapped loss function, where the parametes are then of type \code{\link{ParamUty}}. +for each argument of the wrapped loss function, where the parametes are then of type \link[paradox:Domain]{ParamUty}. } \examples{ diff --git a/man/TorchOptimizer.Rd b/man/TorchOptimizer.Rd index 1361d466..511fc997 100644 --- a/man/TorchOptimizer.Rd +++ b/man/TorchOptimizer.Rd @@ -17,7 +17,7 @@ Items from this dictionary can be retrieved using \code{\link[=t_opt]{t_opt()}}. Defined by the constructor argument \code{param_set}. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter -for each argument of the wrapped loss function, where the parametes are then of type \code{\link{ParamUty}}. +for each argument of the wrapped loss function, where the parametes are then of type \link[`paradox:p_uty`]{ParamUty}. } \examples{ diff --git a/man/mlr_learners.alexnet.Rd b/man/mlr_learners.alexnet.Rd index 55a85bc9..7ef0e071 100644 --- a/man/mlr_learners.alexnet.Rd +++ b/man/mlr_learners.alexnet.Rd @@ -84,7 +84,9 @@ Other Learner:
  • mlr3::Learner$reset()
  • mlr3::Learner$train()
  • mlr3torch::LearnerTorch$format()
  • +
  • mlr3torch::LearnerTorch$marshal()
  • mlr3torch::LearnerTorch$print()
  • +
  • mlr3torch::LearnerTorch$unmarshal()
  • }} diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 30e45c16..979fc931 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -117,7 +117,9 @@ Other Learner:
  • mlr3::Learner$reset()
  • mlr3::Learner$train()
  • mlr3torch::LearnerTorch$format()
  • +
  • mlr3torch::LearnerTorch$marshal()
  • mlr3torch::LearnerTorch$print()
  • +
  • mlr3torch::LearnerTorch$unmarshal()
  • }} diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd index ad87873e..509f9919 100644 --- a/man/mlr_learners.torch_featureless.Rd +++ b/man/mlr_learners.torch_featureless.Rd @@ -103,7 +103,9 @@ Other Learner:
  • mlr3::Learner$reset()
  • mlr3::Learner$train()
  • mlr3torch::LearnerTorch$format()
  • +
  • mlr3torch::LearnerTorch$marshal()
  • mlr3torch::LearnerTorch$print()
  • +
  • mlr3torch::LearnerTorch$unmarshal()
  • }} diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 3b18b5f0..2d4531f7 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -12,7 +12,14 @@ It also allows to hook into the training loop via a callback mechanism. } \section{State}{ -The state is a list with elements \code{network}, \code{optimizer}, \code{loss_fn}, \code{callbacks} and \code{seed}. +The state is a list with elements: +\itemize{ +\item \code{network} :: The trained \link[torch:nn_module]{network}. +\item \code{optimizer} :: The \verb{$state_dict()} \link[torch:optimizer]{optimizer} used to train the network. +\item \code{loss_fn} :: The \verb{$state_dict()} of the \link[torch:nn_module]{loss} used to train the network. +\item \code{callbacks} :: The \link[=mlr_callback_set]{callbacks} used to train the network. +\item \code{seed} :: The seed that was / is used for training and prediction. +} } \section{Parameters}{ @@ -111,6 +118,9 @@ Other Learner: \section{Active bindings}{ \if{html}{\out{
    }} \describe{ +\item{\code{marshaled}}{(\code{logical(1)})\cr +Whether the learner is marshaled.} + \item{\code{network}}{(\code{\link[torch:nn_module]{nn_module()}})\cr The network (only available after training).} @@ -128,6 +138,8 @@ Shortcut for \code{learner$model$callbacks$history}.} \item \href{#method-LearnerTorch-new}{\code{LearnerTorch$new()}} \item \href{#method-LearnerTorch-format}{\code{LearnerTorch$format()}} \item \href{#method-LearnerTorch-print}{\code{LearnerTorch$print()}} +\item \href{#method-LearnerTorch-marshal}{\code{LearnerTorch$marshal()}} +\item \href{#method-LearnerTorch-unmarshal}{\code{LearnerTorch$unmarshal()}} \item \href{#method-LearnerTorch-clone}{\code{LearnerTorch$clone()}} } } @@ -254,6 +266,48 @@ Currently unused.} } } \if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerTorch-marshal}{}}} +\subsection{Method \code{marshal()}}{ +Marshal the learner. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{LearnerTorch$marshal(...)}\if{html}{\out{
    }} +} + +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{...}}{(any)\cr +Additional parameters.} +} +\if{html}{\out{
    }} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerTorch-unmarshal}{}}} +\subsection{Method \code{unmarshal()}}{ +Unmarshal the learner. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{LearnerTorch$unmarshal(...)}\if{html}{\out{
    }} +} + +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{...}}{(any)\cr +Additional parameters.} +} +\if{html}{\out{
    }} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
    }} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-LearnerTorch-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd index 36bb8938..57695857 100644 --- a/man/mlr_learners_torch_image.Rd +++ b/man/mlr_learners_torch_image.Rd @@ -42,7 +42,9 @@ Other Learner:
  • mlr3::Learner$reset()
  • mlr3::Learner$train()
  • mlr3torch::LearnerTorch$format()
  • +
  • mlr3torch::LearnerTorch$marshal()
  • mlr3torch::LearnerTorch$print()
  • +
  • mlr3torch::LearnerTorch$unmarshal()
  • }} diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd index af40d4db..91c60c2f 100644 --- a/man/mlr_learners_torch_model.Rd +++ b/man/mlr_learners_torch_model.Rd @@ -86,7 +86,9 @@ Other Graph Network:
  • mlr3::Learner$reset()
  • mlr3::Learner$train()
  • mlr3torch::LearnerTorch$format()
  • +
  • mlr3torch::LearnerTorch$marshal()
  • mlr3torch::LearnerTorch$print()
  • +
  • mlr3torch::LearnerTorch$unmarshal()
  • }} diff --git a/tests/testthat/helper_autotest.R b/tests/testthat/helper_autotest.R index f8c26741..0af1c9cd 100644 --- a/tests/testthat/helper_autotest.R +++ b/tests/testthat/helper_autotest.R @@ -133,8 +133,8 @@ expect_pipeop_torch = function(graph, id, task, module_class = id, exclude_args # parameters must only ce active during training - walk(po_test$param_set$params, function(p) { - if (!(("train" %in% p$tags) && !("predict" %in% p$tags))) { + walk(po_test$param_set$tags, function(tags) { + if (!(("train" %in% tags) && !("predict" %in% tags))) { stopf("Parameters of PipeOps inheriting from PipeOpTorch must only be active during training.") } }) @@ -332,8 +332,8 @@ expect_pipeop_torch_preprocess = function(obj, shapes_in, exclude = character(0) expect_class(obj, "PipeOpTaskPreprocTorch") # a) Check that all parameters but stages have tags train and predict (this should hold in basically all cases) # parameters must only ce active during training - walk(obj$param_set$params, function(p) { - if (!(("train" %in% p$tags) && !("predict" %in% p$tags))) { + walk(obj$param_set$tags, function(tags) { + if (!(("train" %in% tags) && !("predict" %in% tags))) { stopf("Parameters of PipeOps inheriting from PipeOpTorch must only be active during training.") } }) diff --git a/tests/testthat/helper_functions.R b/tests/testthat/helper_functions.R index 090a702b..c0dfe9b4 100644 --- a/tests/testthat/helper_functions.R +++ b/tests/testthat/helper_functions.R @@ -32,91 +32,7 @@ expect_man_exists = function(man) { } # expect that 'one' is a deep clone of 'two' -expect_deep_clone = function(one, two) { - # is equal - expect_equal(one, two) - visited = new.env() - visited_b = new.env() - expect_references_differ = function(a, b, path) { - - force(path) - if (length(path) > 400) { - stop("Recursion too deep in expect_deep_clone()") - } - - # don't go in circles - addr_a = data.table::address(a) - addr_b = data.table::address(b) - if (!is.null(visited[[addr_a]])) { - return(invisible(NULL)) - } - visited[[addr_a]] = path - visited_b[[addr_b]] = path - - #if (inherits(a, "nn_module_generator") || inherits(a, "torch_optimizer_generator")) { - # stopf("Not implemented yet") - #} - if (inherits(a, "R6ClassGenerator")) { - return(NULL) - } - - # follow attributes, even for non-recursive objects - if (utils::tail(path, 1) != "[attributes]" && !is.null(base::attributes(a))) { - expect_references_differ(base::attributes(a), base::attributes(b), c(path, "[attributes]")) - } - - # don't recurse if there is nowhere to go - if (!base::is.recursive(a)) { - return(invisible(NULL)) - } - - # check that environments differ - if (base::is.environment(a)) { - # some special environments - if (identical(a, baseenv()) || identical(a, globalenv()) || identical(a, emptyenv())) { - return(invisible(NULL)) - } - if (length(path) > 1 && R6::is.R6(a) && "clone" %nin% names(a)) { - return(invisible(NULL)) # don't check if smth is not cloneable - } - if (identical(utils::tail(path, 1), c("[element train_task] 'train_task'"))) { - return(invisible(NULL)) # workaround for https://github.com/mlr-org/mlr3/issues/382 - } - if (identical(utils::tail(path, 1), c("[element fallback] 'fallback'"))) { - return(invisible(NULL)) # workaround for https://github.com/mlr-org/mlr3/issues/511 - } - label = sprintf("Object addresses differ at path %s", paste0(path, collapse = "->")) - expect_true(addr_a != addr_b, label = label) - expect_null(visited_b[[addr_a]], label = label) - } else { - a = unclass(a) - b = unclass(b) - } - # recurse - if (base::is.function(a)) { - return(invisible(NULL)) - ## # maybe this is overdoing it - ## expect_references_differ(base::formals(a), base::formals(b), c(path, "[function args]")) - ## expect_references_differ(base::body(a), base::body(b), c(path, "[function body]")) - } - objnames = base::names(a) - if (is.null(objnames) || anyDuplicated(objnames)) { - index = seq_len(base::length(a)) - } else { - index = objnames - if (base::is.environment(a)) { - index = Filter(function(x) !bindingIsActive(x, a), index) - } - } - for (i in index) { - if (utils::tail(path, 1) == "[attributes]" && i %in% c("srcref", "srcfile", ".Environment")) next - expect_references_differ(base::`[[`(a, i), base::`[[`(b, i), c(path, sprintf("[element %s]%s", i, - if (!is.null(objnames)) sprintf(" '%s'", if (is.character(index)) i else objnames[[i]]) else ""))) - } - } - expect_references_differ(one, two, "ROOT") -} expect_shallow_clone = function(one, two) { expect_equal(one, two) @@ -158,15 +74,29 @@ expect_valid_pipeop_param_set = function(po, check_ps_default_values = TRUE) { ps = po$param_set expect_true(every(ps$tags, function(x) length(intersect(c("train", "predict"), x)) > 0L)) - uties = ps$params[ps$ids("ParamUty")] - if (length(uties)) { - test_value = NO_DEF # custom_checks should fail for NO_DEF - results = map(uties, function(uty) { - uty$custom_check(test_value) - }) - expect_true(all(map_lgl(results, function(result) { - length(result) == 1L && (is.character(result) || result == TRUE) # result == TRUE is necessary because default is function(x) TRUE - })), label = "custom_check returns string on failure") + if (mlr3pipelines:::paradox_info$is_old) { + uties = ps$params[ps$ids("ParamUty")] + if (length(uties)) { + test_value = NO_DEF # custom_checks should fail for NO_DEF + results = map(uties, function(uty) { + uty$custom_check(test_value) + }) + expect_true(all(map_lgl(results, function(result) { + length(result) == 1L && (is.character(result) || result == TRUE) # result == TRUE is necessary because default is function(x) TRUE + })), label = "custom_check returns string on failure") + } + } else { + uties = ps$ids("ParamUty") + if (length(uties)) { + test_value = NO_DEF # custom_checks should fail for NO_DEF + results = map(uties, function(uty) { + psn = ps$subset(uty, allow_dangling_dependencies = TRUE) + psn$check(structure(list(test_value), names = uty)) + }) + expect_true(all(map_lgl(results, function(result) { + length(result) == 1L && (is.character(result) || result == TRUE) # result == TRUE is necessary because default is function(x) TRUE + })), label = "custom_check returns string on failure") + } } if (check_ps_default_values) { diff --git a/tests/testthat/helper_mlr3pipelines.R b/tests/testthat/helper_mlr3pipelines.R new file mode 100644 index 00000000..80e29708 --- /dev/null +++ b/tests/testthat/helper_mlr3pipelines.R @@ -0,0 +1,100 @@ +library("mlr3pipelines") +library("checkmate") +library("testthat") +library("R6") +library("mlr3misc") +library("paradox") + +lapply(list.files(system.file("testthat", package = "mlr3pipelines"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source) + +mlr_helpers = list.files(system.file("testthat", package = "mlr3pipelines"), pattern = "^helper.*\\.[rR]", full.names = TRUE) +lapply(mlr_helpers, FUN = source) + +# expect that 'one' is a deep clone of 'two' +expect_deep_clone = function(one, two) { + # is equal + expect_equal(one, two) + visited = new.env() + visited_b = new.env() + expect_references_differ = function(a, b, path) { + + force(path) + if (length(path) > 400) { + stop("Recursion too deep in expect_deep_clone()") + } + + # don't go in circles + addr_a = data.table::address(a) + addr_b = data.table::address(b) + if (!is.null(visited[[addr_a]])) { + return(invisible(NULL)) + } + visited[[addr_a]] = path + visited_b[[addr_b]] = path + + #if (inherits(a, "nn_module_generator") || inherits(a, "torch_optimizer_generator")) { + # stopf("Not implemented yet") + #} + if (inherits(a, "R6ClassGenerator")) { + return(NULL) + } + + # follow attributes, even for non-recursive objects + if (utils::tail(path, 1) != "[attributes]" && !is.null(base::attributes(a))) { + expect_references_differ(base::attributes(a), base::attributes(b), c(path, "[attributes]")) + } + + # don't recurse if there is nowhere to go + if (!base::is.recursive(a)) { + return(invisible(NULL)) + } + + # check that environments differ + if (base::is.environment(a)) { + # some special environments + if (identical(a, baseenv()) || identical(a, globalenv()) || identical(a, emptyenv())) { + return(invisible(NULL)) + } + if (length(path) > 1 && R6::is.R6(a) && "clone" %nin% names(a)) { + return(invisible(NULL)) # don't check if smth is not cloneable + } + if (identical(utils::tail(path, 1), c("[element train_task] 'train_task'"))) { + return(invisible(NULL)) # workaround for https://github.com/mlr-org/mlr3/issues/382 + } + if (identical(utils::tail(path, 1), c("[element fallback] 'fallback'"))) { + return(invisible(NULL)) # workaround for https://github.com/mlr-org/mlr3/issues/511 + } + label = sprintf("Object addresses differ at path %s", paste0(path, collapse = "->")) + expect_true(addr_a != addr_b, label = label) + expect_null(visited_b[[addr_a]], label = label) + } else { + a = unclass(a) + b = unclass(b) + } + + # recurse + if (base::is.function(a)) { + return(invisible(NULL)) + ## # maybe this is overdoing it + ## expect_references_differ(base::formals(a), base::formals(b), c(path, "[function args]")) + ## expect_references_differ(base::body(a), base::body(b), c(path, "[function body]")) + } + objnames = base::names(a) + if (is.null(objnames) || anyDuplicated(objnames)) { + index = seq_len(base::length(a)) + } else { + index = objnames + if (base::is.environment(a)) { + index = Filter(function(x) !bindingIsActive(x, a), index) + } + } + for (i in index) { + if (utils::tail(path, 1) == "[attributes]" && i %in% c("srcref", "srcfile", ".Environment")) next + expect_references_differ(base::`[[`(a, i), base::`[[`(b, i), c(path, sprintf("[element %s]%s", i, + if (!is.null(objnames)) sprintf(" '%s'", if (is.character(index)) i else objnames[[i]]) else ""))) + } + } + expect_references_differ(one, two, "ROOT") +} + + diff --git a/tests/testthat/test_LearnerTorch.R b/tests/testthat/test_LearnerTorch.R index 6a2a6d42..79b4a7bb 100644 --- a/tests/testthat/test_LearnerTorch.R +++ b/tests/testthat/test_LearnerTorch.R @@ -24,7 +24,7 @@ test_that("Basic tests: Classification", { expect_equal(learner$id, "classif.test1") expect_equal(learner$label, "Test1 Learner") expect_set_equal(learner$feature_types, c("numeric", "integer")) - expect_set_equal(learner$properties, c("multiclass", "twoclass")) + expect_set_equal(learner$properties, c("multiclass", "twoclass", "marshal")) # default predict types are correct expect_set_equal(learner$predict_types, c("response", "prob")) @@ -48,7 +48,7 @@ test_that("Basic tests: Regression", { expect_equal(learner$id, "regr.test1") expect_equal(learner$label, "Test1 Learner") expect_set_equal(learner$feature_types, c("numeric", "integer")) - expect_set_equal(learner$properties, c()) + expect_set_equal(learner$properties, "marshal") # default predict types are correct expect_set_equal(learner$predict_types, "response") @@ -167,8 +167,8 @@ test_that("the state of a trained network contains what it should", { expect_permutation(names(learner$model), c("seed", "network", "optimizer", "loss_fn", "task_col_info", "callbacks")) expect_true(is.integer(learner$model$seed)) expect_class(learner$model$network, "nn_module") - expect_class(learner$model$loss_fn, "nn_l1_loss") - expect_class(learner$model$optimizer, "optim_sgd") + expect_class(learner$model$loss_fn, "list") + expect_class(learner$model$optimizer, "list") expect_list(learner$model$callbacks, types = "CallbackSet", len = 1L) expect_equal(names(learner$model$callbacks), "history1") expect_true(is.integer(learner$model$seed)) @@ -387,6 +387,23 @@ test_that("resample() works", { expect_r6(rr, "ResampleResult") }) +test_that("callr encapsulation and marshaling", { + task = tsk("mtcars")$filter(1:5) + learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"), + neurons = 20 + ) + learner$train(task) + expect_false(learner$marshaled) + learner$marshal()$unmarshal() + expect_prediction(learner$predict(task)) + + learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"), + neurons = 20 + ) + learner$train(task) + expect_prediction(learner$predict(task)) +}) + test_that("Input verification works during `$train()` (train-predict shapes work together)", { task = nano_mnist() diff --git a/tests/testthat/test_LearnerTorchFeatureless.R b/tests/testthat/test_LearnerTorchFeatureless.R index bd0f29a6..ead9def0 100644 --- a/tests/testthat/test_LearnerTorchFeatureless.R +++ b/tests/testthat/test_LearnerTorchFeatureless.R @@ -11,7 +11,7 @@ test_that("dataset_featureless works", { test_that("Basic checks: Classification", { learner = lrn("classif.torch_featureless") expect_learner_torch(learner) - expect_set_equal(learner$properties, c("twoclass", "multiclass", "missings", "featureless")) + expect_set_equal(learner$properties, c("twoclass", "multiclass", "missings", "featureless", "marshal")) }) test_that("LearnerTorchFeatureless works", { @@ -28,5 +28,5 @@ test_that("LearnerTorchFeatureless works", { test_that("Basic checks: Regression", { learner = lrn("regr.torch_featureless") expect_learner_torch(learner) - expect_set_equal(learner$properties, c("missings", "featureless")) + expect_set_equal(learner$properties, c("missings", "featureless", "marshal")) }) diff --git a/tests/testthat/test_PipeOpTaskPreprocTorch.R b/tests/testthat/test_PipeOpTaskPreprocTorch.R index 5b16bd83..4c967095 100644 --- a/tests/testthat/test_PipeOpTaskPreprocTorch.R +++ b/tests/testthat/test_PipeOpTaskPreprocTorch.R @@ -166,13 +166,13 @@ test_that("shapes_out", { test_that("lazy tensor modified as expected", { d = data.table( y = 1:10, - x = as_lazy_tensor(rnorm(10)) + x = as_lazy_tensor(1:10) ) taskin = as_task_regr(d, target = "y") po_test = po("preproc_torch", fn = crate(function(x, a) x + a), param_set = ps(a = p_int(tags = c("train", "required"))), - a = -10, rowwise = FALSE, stages_init = "both") + a = 10, rowwise = FALSE, stages_init = "both") taskout_train = po_test$train(list(taskin))[[1L]] taskout_pred = po_test$predict(list(taskin))[[1L]] @@ -193,7 +193,7 @@ test_that("lazy tensor modified as expected", { expect_equal( as_array(materialize(taskin$data(cols = "x")[[1L]], rbind = TRUE)), - as_array(materialize(taskout_train$data(cols = "x")[[1L]], rbind = TRUE) + 10), + as_array(materialize(taskout_train$data(cols = "x")[[1L]], rbind = TRUE) - 10), tolerance = 1e-5 ) }) diff --git a/tests/testthat/test_PipeOpTorchModel.R b/tests/testthat/test_PipeOpTorchModel.R index f45ab707..47b8754a 100644 --- a/tests/testthat/test_PipeOpTorchModel.R +++ b/tests/testthat/test_PipeOpTorchModel.R @@ -44,14 +44,13 @@ test_that("Manual test: Classification and Regression", { expect_class(obj$state, "LearnerTorchModel") expect_class(obj$state$model$network, c("nn_graph", "nn_module")) # Defaults are used - expect_class(obj$state$model$optimizer, "optim_adam") - expect_class(obj$state$model$loss_fn, "nn_cross_entropy_loss") + expect_class(obj$state$model$optimizer, "list") + expect_class(obj$state$model$loss_fn, "list") # It is possible to change parameter values md$optimizer = t_opt("adagrad", lr = 0.123) obj = po("torch_model_classif", epochs = 0, batch_size = 2) obj$train(list(md)) - expect_class(obj$state$model$optimizer, "optim_adagrad") expect_true(obj$state$state$param_vals$opt.lr == 0.123) expect_true(obj$state$state$param_vals$batch_size == 2) diff --git a/tests/testthat/test_TorchCallback.R b/tests/testthat/test_TorchCallback.R index 7c075688..721de328 100644 --- a/tests/testthat/test_TorchCallback.R +++ b/tests/testthat/test_TorchCallback.R @@ -14,7 +14,7 @@ test_that("Basic checks", { Cbt4 = R6Class("CallbackSetTest4", public = list(initialize = function(x) NULL)) tcb41 = TorchCallback$new(Cbt4) expect_identical(tcb41$param_set$ids(), "x") - expect_class(tcb41$param_set$params$x, "ParamUty") + expect_equal(tcb41$param_set$params[list("x"), "cls", on = "id"][[1L]], "ParamUty") ps42 = ps(x = p_int()) tcb42 = TorchCallback$new(Cbt4, param_set = ps42) diff --git a/tests/testthat/test_TorchDescriptor.R b/tests/testthat/test_TorchDescriptor.R index b2ac893e..8a2aaf16 100644 --- a/tests/testthat/test_TorchDescriptor.R +++ b/tests/testthat/test_TorchDescriptor.R @@ -9,7 +9,7 @@ test_that("TorchDescriptor basic checks", { ) # train tag is added - expect_true("train" %in% descriptor$param_set$params$reduction$tags) + expect_true("train" %in% descriptor$param_set$tags$reduction) expect_identical(descriptor$generator, nn_mse_loss) expect_identical(descriptor$id, "mse") expect_identical(descriptor$param_set$ids(), "reduction") diff --git a/tests/testthat/test_autotests.R b/tests/testthat/test_autotests.R index a17c8e8a..756a56d7 100644 --- a/tests/testthat/test_autotests.R +++ b/tests/testthat/test_autotests.R @@ -171,6 +171,6 @@ test_that("expect_torch_callback works", { ) ) cbd = as_torch_callback(CallbackSetD) - expect_error(expect_torch_callback(cbd, check_man = FALSE), regexp = "not equal to") + expect_error(expect_torch_callback(cbd, check_man = FALSE)) })