Skip to content

Commit

Permalink
feat: add nn convenience function
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Aug 21, 2024
1 parent 31dee34 commit 5795a2a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 5 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Collate:
'learner_torch_methods.R'
'materialize.R'
'merge_graphs.R'
'nn.R'
'nn_graph.R'
'paramset_torchlearner.R'
'preprocess.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ export(mlr3torch_optimizers)
export(model_descriptor_to_learner)
export(model_descriptor_to_module)
export(model_descriptor_union)
export(nn)
export(nn_graph)
export(nn_merge_cat)
export(nn_merge_prod)
Expand Down
13 changes: 13 additions & 0 deletions R/nn.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#' @title Create a Neural Network Layer
#' @description
#' Retrieve a neural network layer from the
#' [`mlr_pipeops`][mlr3pipelines::mlr_pipeops] dictionary.
#' @param .key (`character(1)`)\cr
#' @export
#' @examples
#' po1 = po("nn_linear", id = "linear")
#' # is the same as:
#' po2 = nn("linear")
nn = function(.key, ...) {
invoke(po, .obj = paste0("nn_", .key), .args = insert_named(list(id = .key), list(...)))
}
20 changes: 20 additions & 0 deletions man/nn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions nn.R

This file was deleted.

5 changes: 4 additions & 1 deletion tests/testthat/test_nn.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
test_that("nn works", {

x = nn("linear", out_features = 3)
expect_equal(x$id, "linear")
expect_class(x, "PipeOpTorchLinear")
expect_equal(x$param_set$values$out_features, 3)
})

0 comments on commit 5795a2a

Please sign in to comment.