diff --git a/R/CallbackSet.R b/R/CallbackSet.R index 79dd6e1d..187359a0 100644 --- a/R/CallbackSet.R +++ b/R/CallbackSet.R @@ -42,12 +42,35 @@ CallbackSet = R6Class("CallbackSet", #' @field ctx ([`ContextTorch`] or `NULL`)\cr #' The evaluation context for the callback. #' This field should always be `NULL` except during the `$train()` call of the torch learner. - ctx = NULL + ctx = NULL, + #' @description + #' Prints the object. + #' @param ... (any)\cr + #' Currently unused. + print = function(...) { + catn(sprintf("<%s>", class(self)[[1L]])) + catn(str_indent("* Stages:", self$stages)) + } + ), + active = list( + #' @field stages (`character()`)\cr + #' The active stages of this callback set. + stages = function(rhs) { + assert_ro_binding(rhs) + if (is.null(private$.stages)) { + private$.stages = mlr_reflections$torch$callback_stages[ + map_lgl(mlr_reflections$torch$callback_stages, function(stage) exists(stage, self, inherits = FALSE))] + } + + private$.stages + } + ), private = list( + .stages = NULL, deep_clone = function(name, value) { if (name == "ctx" && !is.null(value)) { - stopf("CallbackSet instances must never be cloned unless the ctx is NULL.") + stopf("CallbackSet instances can only be cloned when the 'ctx' is NULL.") } else { value } diff --git a/man/mlr_callback_set.Rd b/man/mlr_callback_set.Rd index 3191efde..48dabbd4 100644 --- a/man/mlr_callback_set.Rd +++ b/man/mlr_callback_set.Rd @@ -66,13 +66,40 @@ This field should always be \code{NULL} except during the \verb{$train()} call o } \if{html}{\out{}} } +\section{Active bindings}{ +\if{html}{\out{
}} +\describe{ +\item{\code{stages}}{(\code{character()})\cr +The active stages of this callback set.} +} +\if{html}{\out{
}} +} \section{Methods}{ \subsection{Public methods}{ \itemize{ +\item \href{#method-CallbackSet-print}{\code{CallbackSet$print()}} \item \href{#method-CallbackSet-clone}{\code{CallbackSet$clone()}} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSet-print}{}}} +\subsection{Method \code{print()}}{ +Prints the object. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSet$print(...)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{...}}{(any)\cr +Currently unused.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSet-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/mlr_callback_set.checkpoint.Rd b/man/mlr_callback_set.checkpoint.Rd index a3cd1944..c22e7273 100644 --- a/man/mlr_callback_set.checkpoint.Rd +++ b/man/mlr_callback_set.checkpoint.Rd @@ -33,6 +33,13 @@ Other Callback: \item \href{#method-CallbackSetCheckpoint-clone}{\code{CallbackSetCheckpoint$clone()}} } } +\if{html}{\out{ +
Inherited methods + +
+}} \if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSetCheckpoint-new}{}}} diff --git a/man/mlr_callback_set.history.Rd b/man/mlr_callback_set.history.Rd index 3da31ae6..edcde46c 100644 --- a/man/mlr_callback_set.history.Rd +++ b/man/mlr_callback_set.history.Rd @@ -23,6 +23,13 @@ The first column is always \code{epoch}. \item \href{#method-CallbackSetHistory-clone}{\code{CallbackSetHistory$clone()}} } } +\if{html}{\out{ +
Inherited methods + +
+}} \if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSetHistory-on_begin}{}}} diff --git a/man/mlr_callback_set.progress.Rd b/man/mlr_callback_set.progress.Rd index e86fc697..88802a37 100644 --- a/man/mlr_callback_set.progress.Rd +++ b/man/mlr_callback_set.progress.Rd @@ -36,6 +36,13 @@ Other Callback: \item \href{#method-CallbackSetProgress-clone}{\code{CallbackSetProgress$clone()}} } } +\if{html}{\out{ +
Inherited methods + +
+}} \if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSetProgress-on_epoch_begin}{}}} diff --git a/man/mlr_learners.alexnet.Rd b/man/mlr_learners.alexnet.Rd index 78ecd032..87e70654 100644 --- a/man/mlr_learners.alexnet.Rd +++ b/man/mlr_learners.alexnet.Rd @@ -78,11 +78,13 @@ Other Learner:
Inherited methods diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index d443563b..86c71288 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -106,11 +106,13 @@ Other Learner:
Inherited methods diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd index 12855c68..c1d64214 100644 --- a/man/mlr_learners.torch_featureless.Rd +++ b/man/mlr_learners.torch_featureless.Rd @@ -95,11 +95,13 @@ Other Learner:
Inherited methods diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 6975304b..6a044155 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -132,11 +132,13 @@ Shortcut for \code{learner$model$callbacks$history}.}
Inherited methods
}} diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd index 033e94a5..bb090266 100644 --- a/man/mlr_learners_torch_image.Rd +++ b/man/mlr_learners_torch_image.Rd @@ -55,11 +55,13 @@ Other Learner:
Inherited methods diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd index fc1f816e..a738f488 100644 --- a/man/mlr_learners_torch_model.Rd +++ b/man/mlr_learners_torch_model.Rd @@ -77,11 +77,13 @@ Other Graph Network:
Inherited methods diff --git a/tests/testthat/helper_autotest.R b/tests/testthat/helper_autotest.R index e5c65832..665a643f 100644 --- a/tests/testthat/helper_autotest.R +++ b/tests/testthat/helper_autotest.R @@ -301,5 +301,5 @@ autotest_torch_callback = function(torch_callback, check_man = TRUE) { expect_class(cb_trained, "CallbackSet") expect_deep_clone(cb_trained, cb_trained$clone(deep = TRUE)) cb_trained$ctx = "placeholder" - expect_error(cb_trained$clone(deep = TRUE), "must never be cloned unless") + expect_error(cb_trained$clone(deep = TRUE), "can only be cloned") }