diff --git a/R/CallbackSet.R b/R/CallbackSet.R
index 79dd6e1d..187359a0 100644
--- a/R/CallbackSet.R
+++ b/R/CallbackSet.R
@@ -42,12 +42,35 @@ CallbackSet = R6Class("CallbackSet",
#' @field ctx ([`ContextTorch`] or `NULL`)\cr
#' The evaluation context for the callback.
#' This field should always be `NULL` except during the `$train()` call of the torch learner.
- ctx = NULL
+ ctx = NULL,
+ #' @description
+ #' Prints the object.
+ #' @param ... (any)\cr
+ #' Currently unused.
+ print = function(...) {
+ catn(sprintf("<%s>", class(self)[[1L]]))
+ catn(str_indent("* Stages:", self$stages))
+ }
+ ),
+ active = list(
+ #' @field stages (`character()`)\cr
+ #' The active stages of this callback set.
+ stages = function(rhs) {
+ assert_ro_binding(rhs)
+ if (is.null(private$.stages)) {
+ private$.stages = mlr_reflections$torch$callback_stages[
+ map_lgl(mlr_reflections$torch$callback_stages, function(stage) exists(stage, self, inherits = FALSE))]
+ }
+
+ private$.stages
+ }
+
),
private = list(
+ .stages = NULL,
deep_clone = function(name, value) {
if (name == "ctx" && !is.null(value)) {
- stopf("CallbackSet instances must never be cloned unless the ctx is NULL.")
+ stopf("CallbackSet instances can only be cloned when the 'ctx' is NULL.")
} else {
value
}
diff --git a/man/mlr_callback_set.Rd b/man/mlr_callback_set.Rd
index 3191efde..48dabbd4 100644
--- a/man/mlr_callback_set.Rd
+++ b/man/mlr_callback_set.Rd
@@ -66,13 +66,40 @@ This field should always be \code{NULL} except during the \verb{$train()} call o
}
\if{html}{\out{}}
}
+\section{Active bindings}{
+\if{html}{\out{
}}
+\describe{
+\item{\code{stages}}{(\code{character()})\cr
+The active stages of this callback set.}
+}
+\if{html}{\out{
}}
+}
\section{Methods}{
\subsection{Public methods}{
\itemize{
+\item \href{#method-CallbackSet-print}{\code{CallbackSet$print()}}
\item \href{#method-CallbackSet-clone}{\code{CallbackSet$clone()}}
}
}
\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-CallbackSet-print}{}}}
+\subsection{Method \code{print()}}{
+Prints the object.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{CallbackSet$print(...)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{...}}{(any)\cr
+Currently unused.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-CallbackSet-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/mlr_callback_set.checkpoint.Rd b/man/mlr_callback_set.checkpoint.Rd
index a3cd1944..c22e7273 100644
--- a/man/mlr_callback_set.checkpoint.Rd
+++ b/man/mlr_callback_set.checkpoint.Rd
@@ -33,6 +33,13 @@ Other Callback:
\item \href{#method-CallbackSetCheckpoint-clone}{\code{CallbackSetCheckpoint$clone()}}
}
}
+\if{html}{\out{
+Inherited methods
+
+
+}}
\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-CallbackSetCheckpoint-new}{}}}
diff --git a/man/mlr_callback_set.history.Rd b/man/mlr_callback_set.history.Rd
index 3da31ae6..edcde46c 100644
--- a/man/mlr_callback_set.history.Rd
+++ b/man/mlr_callback_set.history.Rd
@@ -23,6 +23,13 @@ The first column is always \code{epoch}.
\item \href{#method-CallbackSetHistory-clone}{\code{CallbackSetHistory$clone()}}
}
}
+\if{html}{\out{
+Inherited methods
+
+
+}}
\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-CallbackSetHistory-on_begin}{}}}
diff --git a/man/mlr_callback_set.progress.Rd b/man/mlr_callback_set.progress.Rd
index e86fc697..88802a37 100644
--- a/man/mlr_callback_set.progress.Rd
+++ b/man/mlr_callback_set.progress.Rd
@@ -36,6 +36,13 @@ Other Callback:
\item \href{#method-CallbackSetProgress-clone}{\code{CallbackSetProgress$clone()}}
}
}
+\if{html}{\out{
+Inherited methods
+
+
+}}
\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-CallbackSetProgress-on_epoch_begin}{}}}
diff --git a/man/mlr_learners.alexnet.Rd b/man/mlr_learners.alexnet.Rd
index 78ecd032..87e70654 100644
--- a/man/mlr_learners.alexnet.Rd
+++ b/man/mlr_learners.alexnet.Rd
@@ -78,11 +78,13 @@ Other Learner:
Inherited methods
diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd
index d443563b..86c71288 100644
--- a/man/mlr_learners.mlp.Rd
+++ b/man/mlr_learners.mlp.Rd
@@ -106,11 +106,13 @@ Other Learner:
Inherited methods
diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd
index 12855c68..c1d64214 100644
--- a/man/mlr_learners.torch_featureless.Rd
+++ b/man/mlr_learners.torch_featureless.Rd
@@ -95,11 +95,13 @@ Other Learner:
Inherited methods
diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd
index 6975304b..6a044155 100644
--- a/man/mlr_learners_torch.Rd
+++ b/man/mlr_learners_torch.Rd
@@ -132,11 +132,13 @@ Shortcut for \code{learner$model$callbacks$history}.}
Inherited methods
}}
diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd
index 033e94a5..bb090266 100644
--- a/man/mlr_learners_torch_image.Rd
+++ b/man/mlr_learners_torch_image.Rd
@@ -55,11 +55,13 @@ Other Learner:
Inherited methods
diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd
index fc1f816e..a738f488 100644
--- a/man/mlr_learners_torch_model.Rd
+++ b/man/mlr_learners_torch_model.Rd
@@ -77,11 +77,13 @@ Other Graph Network:
Inherited methods
diff --git a/tests/testthat/helper_autotest.R b/tests/testthat/helper_autotest.R
index e5c65832..665a643f 100644
--- a/tests/testthat/helper_autotest.R
+++ b/tests/testthat/helper_autotest.R
@@ -301,5 +301,5 @@ autotest_torch_callback = function(torch_callback, check_man = TRUE) {
expect_class(cb_trained, "CallbackSet")
expect_deep_clone(cb_trained, cb_trained$clone(deep = TRUE))
cb_trained$ctx = "placeholder"
- expect_error(cb_trained$clone(deep = TRUE), "must never be cloned unless")
+ expect_error(cb_trained$clone(deep = TRUE), "can only be cloned")
}