From c6def4f33e305e1cc76556d859274f19715d8ffb Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 18 Oct 2024 08:38:36 +0200 Subject: [PATCH] feat(LearnerTorch): interop threads parameter (#295) --- NEWS.md | 4 ++++ R/LearnerTorch.R | 6 ++++-- R/paramset_torchlearner.R | 1 + R/with_torch_settings.R | 17 ++++++++++----- man-roxygen/paramset_torchlearner.R | 4 ++++ man/mlr_learners_torch.Rd | 4 ++++ man/mlr_pipeops_torch_model.Rd | 4 ++++ tests/testthat/test_with_torch_settings.R | 25 +++++++++++++++++++---- 8 files changed, 54 insertions(+), 11 deletions(-) diff --git a/NEWS.md b/NEWS.md index 59454838..758bcec3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# mlr3torch dev + +* feat: Add parameter `num_interop_threads` to `LearnerTorch` + # mlr3torch 0.1.2 * Don't use deprecated `data_formats` anymore diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index f1930015..ce2248ab 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -433,7 +433,8 @@ LearnerTorch = R6Class("LearnerTorch", param_vals$device = auto_device(param_vals$device) if (identical(param_vals$seed, "random")) param_vals$seed = sample.int(.Machine$integer.max, 1) - model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, { + model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, + num_interop_threads = param_vals$num_interop_threads, expr = { learner_torch_train(self, private, super, task, param_vals) }) model$task_col_info = copy(task$col_info[c(task$feature_names, task$target_names), c("id", "type", "levels")]) @@ -453,7 +454,8 @@ LearnerTorch = R6Class("LearnerTorch", param_vals$device = auto_device(param_vals$device) private$.verify_predict_task(task, param_vals) - with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, { + with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, + num_interop_threads = param_vals$num_interop_threads, expr = { learner_torch_predict(self, private, super, task, param_vals) }) }, diff --git a/R/paramset_torchlearner.R b/R/paramset_torchlearner.R index be034347..4acfb80b 100644 --- a/R/paramset_torchlearner.R +++ b/R/paramset_torchlearner.R @@ -54,6 +54,7 @@ paramset_torchlearner = function(task_type) { aggr = epochs_aggr, in_tune_fn = epochs_tune_fn, disable_in_tune = list(patience = 0)), device = p_fct(tags = c("train", "predict", "required"), levels = mlr_reflections$torch$devices, init = "auto"), num_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L), + num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required"), init = 1L), seed = p_int(tags = c("train", "predict", "required"), special_vals = list("random", NULL), init = "random"), # evaluation eval_freq = p_int(lower = 1L, tags = c("train", "required"), init = 1L), diff --git a/R/with_torch_settings.R b/R/with_torch_settings.R index bd7f965b..7695348a 100644 --- a/R/with_torch_settings.R +++ b/R/with_torch_settings.R @@ -1,13 +1,20 @@ -with_torch_settings = function(seed, num_threads = 1, expr) { +with_torch_settings = function(seed, num_threads = 1, num_interop_threads = 1, expr) { old_num_threads = torch_get_num_threads() - if (!running_on_mac()) { + if (running_on_mac()) { + if (!isTRUE(all.equal(num_threads, 1L))) { + lg$warn("Cannot set number of threads on macOS.") + } + } else { on.exit({torch_set_num_threads(old_num_threads)}, add = TRUE ) torch_set_num_threads(num_threads) - } else { - if (!isTRUE(all.equal(num_threads, 1L))) { - lg$warn("Cannot set number of threads on macOS.") + } + + if (num_interop_threads != torch_get_num_interop_threads()) { + result = try(torch::torch_set_num_interop_threads(num_interop_threads), silent = TRUE) + if (inherits(result, "try-error")) { + lg$warn(sprintf("Can only set the interop threads once, keeping the previous value %s", torch_get_num_interop_threads())) } } # sets the seed back when exiting the function diff --git a/man-roxygen/paramset_torchlearner.R b/man-roxygen/paramset_torchlearner.R index b9e32fd9..cf5d7505 100644 --- a/man-roxygen/paramset_torchlearner.R +++ b/man-roxygen/paramset_torchlearner.R @@ -14,6 +14,10 @@ #' * `num_threads` :: `integer(1)`\cr #' The number of threads for intraop pararallelization (if `device` is `"cpu"`). #' This value is initialized to 1. +#' * `num_interop_threads` :: `integer(1)`\cr +#' The number of threads for intraop and interop pararallelization (if `device` is `"cpu"`). +#' This value is initialized to 1. +#' Note that this can only be set once during a session and changing the value within an R session will raise a warning. #' * `seed` :: `integer(1)` or `"random"` or `NULL`\cr #' The torch seed that is used during training and prediction. #' This value is initialized to `"random"`, which means that a random seed will be sampled at the beginning of the diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 5797f246..ebdfe53d 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -70,6 +70,10 @@ fall back to \code{"cpu"}. \item \code{num_threads} :: \code{integer(1)}\cr The number of threads for intraop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. +\item \code{num_interop_threads} :: \code{integer(1)}\cr +The number of threads for intraop and interop pararallelization (if \code{device} is \code{"cpu"}). +This value is initialized to 1. +Note that this can only be set once during a session and changing the value within an R session will raise a warning. \item \code{seed} :: \code{integer(1)} or \code{"random"} or \code{NULL}\cr The torch seed that is used during training and prediction. This value is initialized to \code{"random"}, which means that a random seed will be sampled at the beginning of the diff --git a/man/mlr_pipeops_torch_model.Rd b/man/mlr_pipeops_torch_model.Rd index bad5df39..7cac5719 100644 --- a/man/mlr_pipeops_torch_model.Rd +++ b/man/mlr_pipeops_torch_model.Rd @@ -37,6 +37,10 @@ fall back to \code{"cpu"}. \item \code{num_threads} :: \code{integer(1)}\cr The number of threads for intraop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. +\item \code{num_interop_threads} :: \code{integer(1)}\cr +The number of threads for intraop and interop pararallelization (if \code{device} is \code{"cpu"}). +This value is initialized to 1. +Note that this can only be set once during a session and changing the value within an R session will raise a warning. \item \code{seed} :: \code{integer(1)} or \code{"random"} or \code{NULL}\cr The torch seed that is used during training and prediction. This value is initialized to \code{"random"}, which means that a random seed will be sampled at the beginning of the diff --git a/tests/testthat/test_with_torch_settings.R b/tests/testthat/test_with_torch_settings.R index 32cdc326..f7c50720 100644 --- a/tests/testthat/test_with_torch_settings.R +++ b/tests/testthat/test_with_torch_settings.R @@ -9,11 +9,11 @@ test_that("with_torch_settings leaves global state untouched", { } prev_torch_rng_state = torch_get_rng_state() - with_torch_settings(1, 1, { + with_torch_settings(1, 1, 1, { y1 = torch_randn(1) }) - with_torch_settings(1, 1, { + with_torch_settings(1, 1, 1, { y2 = torch_randn(1) }) @@ -28,13 +28,30 @@ test_that("with_torch_settings leaves global state untouched", { # (This would happen if we did not set the seed afterwards back to the previous value) withr::with_seed(10, { - with_torch_settings(seed = 1, num_threads = 1, NULL) + with_torch_settings(seed = 1, num_threads = 1, num_interop_threads = 1, NULL) at = torch_randn(1) }) withr::with_seed(20, { - with_torch_settings(seed = 1, num_threads = 1, NULL) + with_torch_settings(seed = 1, num_threads = 1, num_interop_threads = 1, NULL) bt = torch_randn(1) }) expect_false(torch_equal(at, bt)) }) + +test_that("interop threads proper warning message", { + skip_if_not_installed("callr") + # otherwise capture.output does for some reason not capture the warning message + skip_if(!running_on_mac()) + + result = callr::r(function() { + library(torch) + with_torch_settings = getFromNamespace("with_torch_settings", "mlr3torch") + with_torch_settings(NULL, 1, 2, invisible(NULL)) + x1 = capture.output(with_torch_settings(NULL, 1, 2, invisible(NULL))) + x2 = capture.output(with_torch_settings(NULL, 1, 1, invisible(NULL))) + list(x1, x2) + }) + expect_true(length(result[[1]]) == 0) + expect_true(grepl("keeping the previous value 2", result[[2]], fixed = TRUE)) +})