Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bundle #217

Merged
merged 14 commits into from
May 3, 2024
15 changes: 8 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ Description: Deep Learning library that extends the mlr3 framework by building
defined in 'mlr3pipelines'.
License: LGPL (>= 3)
Depends:
mlr3 (>= 0.16.1),
mlr3pipelines (>= 0.5.0),
torch (>= 0.11.0),
mlr3 (>= 0.19.0),
mlr3pipelines (>= 0.5.2),
torch (>= 0.12.0),
R (>= 3.5.0)
Imports:
backports,
Expand All @@ -58,12 +58,13 @@ Suggests:
magick,
progress,
rmarkdown,
rpart,
viridis,
testthat (>= 3.0.0),
torchvision
torchvision,
testthat (>= 3.0.0)
Remotes:
mlr-org/mlr3,
mlr-org/mlr3pipelines,
mlverse/torch,
mlr-org/paradox,
mlverse/torchvision
Config/testthat/edition: 3
NeedsCompilation: no
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
S3method(format,lazy_tensor)
S3method(hash_input,lazy_tensor)
S3method(marshal_model,learner_torch_state)
S3method(materialize,data.frame)
S3method(materialize,lazy_tensor)
S3method(materialize,list)
Expand All @@ -49,6 +50,7 @@ S3method(t_opt,"NULL")
S3method(t_opt,character)
S3method(t_opts,"NULL")
S3method(t_opts,character)
S3method(unmarshal_model,learner_torch_state_marshaled)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
Expand Down
68 changes: 55 additions & 13 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
#' Defaults to an empty` list()`, i.e. no callbacks.
#'
#' @section State:
#' The state is a list with elements `network`, `optimizer`, `loss_fn`, `callbacks` and `seed`.
#' The state is a list with elements:
#' * `network` :: The trained [network][torch::nn_module].
#' * `optimizer` :: The `$state_dict()` [optimizer][torch::optimizer] used to train the network.
#' * `loss_fn` :: The `$state_dict()` of the [loss][torch::nn_module] used to train the network.
#' * `callbacks` :: The [callbacks][mlr3torch::mlr_callback_set] used to train the network.
#' * `seed` :: The seed that was / is used for training and prediction.
#'
#' @template paramset_torchlearner
#'
Expand Down Expand Up @@ -91,7 +96,6 @@ LearnerTorch = R6Class("LearnerTorch",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, task_type, param_set, properties, man, label, feature_types,
optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) {

assert_choice(task_type, c("regr", "classif"))
predict_types = predict_types %??% switch(task_type,
regr = "response",
Expand All @@ -115,9 +119,6 @@ LearnerTorch = R6Class("LearnerTorch",
private$.optimizer = as_torch_optimizer(optimizer, clone = TRUE)
}

private$.optimizer$param_set$set_id = "opt"
private$.loss$param_set$set_id = "loss"

callbacks = as_torch_callbacks(callbacks, clone = TRUE)
callback_ids = ids(callbacks)
if (!test_names(callback_ids, type = "unique")) {
Expand All @@ -127,18 +128,16 @@ LearnerTorch = R6Class("LearnerTorch",
}

private$.callbacks = set_names(callbacks, callback_ids)
walk(private$.callbacks, function(cb) {
cb$param_set$set_id = paste0("cb.", cb$id)
})

packages = unique(c(
packages,
unlist(map(private$.callbacks, "packages")),
private$.loss$packages,
private$.optimizer$packages
))


assert_subset(properties, mlr_reflections$learner_properties[[task_type]])
properties = union(properties, "marshal")
assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]))
if (any(grepl("^(loss\\.|opt\\.|cb\\.)", param_set$ids()))) {
stopf("Prefixes 'loss.', 'opt.', and 'cb.' are reserved for dynamically constructed parameters.")
Expand All @@ -148,7 +147,7 @@ LearnerTorch = R6Class("LearnerTorch",

paramset_torch = paramset_torchlearner(task_type)
if (param_set$length > 0) {
private$.param_set_base = ParamSetCollection$new(list(param_set, paramset_torch))
private$.param_set_base = ParamSetCollection$new(sets = list(param_set, paramset_torch))
} else {
private$.param_set_base = paramset_torch
}
Expand Down Expand Up @@ -202,9 +201,31 @@ LearnerTorch = R6Class("LearnerTorch",
if (length(e)) {
catn(str_indent("* Errors:", e))
}
},
#' @description
#' Marshal the learner.
#' @param ... (any)\cr
#' Additional parameters.
#' @return self
marshal = function(...) {
learner_marshal(.learner = self, ...)
},
#' @description
#' Unmarshal the learner.
#' @param ... (any)\cr
#' Additional parameters.
#' @return self
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
}
),
active = list(
#' @field marshaled (`logical(1)`)\cr
#' Whether the learner is marshaled.
marshaled = function(rhs) {
assert_ro_binding(rhs)
learner_marshaled(self)
},
#' @field network ([`nn_module()`][torch::nn_module])\cr
#' The network (only available after training).
network = function(rhs) {
Expand All @@ -218,9 +239,9 @@ LearnerTorch = R6Class("LearnerTorch",
#' The parameter set
param_set = function(rhs) {
if (is.null(private$.param_set)) {
private$.param_set = ParamSetCollection$new(c(
list(private$.param_set_base, private$.optimizer$param_set, private$.loss$param_set),
map(private$.callbacks, "param_set"))
private$.param_set = ParamSetCollection$new(sets = c(
list(private$.param_set_base, opt = private$.optimizer$param_set, loss = private$.loss$param_set),
set_names(map(private$.callbacks, "param_set"), sprintf("cb.%s", ids(private$.callbacks))))
)
}
private$.param_set
Expand Down Expand Up @@ -335,6 +356,27 @@ LearnerTorch = R6Class("LearnerTorch",
)
)

#' @export
marshal_model.learner_torch_state = function(model, inplace = FALSE, ...) {
# FIXME: optimizer and loss_fn
model$network = torch_serialize(model$network)
model$loss_fn = torch_serialize(model$loss_fn)
model$optimizer = torch_serialize(model$optimizer)

structure(list(
marshaled = model,
packages = "mlr3torch"
), class = c("learner_torch_state_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.learner_torch_state_marshaled = function(model, inplace = FALSE, device = "cpu", ...) {
model = model$marshaled
model$network = torch_load(model$network, device = device)
model$loss_fn = torch_load(model$loss_fn, device = device)
model$optimizer = torch_load(model$optimizer, device = device)
return(model)
}

deep_clone = function(self, private, super, name, value) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerTorchFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
properties = switch(task_type,
classif = c("twoclass", "multiclass", "missings", "featureless"),
regr = c("missings", "featureless")
classif = c("twoclass", "multiclass", "missings", "featureless", "marshal"),
regr = c("missings", "featureless", "marshal")
)
super$initialize(
id = paste0(task_type, ".torch_featureless"),
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerTorchMLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",
check_activation = crate(function(x) check_class(x, "nn_module"))
check_activation_args = crate(function(x) check_list(x, names = "unique"))
check_neurons = crate(function(x) check_integerish(x, any.missing = FALSE, lower = 1))
cechk_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L))
check_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L), .parent = topenv())

param_set = ps(
neurons = p_uty(tags = c("train", "predict"), custom_check = check_neurons),
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskPreprocTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ PipeOpTaskPreprocTorch = R6Class("PipeOpTaskPreprocTorch",
private$.rowwise = assert_flag(rowwise)
param_set = assert_param_set(param_set$clone(deep = TRUE))

param_set$add(ps(
param_set = c(param_set, ps(
stages = p_fct(c("train", "predict", "both"), tags = c("train", "required"))
))
param_set$set_values(stages = stages_init)
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpTorchAvgPool.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ PipeOpTorchAvgPool = R6Class("PipeOpTorchAvgPool",
count_include_pad = p_lgl(default = TRUE, tags = "train")
)
if (d >= 2L) {
param_set$add(ParamDbl$new("divisor_override", default = NULL, lower = 0, tags = "train", special_vals = list(NULL)))
param_set = c(param_set, ps(
divisor_override = p_dbl(default = NULL, lower = 0, tags = "train", special_vals = list(NULL))
))
}

super$initialize(
Expand Down
9 changes: 4 additions & 5 deletions R/PipeOpTorchCallbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,16 @@ PipeOpTorchCallbacks = R6Class("PipeOpTorchCallbacks",
cbids = ids(private$.callbacks)
assert_names(cbids, type = "unique")
walk(private$.callbacks, function(cb) {
cb$param_set$set_id = cb$id
walk(cb$param_set$params, function(p) {
p$tags = union(p$tags, "train")
})
if (length(cb$param_set$tags)) {
cb$param_set$tags = map(cb$param_set$tags, function(tags) union(tags, "train"))
}
})
private$.callbacks = set_names(private$.callbacks, cbids)
input = data.table(name = "input", train = "ModelDescriptor", predict = "Task")
output = data.table(name = "output", train = "ModelDescriptor", predict = "Task")
super$initialize(
id = id,
param_set = alist(invoke(ParamSetCollection$new, sets = map(private$.callbacks, "param_set"))),
param_set = alist(ParamSetCollection$new(sets = map(private$.callbacks, "param_set"))),
param_vals = param_vals,
input = input,
output = output,
Expand Down
2 changes: 1 addition & 1 deletion R/TorchCallback.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ as_torch_callbacks.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`.
#'
#' @family Callback
#' @family Torch Descriptor
Expand Down
5 changes: 4 additions & 1 deletion R/TorchDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ TorchDescriptor = R6Class("TorchDescriptor",
self$generator = generator
# TODO: Assert that all parameters are tagged with "train"
self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator)
walk(self$param_set$params, function(param) param$tags = union(param$tags, "train"))
if (length(self$param_set$tags)) {
self$param_set$tags = map(self$param_set$tags, function(tags) union(tags, "train"))

}
if (is.function(generator)) {
args = formalArgs(generator)
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/TorchLoss.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ as_torch_loss.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`.
#'
#' @family Torch Descriptor
#' @export
Expand Down
6 changes: 3 additions & 3 deletions R/TorchOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ as_torch_optimizer.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type [ParamUty][`paradox::p_uty`].
#'
#' @family Torch Descriptor
#' @export
Expand Down Expand Up @@ -278,7 +278,7 @@ mlr3torch_optimizers$add("sgd",
mlr3torch_optimizers$add("asgd",
function() {
p = ps(
lr = p_dbl(default = 1e-2, lower = 0, tags = c("required", "train")),
lr = p_dbl(default = 1e-2, lower = 0, tags = "train"),
lambda = p_dbl(lower = 0, upper = 1, default = 1e-4, tags = "train"),
alpha = p_dbl(lower = 0, upper = Inf, default = 0.75, tags = "train"),
t0 = p_int(lower = 1L, upper = Inf, default = 1e6, tags = "train"),
Expand Down Expand Up @@ -308,7 +308,7 @@ mlr3torch_optimizers$add("rprop",
p = ps(
lr = p_dbl(default = 0.01, lower = 0, tags = "train"),
etas = p_uty(default = c(0.5, 1.2), tags = "train"),
step_sizes = p_uty(c(1e-06, 50), tags = "train")
step_sizes = p_uty(default = c(1e-06, 50), tags = "train")
)
TorchOptimizer$new(
torch_optimizer = torch::optim_rprop,
Expand Down
7 changes: 4 additions & 3 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ learner_torch_train = function(self, private, super, task, param_vals) {

# In case the seed was "random" initially we want to make the sampled seed available in the state.
model$seed = param_vals$seed
return(model)

structure(model, class = c("learner_torch_state", "list"))
}


Expand Down Expand Up @@ -166,8 +167,8 @@ train_loop = function(ctx, cbs) {
# The seed is added later
list(
network = ctx$network,
loss_fn = ctx$loss_fn,
optimizer = ctx$optimizer,
loss_fn = ctx$loss_fn$state_dict(),
optimizer = ctx$optimizer$state_dict(),
callbacks = cbs
)
}
Expand Down
2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentCenterCrop.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentColorJitter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentCrop.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentHflip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/PipeOpPreprocTorchAugmentRandomAffine.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentRandomApply.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentRandomChoice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading