Skip to content

Commit

Permalink
feat(LearnerTorch): interop threads parameter (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Oct 18, 2024
1 parent 2efffc8 commit c6def4f
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 11 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3torch dev

* feat: Add parameter `num_interop_threads` to `LearnerTorch`

# mlr3torch 0.1.2

* Don't use deprecated `data_formats` anymore
Expand Down
6 changes: 4 additions & 2 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ LearnerTorch = R6Class("LearnerTorch",
param_vals$device = auto_device(param_vals$device)
if (identical(param_vals$seed, "random")) param_vals$seed = sample.int(.Machine$integer.max, 1)

model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, {
model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads,
num_interop_threads = param_vals$num_interop_threads, expr = {
learner_torch_train(self, private, super, task, param_vals)
})
model$task_col_info = copy(task$col_info[c(task$feature_names, task$target_names), c("id", "type", "levels")])
Expand All @@ -453,7 +454,8 @@ LearnerTorch = R6Class("LearnerTorch",
param_vals$device = auto_device(param_vals$device)
private$.verify_predict_task(task, param_vals)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads,
num_interop_threads = param_vals$num_interop_threads, expr = {
learner_torch_predict(self, private, super, task, param_vals)
})
},
Expand Down
1 change: 1 addition & 0 deletions R/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ paramset_torchlearner = function(task_type) {
aggr = epochs_aggr, in_tune_fn = epochs_tune_fn, disable_in_tune = list(patience = 0)),
device = p_fct(tags = c("train", "predict", "required"), levels = mlr_reflections$torch$devices, init = "auto"),
num_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required"), init = 1L),
seed = p_int(tags = c("train", "predict", "required"), special_vals = list("random", NULL), init = "random"),
# evaluation
eval_freq = p_int(lower = 1L, tags = c("train", "required"), init = 1L),
Expand Down
17 changes: 12 additions & 5 deletions R/with_torch_settings.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
with_torch_settings = function(seed, num_threads = 1, expr) {
with_torch_settings = function(seed, num_threads = 1, num_interop_threads = 1, expr) {
old_num_threads = torch_get_num_threads()
if (!running_on_mac()) {
if (running_on_mac()) {
if (!isTRUE(all.equal(num_threads, 1L))) {
lg$warn("Cannot set number of threads on macOS.")
}
} else {
on.exit({torch_set_num_threads(old_num_threads)},
add = TRUE
)
torch_set_num_threads(num_threads)
} else {
if (!isTRUE(all.equal(num_threads, 1L))) {
lg$warn("Cannot set number of threads on macOS.")
}

if (num_interop_threads != torch_get_num_interop_threads()) {
result = try(torch::torch_set_num_interop_threads(num_interop_threads), silent = TRUE)
if (inherits(result, "try-error")) {
lg$warn(sprintf("Can only set the interop threads once, keeping the previous value %s", torch_get_num_interop_threads()))
}
}
# sets the seed back when exiting the function
Expand Down
4 changes: 4 additions & 0 deletions man-roxygen/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#' * `num_threads` :: `integer(1)`\cr
#' The number of threads for intraop pararallelization (if `device` is `"cpu"`).
#' This value is initialized to 1.
#' * `num_interop_threads` :: `integer(1)`\cr
#' The number of threads for intraop and interop pararallelization (if `device` is `"cpu"`).
#' This value is initialized to 1.
#' Note that this can only be set once during a session and changing the value within an R session will raise a warning.
#' * `seed` :: `integer(1)` or `"random"` or `NULL`\cr
#' The torch seed that is used during training and prediction.
#' This value is initialized to `"random"`, which means that a random seed will be sampled at the beginning of the
Expand Down
4 changes: 4 additions & 0 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/mlr_pipeops_torch_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 21 additions & 4 deletions tests/testthat/test_with_torch_settings.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ test_that("with_torch_settings leaves global state untouched", {
}
prev_torch_rng_state = torch_get_rng_state()

with_torch_settings(1, 1, {
with_torch_settings(1, 1, 1, {
y1 = torch_randn(1)
})

with_torch_settings(1, 1, {
with_torch_settings(1, 1, 1, {
y2 = torch_randn(1)
})

Expand All @@ -28,13 +28,30 @@ test_that("with_torch_settings leaves global state untouched", {
# (This would happen if we did not set the seed afterwards back to the previous value)

withr::with_seed(10, {
with_torch_settings(seed = 1, num_threads = 1, NULL)
with_torch_settings(seed = 1, num_threads = 1, num_interop_threads = 1, NULL)
at = torch_randn(1)
})

withr::with_seed(20, {
with_torch_settings(seed = 1, num_threads = 1, NULL)
with_torch_settings(seed = 1, num_threads = 1, num_interop_threads = 1, NULL)
bt = torch_randn(1)
})
expect_false(torch_equal(at, bt))
})

test_that("interop threads proper warning message", {
skip_if_not_installed("callr")
# otherwise capture.output does for some reason not capture the warning message
skip_if(!running_on_mac())

result = callr::r(function() {
library(torch)
with_torch_settings = getFromNamespace("with_torch_settings", "mlr3torch")
with_torch_settings(NULL, 1, 2, invisible(NULL))
x1 = capture.output(with_torch_settings(NULL, 1, 2, invisible(NULL)))
x2 = capture.output(with_torch_settings(NULL, 1, 1, invisible(NULL)))
list(x1, x2)
})
expect_true(length(result[[1]]) == 0)
expect_true(grepl("keeping the previous value 2", result[[2]], fixed = TRUE))
})

0 comments on commit c6def4f

Please sign in to comment.