Skip to content

Commit

Permalink
Bundle (#217)
Browse files Browse the repository at this point in the history
* support bundling property

* use torch functions

* fix tests

* tests

* fix tests for featureless

* ...

* ...

* remove browser, document

* chore: require mlr3 0.19.0

* typo in desc

* cleanup, use test helpers from pipelines

* fix pipelines version

---------

Co-authored-by: be-marc <[email protected]>
  • Loading branch information
sebffischer and be-marc authored May 3, 2024
1 parent b702f64 commit 131db14
Show file tree
Hide file tree
Showing 56 changed files with 347 additions and 187 deletions.
15 changes: 8 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ Description: Deep Learning library that extends the mlr3 framework by building
defined in 'mlr3pipelines'.
License: LGPL (>= 3)
Depends:
mlr3 (>= 0.16.1),
mlr3pipelines (>= 0.5.0),
torch (>= 0.11.0),
mlr3 (>= 0.19.0),
mlr3pipelines (>= 0.5.2),
torch (>= 0.12.0),
R (>= 3.5.0)
Imports:
backports,
Expand All @@ -58,12 +58,13 @@ Suggests:
magick,
progress,
rmarkdown,
rpart,
viridis,
testthat (>= 3.0.0),
torchvision
torchvision,
testthat (>= 3.0.0)
Remotes:
mlr-org/mlr3,
mlr-org/mlr3pipelines,
mlverse/torch,
mlr-org/paradox,
mlverse/torchvision
Config/testthat/edition: 3
NeedsCompilation: no
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
S3method(format,lazy_tensor)
S3method(hash_input,lazy_tensor)
S3method(marshal_model,learner_torch_state)
S3method(materialize,data.frame)
S3method(materialize,lazy_tensor)
S3method(materialize,list)
Expand All @@ -49,6 +50,7 @@ S3method(t_opt,"NULL")
S3method(t_opt,character)
S3method(t_opts,"NULL")
S3method(t_opts,character)
S3method(unmarshal_model,learner_torch_state_marshaled)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
Expand Down
68 changes: 55 additions & 13 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
#' Defaults to an empty` list()`, i.e. no callbacks.
#'
#' @section State:
#' The state is a list with elements `network`, `optimizer`, `loss_fn`, `callbacks` and `seed`.
#' The state is a list with elements:
#' * `network` :: The trained [network][torch::nn_module].
#' * `optimizer` :: The `$state_dict()` [optimizer][torch::optimizer] used to train the network.
#' * `loss_fn` :: The `$state_dict()` of the [loss][torch::nn_module] used to train the network.
#' * `callbacks` :: The [callbacks][mlr3torch::mlr_callback_set] used to train the network.
#' * `seed` :: The seed that was / is used for training and prediction.
#'
#' @template paramset_torchlearner
#'
Expand Down Expand Up @@ -91,7 +96,6 @@ LearnerTorch = R6Class("LearnerTorch",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, task_type, param_set, properties, man, label, feature_types,
optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) {

assert_choice(task_type, c("regr", "classif"))
predict_types = predict_types %??% switch(task_type,
regr = "response",
Expand All @@ -115,9 +119,6 @@ LearnerTorch = R6Class("LearnerTorch",
private$.optimizer = as_torch_optimizer(optimizer, clone = TRUE)
}

private$.optimizer$param_set$set_id = "opt"
private$.loss$param_set$set_id = "loss"

callbacks = as_torch_callbacks(callbacks, clone = TRUE)
callback_ids = ids(callbacks)
if (!test_names(callback_ids, type = "unique")) {
Expand All @@ -127,18 +128,16 @@ LearnerTorch = R6Class("LearnerTorch",
}

private$.callbacks = set_names(callbacks, callback_ids)
walk(private$.callbacks, function(cb) {
cb$param_set$set_id = paste0("cb.", cb$id)
})

packages = unique(c(
packages,
unlist(map(private$.callbacks, "packages")),
private$.loss$packages,
private$.optimizer$packages
))


assert_subset(properties, mlr_reflections$learner_properties[[task_type]])
properties = union(properties, "marshal")
assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]))
if (any(grepl("^(loss\\.|opt\\.|cb\\.)", param_set$ids()))) {
stopf("Prefixes 'loss.', 'opt.', and 'cb.' are reserved for dynamically constructed parameters.")
Expand All @@ -148,7 +147,7 @@ LearnerTorch = R6Class("LearnerTorch",

paramset_torch = paramset_torchlearner(task_type)
if (param_set$length > 0) {
private$.param_set_base = ParamSetCollection$new(list(param_set, paramset_torch))
private$.param_set_base = ParamSetCollection$new(sets = list(param_set, paramset_torch))
} else {
private$.param_set_base = paramset_torch
}
Expand Down Expand Up @@ -202,9 +201,31 @@ LearnerTorch = R6Class("LearnerTorch",
if (length(e)) {
catn(str_indent("* Errors:", e))
}
},
#' @description
#' Marshal the learner.
#' @param ... (any)\cr
#' Additional parameters.
#' @return self
marshal = function(...) {
learner_marshal(.learner = self, ...)
},
#' @description
#' Unmarshal the learner.
#' @param ... (any)\cr
#' Additional parameters.
#' @return self
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
}
),
active = list(
#' @field marshaled (`logical(1)`)\cr
#' Whether the learner is marshaled.
marshaled = function(rhs) {
assert_ro_binding(rhs)
learner_marshaled(self)
},
#' @field network ([`nn_module()`][torch::nn_module])\cr
#' The network (only available after training).
network = function(rhs) {
Expand All @@ -218,9 +239,9 @@ LearnerTorch = R6Class("LearnerTorch",
#' The parameter set
param_set = function(rhs) {
if (is.null(private$.param_set)) {
private$.param_set = ParamSetCollection$new(c(
list(private$.param_set_base, private$.optimizer$param_set, private$.loss$param_set),
map(private$.callbacks, "param_set"))
private$.param_set = ParamSetCollection$new(sets = c(
list(private$.param_set_base, opt = private$.optimizer$param_set, loss = private$.loss$param_set),
set_names(map(private$.callbacks, "param_set"), sprintf("cb.%s", ids(private$.callbacks))))
)
}
private$.param_set
Expand Down Expand Up @@ -335,6 +356,27 @@ LearnerTorch = R6Class("LearnerTorch",
)
)

#' @export
marshal_model.learner_torch_state = function(model, inplace = FALSE, ...) {
# FIXME: optimizer and loss_fn
model$network = torch_serialize(model$network)
model$loss_fn = torch_serialize(model$loss_fn)
model$optimizer = torch_serialize(model$optimizer)

structure(list(
marshaled = model,
packages = "mlr3torch"
), class = c("learner_torch_state_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.learner_torch_state_marshaled = function(model, inplace = FALSE, device = "cpu", ...) {
model = model$marshaled
model$network = torch_load(model$network, device = device)
model$loss_fn = torch_load(model$loss_fn, device = device)
model$optimizer = torch_load(model$optimizer, device = device)
return(model)
}

deep_clone = function(self, private, super, name, value) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerTorchFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
properties = switch(task_type,
classif = c("twoclass", "multiclass", "missings", "featureless"),
regr = c("missings", "featureless")
classif = c("twoclass", "multiclass", "missings", "featureless", "marshal"),
regr = c("missings", "featureless", "marshal")
)
super$initialize(
id = paste0(task_type, ".torch_featureless"),
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerTorchMLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",
check_activation = crate(function(x) check_class(x, "nn_module"))
check_activation_args = crate(function(x) check_list(x, names = "unique"))
check_neurons = crate(function(x) check_integerish(x, any.missing = FALSE, lower = 1))
cechk_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L))
check_shape = crate(function(x) check_shape(x, null_ok = TRUE, len = 2L), .parent = topenv())

param_set = ps(
neurons = p_uty(tags = c("train", "predict"), custom_check = check_neurons),
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskPreprocTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ PipeOpTaskPreprocTorch = R6Class("PipeOpTaskPreprocTorch",
private$.rowwise = assert_flag(rowwise)
param_set = assert_param_set(param_set$clone(deep = TRUE))

param_set$add(ps(
param_set = c(param_set, ps(
stages = p_fct(c("train", "predict", "both"), tags = c("train", "required"))
))
param_set$set_values(stages = stages_init)
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpTorchAvgPool.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ PipeOpTorchAvgPool = R6Class("PipeOpTorchAvgPool",
count_include_pad = p_lgl(default = TRUE, tags = "train")
)
if (d >= 2L) {
param_set$add(ParamDbl$new("divisor_override", default = NULL, lower = 0, tags = "train", special_vals = list(NULL)))
param_set = c(param_set, ps(
divisor_override = p_dbl(default = NULL, lower = 0, tags = "train", special_vals = list(NULL))
))
}

super$initialize(
Expand Down
9 changes: 4 additions & 5 deletions R/PipeOpTorchCallbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,16 @@ PipeOpTorchCallbacks = R6Class("PipeOpTorchCallbacks",
cbids = ids(private$.callbacks)
assert_names(cbids, type = "unique")
walk(private$.callbacks, function(cb) {
cb$param_set$set_id = cb$id
walk(cb$param_set$params, function(p) {
p$tags = union(p$tags, "train")
})
if (length(cb$param_set$tags)) {
cb$param_set$tags = map(cb$param_set$tags, function(tags) union(tags, "train"))
}
})
private$.callbacks = set_names(private$.callbacks, cbids)
input = data.table(name = "input", train = "ModelDescriptor", predict = "Task")
output = data.table(name = "output", train = "ModelDescriptor", predict = "Task")
super$initialize(
id = id,
param_set = alist(invoke(ParamSetCollection$new, sets = map(private$.callbacks, "param_set"))),
param_set = alist(ParamSetCollection$new(sets = map(private$.callbacks, "param_set"))),
param_vals = param_vals,
input = input,
output = output,
Expand Down
2 changes: 1 addition & 1 deletion R/TorchCallback.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ as_torch_callbacks.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`.
#'
#' @family Callback
#' @family Torch Descriptor
Expand Down
5 changes: 4 additions & 1 deletion R/TorchDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ TorchDescriptor = R6Class("TorchDescriptor",
self$generator = generator
# TODO: Assert that all parameters are tagged with "train"
self$param_set = assert_r6(param_set, "ParamSet", null.ok = TRUE) %??% inferps(generator)
walk(self$param_set$params, function(param) param$tags = union(param$tags, "train"))
if (length(self$param_set$tags)) {
self$param_set$tags = map(self$param_set$tags, function(tags) union(tags, "train"))

}
if (is.function(generator)) {
args = formalArgs(generator)
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/TorchLoss.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ as_torch_loss.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type `ParamUty`.
#'
#' @family Torch Descriptor
#' @export
Expand Down
6 changes: 3 additions & 3 deletions R/TorchOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ as_torch_optimizer.character = function(x, clone = FALSE, ...) { # nolint
#' @section Parameters:
#' Defined by the constructor argument `param_set`.
#' If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
#' for each argument of the wrapped loss function, where the parametes are then of type [`ParamUty`].
#' for each argument of the wrapped loss function, where the parametes are then of type [ParamUty][`paradox::p_uty`].
#'
#' @family Torch Descriptor
#' @export
Expand Down Expand Up @@ -278,7 +278,7 @@ mlr3torch_optimizers$add("sgd",
mlr3torch_optimizers$add("asgd",
function() {
p = ps(
lr = p_dbl(default = 1e-2, lower = 0, tags = c("required", "train")),
lr = p_dbl(default = 1e-2, lower = 0, tags = "train"),
lambda = p_dbl(lower = 0, upper = 1, default = 1e-4, tags = "train"),
alpha = p_dbl(lower = 0, upper = Inf, default = 0.75, tags = "train"),
t0 = p_int(lower = 1L, upper = Inf, default = 1e6, tags = "train"),
Expand Down Expand Up @@ -308,7 +308,7 @@ mlr3torch_optimizers$add("rprop",
p = ps(
lr = p_dbl(default = 0.01, lower = 0, tags = "train"),
etas = p_uty(default = c(0.5, 1.2), tags = "train"),
step_sizes = p_uty(c(1e-06, 50), tags = "train")
step_sizes = p_uty(default = c(1e-06, 50), tags = "train")
)
TorchOptimizer$new(
torch_optimizer = torch::optim_rprop,
Expand Down
7 changes: 4 additions & 3 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ learner_torch_train = function(self, private, super, task, param_vals) {

# In case the seed was "random" initially we want to make the sampled seed available in the state.
model$seed = param_vals$seed
return(model)

structure(model, class = c("learner_torch_state", "list"))
}


Expand Down Expand Up @@ -166,8 +167,8 @@ train_loop = function(ctx, cbs) {
# The seed is added later
list(
network = ctx$network,
loss_fn = ctx$loss_fn,
optimizer = ctx$optimizer,
loss_fn = ctx$loss_fn$state_dict(),
optimizer = ctx$optimizer$state_dict(),
callbacks = cbs
)
}
Expand Down
2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentCenterCrop.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentColorJitter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentCrop.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentHflip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/PipeOpPreprocTorchAugmentRandomAffine.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentRandomApply.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/PipeOpPreprocTorchAugmentRandomChoice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 131db14

Please sign in to comment.