Skip to content

Commit

Permalink
fix: deep-cloning, callback states, hashes(#198)
Browse files Browse the repository at this point in the history
* fixes the deep clone methods of objects
* callbacks now store a state and not themselves in the learner's model
* fix some hashes
* some other smaller improvements and fixes
  • Loading branch information
sebffischer authored Jun 14, 2024
1 parent 1a77da8 commit d41d116
Show file tree
Hide file tree
Showing 54 changed files with 846 additions and 506 deletions.
14 changes: 8 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ Description: Deep Learning library that extends the mlr3 framework by building
License: LGPL (>= 3)
Depends:
mlr3 (>= 0.19.0),
mlr3pipelines (>= 0.5.2),
mlr3pipelines,
torch (>= 0.13.0),
R (>= 3.5.0)
Imports:
backports,
checkmate (>= 2.2.0),
coro,
data.table,
lgr,
mlr3misc (>= 0.14.0),
methods,
data.table,
paradox (>= 0.11.0),
mlr3misc (>= 0.14.0),
paradox (>= 1.0.0),
R6,
withr
Suggests:
Expand All @@ -60,10 +60,12 @@ Suggests:
rmarkdown,
rpart,
viridis,
testthat (>= 3.0.0),
torchvision,
testthat (>= 3.0.0)
waldo
Remotes:
mlr-org/paradox,
mlr-org/mlr3,
mlr-org/mlr3pipelines,
mlverse/torchvision
Config/testthat/edition: 3
NeedsCompilation: no
Expand Down
6 changes: 4 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ S3method(as_torch_optimizer,torch_optimizer_generator)
S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
S3method(format,lazy_tensor)
S3method(hash_input,TorchIngressToken)
S3method(hash_input,lazy_tensor)
S3method(marshal_model,learner_torch_state)
S3method(hash_input,nn_module)
S3method(marshal_model,learner_torch_model)
S3method(materialize,data.frame)
S3method(materialize,lazy_tensor)
S3method(materialize,list)
Expand All @@ -52,7 +54,7 @@ S3method(t_opt,"NULL")
S3method(t_opt,character)
S3method(t_opts,"NULL")
S3method(t_opts,character)
S3method(unmarshal_model,learner_torch_state_marshaled)
S3method(unmarshal_model,learner_torch_model_marshaled)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# mlr3torch 0.0.0-900
# mlr3torch dev
52 changes: 49 additions & 3 deletions R/CallbackSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@
#' This context is assigned at the beginning of the training loop and removed afterwards.
#' Different stages of a callback can communicate with each other by assigning values to `$self`.
#'
#' *State*:
#' To be able to store information in the `$model` slot of a [`LearnerTorch`], callbacks support a state API.
#' You can overload the `$state_dict()` public method to define what will be stored in `learner$model$callbacks$<id>`
#' after training finishes.
#' This then also requires to implement a `$load_state_dict(state_dict)` method that defines how to load a previously saved
#' callback state into a different callback.
#' Note that the `$state_dict()` should not include the parameter values that were used to initialize the callback.
#'
#' For creating custom callbacks, the function [`torch_callback()`] is recommended, which creates a
#' `CallbackSet` and then wraps it in a [`TorchCallback`].
#' To create a `CallbackSet` the convenience function [`callback_set()`] can be used.
#' These functions perform checks such as that the stages are not accidentally misspelled.
#'
#'
#' @section Stages:
#' * `begin` :: Run before the training loop begins.
#' * `epoch_begin` :: Run he beginning of each epoch.
Expand All @@ -33,7 +42,8 @@
#' * `batch_valid_begin` :: Run before the forward call in the validation loop.
#' * `batch_valid_end` :: Run after the forward call in the validation loop.
#' * `epoch_end` :: Run at the end of each epoch.
#' * `end` :: Run at last, using `on.exit()`.
#' * `end` :: Run after last epoch.
#' * `exit` :: Run at last, using `on.exit()`.
#' @family Callback
#' @export
CallbackSet = R6Class("CallbackSet",
Expand All @@ -50,6 +60,21 @@ CallbackSet = R6Class("CallbackSet",
print = function(...) {
catn(sprintf("<%s>", class(self)[[1L]]))
catn(str_indent("* Stages:", self$stages))
},
#' @description
#' Returns information that is kept in the the [`LearnerTorch`]'s state after training.
#' This information should be loadable into the callback using `$load_state_dict()` to be able to continue training.
#' This returns `NULL` by default.
state_dict = function() {
NULL
},
#' @description
#' Loads the state dict into the callback to continue training.
#' @param state_dict (any)\cr
#' The state dict as retrieved via `$state_dict()`.
load_state_dict = function(state_dict) {
assert_true(is.null(state_dict))
NULL
}
),
active = list(
Expand All @@ -71,6 +96,10 @@ CallbackSet = R6Class("CallbackSet",
deep_clone = function(name, value) {
if (name == "ctx" && !is.null(value)) {
stopf("CallbackSet instances can only be cloned when the 'ctx' is NULL.")
} else if (is.R6(value)) {
value$clone(deep = TRUE)
} else if (is.data.table(value)) {
copy(value)
} else {
value
}
Expand All @@ -90,7 +119,7 @@ CallbackSet = R6Class("CallbackSet",
#'
#' @param classname (`character(1)`)\cr
#' The class name.
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end (`function`)\cr
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end,on_exit (`function`)\cr
#' Function to execute at the given stage, see section *Stages*.
#' @param initialize (`function()`)\cr
#' The initialization method of the callback.
Expand All @@ -101,6 +130,11 @@ CallbackSet = R6Class("CallbackSet",
#' @param inherit (`R6ClassGenerator`)\cr
#' From which class to inherit.
#' This class must either be [`CallbackSet`] (default) or inherit from it.
#' @param state_dict (`function()`)\cr
#' The function that retrieves the state dict from the callback.
#' This is what will be available in the learner after training.
#' @param load_state_dict (`function(state_dict)`)\cr
#' Function that loads a callback state.
#' @param lock_objects (`logical(1)`)\cr
#' Whether to lock the objects of the resulting [`R6Class`].
#' If `FALSE` (default), values can be freely assigned to `self` without declaring them in the
Expand All @@ -115,6 +149,7 @@ callback_set = function(
# training
on_begin = NULL,
on_end = NULL,
on_exit = NULL,
on_epoch_begin = NULL,
on_before_valid = NULL,
on_epoch_end = NULL,
Expand All @@ -125,11 +160,16 @@ callback_set = function(
on_batch_valid_begin = NULL,
on_batch_valid_end = NULL,
# other methods
state_dict = NULL,
load_state_dict = NULL,
initialize = NULL,
public = NULL, private = NULL, active = NULL, parent_env = parent.frame(), inherit = CallbackSet,
lock_objects = FALSE
) {
assert_true(startsWith(classname, "CallbackSet"))
assert_false(xor(is.null(state_dict), is.null(load_state_dict)))
assert_function(state_dict, nargs = 0, null.ok = TRUE)
assert_function(load_state_dict, args = "state_dict", nargs = 1, null.ok = TRUE)
more_public = list(
on_begin = assert_function(on_begin, nargs = 0, null.ok = TRUE),
on_end = assert_function(on_end, nargs = 0, null.ok = TRUE),
Expand All @@ -140,7 +180,8 @@ callback_set = function(
on_batch_end = assert_function(on_batch_end, nargs = 0, null.ok = TRUE),
on_after_backward = assert_function(on_after_backward, nargs = 0, null.ok = TRUE),
on_batch_valid_begin = assert_function(on_batch_valid_begin, nargs = 0, null.ok = TRUE),
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE)
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE),
on_exit = assert_function(on_exit, nargs = 0, null.ok = TRUE)
)

assert_function(initialize, null.ok = TRUE)
Expand All @@ -153,6 +194,11 @@ callback_set = function(
assert_list(public, null.ok = TRUE, names = "unique")
if (length(public)) assert_names(names(public), disjunct.from = names(more_public))

if (!is.null(state_dict)) {
public$state_dict = state_dict
public$load_state_dict = load_state_dict
}

invalid_stages = names(public)[grepl("^on_", names(public))]

if (length(invalid_stages)) {
Expand Down
77 changes: 58 additions & 19 deletions R/CallbackSetCheckpoint.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
#' @name mlr_callback_set.checkpoint
#'
#' @description
#' Saves the model during training.
#' Saves the optimizer and network states during training.
#' The final network and optimizer are always stored.
#' @details
#' Saving the learner itself in the callback with a trained model is impossible,
#' as the model slot is set *after* the last callback step is executed.
#'
#' @param path (`character(1)`)\cr
#' The path to a folder where the models are saved. This path must not exist before.
#' The path to a folder where the models are saved.
#' @param freq (`integer(1)`)\cr
#' The frequency how often the model is saved (epoch frequency).
#'
#' The frequency how often the model is saved.
#' Frequency is either per step or epoch, which can be configured through the `freq_type` parameter.
#' @param freq_type (`character(1)`)\cr
#' Can be be either `"epoch"` (default) or `"step"`.
#' @family Callback
#' @export
#' @include CallbackSet.R
Expand All @@ -19,27 +26,58 @@ CallbackSetCheckpoint = R6Class("CallbackSetCheckpoint",
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(path, freq) {
# TODO: Maybe we want to be able to give gradient steps here instead of epochs?
assert_path_for_output(path)
dir.create(path, recursive = TRUE)
self$path = path
initialize = function(path, freq, freq_type = "epoch") {
self$freq = assert_int(freq, lower = 1L)
self$path = assert_path_for_output(path)
self$freq_type = assert_choice(freq_type, c("epoch", "step"))
if (!dir.exists(path)) {
dir.create(path, recursive = TRUE)
}
},
#' @description
#' Saves the network state dict.
#' Saves the network and optimizer state dict.
#' Does nothing if `freq_type` or `freq` are not met.
on_epoch_end = function() {
if ((self$ctx$epoch %% self$freq) == 0) {
torch::torch_save(self$ctx$network, file.path(self$path, paste0("network", self$ctx$epoch, ".pt")))
if (self$freq_type == "step" || (self$ctx$epoch %% self$freq != 0)) {
return(NULL)
}
private$.save(self$ctx$epoch)
},
#' @description
#' Saves the selected objects defined in `save`.
#' Does nothing if freq_type or freq are not met.
on_batch_end = function() {
if (self$freq_type == "epoch" || (self$ctx$step %% self$freq != 0)) {
return(NULL)
}
private$.save(self$ctx$step)
},
#' @description
#' Saves the final network.
on_end = function() {
path = file.path(self$path, paste0("network", self$ctx$epoch, ".pt"))
if (!file.exists(path)) { # no need to save the last network twice if it was already saved.
torch::torch_save(self$ctx$network, path)
#' Saves the learner.
on_exit = function() {
if (self$ctx$epoch == 0) return(NULL)
if (self$freq_type == "epoch") {
if (self$ctx$epoch %% self$freq == 0) {
# already saved
return(NULL)
} else {
private$.save(self$ctx$epoch)
}
}
if (self$freq_type == "step") {
if (self$ctx$step %% self$freq == 0) {
# already saved
return(NULL)
} else {
private$.save(self$ctx$epoch)
}
}
}
),
private = list(
.save = function(suffix) {
torch_save(self$ctx$network$state_dict(), file.path(self$path, paste0("network", suffix, ".pt")))
torch_save(self$ctx$optimizer$state_dict(), file.path(self$path, paste0("optimizer", suffix, ".pt")))
}
)
)
Expand All @@ -49,8 +87,9 @@ mlr3torch_callbacks$add("checkpoint", function() {
TorchCallback$new(
callback_generator = CallbackSetCheckpoint,
param_set = ps(
path = p_uty(),
freq = p_int(lower = 1L)
path = p_uty(tags = c("train", "required")),
freq = p_int(lower = 1L, tags = c("train", "required")),
freq_type = p_fct(default = "epoch", c("epoch", "step"), tags = "train")
),
id = "checkpoint",
label = "Checkpoint",
Expand Down
69 changes: 14 additions & 55 deletions R/CallbackSetHistory.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,20 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
},
#' @description
#' Converts the lists to data.tables.
on_end = function() {
self$train = rbindlist(self$train, fill = TRUE)
self$valid = rbindlist(self$valid, fill = TRUE)
state_dict = function() {
structure(list(
train = rbindlist(self$train, fill = TRUE),
valid = rbindlist(self$valid, fill = TRUE)
), class = "callback_state_history")
},
#' @description
#' Sets the field `$train` and `$valid` to those contained in the state dict.
#' @param state_dict (`callback_state_history`)\cr
#' The state dict as retrieved via `$state_dict()`.
load_state_dict = function(state_dict) {
assert_class(state_dict, "callback_state_history")
self$train = state_dict$train
self$valid = state_dict$valid
},
#' @description
#' Add the latest training scores to the history.
Expand All @@ -42,58 +53,6 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
list(epoch = self$ctx$epoch), self$ctx$last_scores_valid
)
}
},
#' @description Plots the history.
#' @param measures (`character()`)\cr
#' Which measures to plot. No default.
#' @param set (`character(1)`)\cr
#' Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
#' @param epochs (`integer()`)\cr
#' An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
#' @param theme ([ggplot2::theme()])\cr
#' The theme, [ggplot2::theme_minimal()] is the default.
#' @param ... (any)\cr
#' Currently unused.
plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
assert_choice(set, c("valid", "train"))
data = self[[set]]
assert_subset(measures, colnames(data))

if (is.null(epochs)) {
data = data[, c("epoch", measures), with = FALSE]
} else {
assert_integerish(epochs, unique = TRUE)
data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
}

if ((!nrow(data)) || (ncol(data) < 2)) {
stopf("No eligible measures to plot for set '%s'.", set)
}

epoch = score = measure = .data = NULL
if (ncol(data) == 2L) {
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
x = "Epoch",
y = measures,
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
) +
theme
} else {
data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
viridis::scale_color_viridis(discrete = TRUE) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
x = "Epoch",
y = "Score",
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
) +
theme
}
}
),
private = list(
Expand Down
6 changes: 3 additions & 3 deletions R/ContextTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ ContextTorch = R6Class("ContextTorch",
#' @field epoch (`integer(1)`)\cr
#' The current epoch.
epoch = NULL,
#' @field batch (`integer(1)`)\cr
#' The current iteration of the batch.
batch = NULL,
#' @field step (`integer(1)`)\cr
#' The current iteration.
step = NULL,
#' @field prediction_encoder (`function()`)\cr
#' The learner's prediction encoder.
prediction_encoder = NULL
Expand Down
Loading

0 comments on commit d41d116

Please sign in to comment.