From f3acbec0a7b2981073cf877737db65374c135fcd Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 1 Jul 2024 13:30:25 +0200 Subject: [PATCH] feat: torchvision learners (#238) * wip * add torchvision learners * desc * ... * ... * ... * ... * fix example * ... * ... --- .github/workflows/r-cmd-check.yml | 37 +++ DESCRIPTION | 4 +- NAMESPACE | 7 +- R/LearnerTorchAlexNet.R | 61 ----- R/LearnerTorchVision.R | 180 ++++++++++++++ R/TaskClassif_lazy_iris.R | 2 +- R/TaskClassif_mnist.R | 2 +- R/TaskClassif_tiny_imagenet.R | 2 +- R/bibentries.R | 42 +++- R/utils.R | 11 + R/zzz.R | 6 +- man-roxygen/learner.R | 10 +- man-roxygen/learner_example.R | 3 +- man/mlr_learners.mlp.Rd | 18 +- man/mlr_learners.tab_resnet.Rd | 20 +- man/mlr_learners.torch_featureless.Rd | 18 +- ...alexnet.Rd => mlr_learners.torchvision.Rd} | 112 ++++----- man/mlr_learners_torch.Rd | 1 - man/mlr_learners_torch_image.Rd | 1 - man/mlr_learners_torch_model.Rd | 1 - man/mlr_tasks_lazy_iris.Rd | 2 +- man/mlr_tasks_mnist.Rd | 2 +- man/mlr_tasks_tiny_imagenet.Rd | 2 +- man/replace_head.Rd | 20 ++ tests/testthat/test_LearnerTorch.R | 2 - tests/testthat/test_LearnerTorchAlexNet.R | 17 -- tests/testthat/test_LearnerTorchVision.R | 232 ++++++++++++++++++ 27 files changed, 595 insertions(+), 220 deletions(-) delete mode 100644 R/LearnerTorchAlexNet.R create mode 100644 R/LearnerTorchVision.R rename man/{mlr_learners.alexnet.Rd => mlr_learners.torchvision.Rd} (64%) create mode 100644 man/replace_head.Rd delete mode 100644 tests/testthat/test_LearnerTorchAlexNet.R create mode 100644 tests/testthat/test_LearnerTorchVision.R diff --git a/.github/workflows/r-cmd-check.yml b/.github/workflows/r-cmd-check.yml index c3633ffe..43c7c0b7 100644 --- a/.github/workflows/r-cmd-check.yml +++ b/.github/workflows/r-cmd-check.yml @@ -29,6 +29,7 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} TORCH_INSTALL: 1 + INCLUDE_IGNORED: "true" strategy: fail-fast: false @@ -53,6 +54,42 @@ jobs: extra-packages: any::rcmdcheck needs: check + - name: Install additional packages + run: | + Rscript -e 'install.packages("rappdirs")' + + - name: Get torchvision package version (Linux/macOS) + if: runner.os != 'Windows' + id: get_package_version_unix + run: | + echo "TORCHVISION_PACKAGE_VERSION=$(Rscript -e 'cat(as.character(packageVersion("torchvision")))')" >> $GITHUB_ENV + + - name: Get torchvision package version (Windows) + if: runner.os == 'Windows' + id: get_package_version_windows + run: | + $version = Rscript -e 'cat(as.character(packageVersion("torchvision")))' + echo "TORCHVISION_PACKAGE_VERSION=$version" >> $env:GITHUB_ENV + + - name: Get torch cache path (Linux/macOS) + if: runner.os != 'Windows' + id: get_cache_path_unix + run: | + echo "TORCH_CACHE_PATH=$(Rscript -e 'cat(rappdirs::user_cache_dir("torch"))')" >> $GITHUB_ENV + + - name: Get torch cache path (Windows) + if: runner.os == 'Windows' + id: get_cache_path_windows + run: | + $cachePath = Rscript -e 'cat(rappdirs::user_cache_dir("torch"))' + echo "TORCH_CACHE_PATH=$cachePath" >> $env:GITHUB_ENV + + - name: Cache Torchvision Downloads + uses: actions/cache@v3 + with: + path: ${{ env.TORCH_CACHE_PATH }} + key: ${{ runner.os }}-r-${{ env.TORCHVISION_PACKAGE_VERSION }} + - uses: r-lib/actions/check-r-package@v2 - uses: mxschmitt/action-tmate@v3 diff --git a/DESCRIPTION b/DESCRIPTION index 115dba87..ae2091a3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -86,15 +86,15 @@ Collate: 'utils.R' 'DataDescriptor.R' 'LearnerTorch.R' - 'LearnerTorchImage.R' - 'LearnerTorchAlexNet.R' 'LearnerTorchFeatureless.R' + 'LearnerTorchImage.R' 'LearnerTorchMLP.R' 'task_dataset.R' 'shape.R' 'PipeOpTorchIngress.R' 'LearnerTorchModel.R' 'LearnerTorchTabResNet.R' + 'LearnerTorchVision.R' 'ModelDescriptor.R' 'PipeOpModule.R' 'PipeOpTorch.R' diff --git a/NAMESPACE b/NAMESPACE index d1fa9259..73117495 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,10 @@ S3method(print,TorchIngressToken) S3method(print,lazy_tensor) S3method(rep,lazy_tensor) S3method(rep_len,lazy_tensor) +S3method(replace_head,AlexNet) +S3method(replace_head,VGG) +S3method(replace_head,mobilenet_v2) +S3method(replace_head,resnet) S3method(reset_last_layer,AlexNet) S3method(reset_last_layer,resnet) S3method(t_clbk,"NULL") @@ -65,12 +69,12 @@ export(ContextTorch) export(DataBackendLazy) export(DataDescriptor) export(LearnerTorch) -export(LearnerTorchAlexNet) export(LearnerTorchFeatureless) export(LearnerTorchImage) export(LearnerTorchMLP) export(LearnerTorchModel) export(LearnerTorchTabResNet) +export(LearnerTorchVision) export(ModelDescriptor) export(PipeOpModule) export(PipeOpTaskPreprocTorch) @@ -168,6 +172,7 @@ export(nn_reshape) export(nn_squeeze) export(nn_unsqueeze) export(pipeop_preproc_torch) +export(replace_head) export(reset_last_layer) export(t_clbk) export(t_clbks) diff --git a/R/LearnerTorchAlexNet.R b/R/LearnerTorchAlexNet.R deleted file mode 100644 index 7035c718..00000000 --- a/R/LearnerTorchAlexNet.R +++ /dev/null @@ -1,61 +0,0 @@ -#' @title AlexNet Image Classifier -#' -#' @templateVar name alexnet -#' @templateVar task_types classif -#' @template learner -#' @template params_learner -#' -#' @description -#' Historic convolutional neural network for image classification. -#' -#' @section Parameters: -#' Parameters from [`LearnerTorchImage`] and -#' -#' * `pretrained` :: `logical(1)`\cr -#' Whether to use the pretrained model. -#' -#' @references `r format_bib("krizhevsky2017imagenet")` -#' @include LearnerTorchImage.R -#' @export -LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet", - inherit = LearnerTorchImage, - public = list( - #' @description Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) { - param_set = ps( - pretrained = p_lgl(tags = c("required", "train")) - ) - param_set$values = list(pretrained = TRUE) - super$initialize( - task_type = task_type, - id = paste0(task_type, ".alexnet"), - param_set = param_set, - man = "mlr3torch::mlr_learners.alexnet", - optimizer = optimizer, - loss = loss, - callbacks = callbacks, - label = "AlexNet Image Classifier" - ) - } - ), - private = list( - .network = function(task, param_vals) { - nout = get_nout(task) - if (param_vals$pretrained) { - network = torchvision::model_alexnet(pretrained = TRUE) - - network$classifier$`6` = torch::nn_linear( - in_features = network$classifier$`6`$in_features, - out_features = nout, - bias = TRUE - ) - return(network) - } - - torchvision::model_alexnet(pretrained = FALSE, num_classes = nout) - } - ) -) - -#' @include zzz.R -register_learner("classif.alexnet", LearnerTorchAlexNet) diff --git a/R/LearnerTorchVision.R b/R/LearnerTorchVision.R new file mode 100644 index 00000000..df27f4e2 --- /dev/null +++ b/R/LearnerTorchVision.R @@ -0,0 +1,180 @@ +#' @title AlexNet Image Classifier +#' +#' @name mlr_learners.torchvision +#' +#' @description +#' Classic image classification networks from `torchvision`. +#' +#' @section Parameters: +#' Parameters from [`LearnerTorchImage`] and +#' +#' * `pretrained` :: `logical(1)`\cr +#' Whether to use the pretrained model. +#' The final linear layer will be replaced with a new `nn_linear` with the +#' number of classes inferred from the [`Task`][mlr3::Task]. +#' +#' @section Properties: +#' * Supported task types: `"classif"` +#' * Predict Types: `"response"` and `"prob"` +#' * Feature Types: `"lazy_tensor"` +#' * Required packages: `"mlr3torch"`, `"torch"`, `"torchvision"` +#' @template params_learner +#' @param name (`character(1)`)\cr +#' The name of the network. +#' @param module_generator (`function(pretrained, num_classes)`)\cr +#' Function that generates the network. +#' @param label (`character(1)`)\cr +#' The label of the network. +#'#' @references +#' `r format_bib("krizhevsky2017imagenet")` +#' `r format_bib("sandler2018mobilenetv2")` +#' `r format_bib("he2016deep")` +#' `r format_bib("simonyan2014very")` +#' @include LearnerTorchImage.R +#' @export +LearnerTorchVision = R6Class("LearnerTorchVision", + inherit = LearnerTorchImage, + public = list( + #' @description Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(name, module_generator, label, optimizer = NULL, loss = NULL, callbacks = list()) { # nolint + task_type = "classif" + private$.module_generator = module_generator + param_set = ps( + pretrained = p_lgl(tags = c("required", "train")) + ) + param_set$values = list(pretrained = TRUE) + super$initialize( + task_type = task_type, + id = paste0(task_type, ".", name), + param_set = param_set, + man = paste0("mlr3torch::mlr_learners.torchvision"), + optimizer = optimizer, + loss = loss, + callbacks = callbacks, + label = label + ) + } + ), + private = list( + .module_generator = NULL, + .network = function(task, param_vals) { + nout = get_nout(task) + if (param_vals$pretrained) { + network = replace_head(private$.module_generator(pretrained = TRUE), nout) + return(network) + } + private$.module_generator(pretrained = FALSE, num_classes = nout) + }, + .additional_phash_input = function() { + list(private$.module_generator) + } + ) +) + +#' @export +replace_head.AlexNet = function(network, d_out) { + network$classifier$`6` = torch::nn_linear( + in_features = network$classifier$`6`$in_features, + out_features = d_out, + bias = TRUE + ) + network +} + +# #' @export +# replace_head.Inception3 = function(network, d_out) { +# network$fc = nn_linear(2048, d_out) +# network +# } + +#' @export +replace_head.mobilenet_v2 = function(network, d_out) { + network$classifier$`1` = nn_linear(1280, d_out) + network +} + +#' @export +replace_head.resnet = function(network, d_out) { + in_features = network$fc$in_features + network$fc = nn_linear(in_features, d_out) + network +} + +#' @export +replace_head.VGG = function(network, d_out) { + network$classifier$`6` = nn_linear(4096, d_out) + network +} + +#' @include zzz.R +register_learner("classif.alexnet", function() { + LearnerTorchVision$new("alexnet", torchvision::model_alexnet, "AlexNet") +}) + +# register_learner("classif.inception_v3", function() { +# LearnerTorchVision$new("inception_v3", torchvision::model_inception_v3, "Inception V3") +# }) + +register_learner("classif.mobilenet_v2", function() { + LearnerTorchVision$new("mobilenet_v2", torchvision::model_mobilenet_v2, "Mobilenet V2") +}) + +register_learner("classif.resnet18", function() { + LearnerTorchVision$new("resnet18", torchvision::model_resnet18, "ResNet-18") +}) + +register_learner("classif.resnet34", function() { + LearnerTorchVision$new("resnet34", torchvision::model_resnet34, "ResNet-34") +}) + +register_learner("classif.resnet50", function() { + LearnerTorchVision$new("resnet50", torchvision::model_resnet50, "ResNet-50") +}) + +register_learner("classif.resnet101", function() { + LearnerTorchVision$new("resnet101", torchvision::model_resnet101, "ResNet-101") +}) + +register_learner("classif.resnet152", function() { + LearnerTorchVision$new("resnet152", torchvision::model_resnet152, "ResNet-152") +}) + +register_learner("classif.resnext101_32x8d", function() { + LearnerTorchVision$new("resnext101_32x8d", torchvision::model_resnext101_32x8d, "ResNeXt-101 32x8d") +}) + +register_learner("classif.resnext50_32x4d", function() { + LearnerTorchVision$new("resnext50_32x4d", torchvision::model_resnext50_32x4d, "ResNeXt-50 32x4d") +}) + +register_learner("classif.vgg11", function() { + LearnerTorchVision$new("vgg11", torchvision::model_vgg11, "VGG 11") +}) + +register_learner("classif.vgg11_bn", function() { + LearnerTorchVision$new("vgg11_bn", torchvision::model_vgg11_bn, "VGG 11") +}) + +register_learner("classif.vgg13", function() { + LearnerTorchVision$new("vgg13", torchvision::model_vgg13, "VGG 13") +}) + +register_learner("classif.vgg13_bn", function() { + LearnerTorchVision$new("vgg13_bn", torchvision::model_vgg13_bn, "VGG 13") +}) + +register_learner("classif.vgg16", function() { + LearnerTorchVision$new("vgg16", torchvision::model_vgg16, "VGG 16") +}) + +register_learner("classif.vgg16_bn", function() { + LearnerTorchVision$new("vgg16_bn", torchvision::model_vgg16_bn, "VGG 16") +}) + +register_learner("classif.vgg19", function() { + LearnerTorchVision$new("vgg19", torchvision::model_vgg19, "VGG 19") +}) + +register_learner("classif.vgg19_bn", function() { + LearnerTorchVision$new("vgg19_bn", torchvision::model_vgg19_bn, "VGG 19") +}) diff --git a/R/TaskClassif_lazy_iris.R b/R/TaskClassif_lazy_iris.R index be205b65..49e010f5 100644 --- a/R/TaskClassif_lazy_iris.R +++ b/R/TaskClassif_lazy_iris.R @@ -16,7 +16,7 @@ #' @source #' \url{https://en.wikipedia.org/wiki/Iris_flower_data_set} #' -#' @section Meta Information: +#' @section Properties: #' `r rd_info_task_torch("lazy_iris", missings = FALSE)` #' #' @references diff --git a/R/TaskClassif_mnist.R b/R/TaskClassif_mnist.R index 1b7d9638..c14d6ad6 100644 --- a/R/TaskClassif_mnist.R +++ b/R/TaskClassif_mnist.R @@ -18,7 +18,7 @@ #' @source #' \url{https://torchvision.mlverse.org/reference/mnist_dataset.html} #' -#' @section Meta Information: +#' @section Properties: #' `r rd_info_task_torch("mnist", missings = FALSE)` #' #' @references diff --git a/R/TaskClassif_tiny_imagenet.R b/R/TaskClassif_tiny_imagenet.R index 33ea093c..6f8d623f 100644 --- a/R/TaskClassif_tiny_imagenet.R +++ b/R/TaskClassif_tiny_imagenet.R @@ -19,7 +19,7 @@ #' #' @template task_download #' -#' @section Meta Information: +#' @section Properties: #' `r rd_info_task_torch("tiny_imagenet", missings = FALSE)` #' #' @references diff --git a/R/bibentries.R b/R/bibentries.R index 80174531..68d7d4fa 100644 --- a/R/bibentries.R +++ b/R/bibentries.R @@ -1,6 +1,6 @@ bibentries = c(# nolint start gorishniy2021revisiting = bibentry("article", - title = "Revisiting Deep Learning Models for Tabular Data", + title = "Revisiting Deep Learning for Tabular Data", author = "Yury Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko", journal = "arXiv", volume = "2106.11959", @@ -27,12 +27,6 @@ bibentries = c(# nolint start pages = "6679--6687", year = "2021" ), - krizhevsky2014one = bibentry("article", - title = "One weird trick for parallelizing convolutional neural networks", - author = "Krizhevsky, Alex", - journal = "arXiv preprint arXiv:1404.5997", - year = "2014" - ), ioffe2015batch = bibentry("inproceedings", title = "Batch normalization: Accelerating deep network training by reducing internal covariate shift", @@ -85,5 +79,39 @@ bibentries = c(# nolint start author = "Edgar Anderson", title = "The Species Problem in Iris", journal = "Annals of the Missouri Botanical Garden" + ), + sandler2018mobilenetv2 = bibentry("InProceedings", + title = "Mobilenetv2: Inverted residuals and linear bottlenecks", + author = "Sandler, Mark and Howard, Andrew and Zhu, Menglong and Zhmoginov, Andrey and Chen, Liang-Chieh", + booktitle= "Proceedings of the IEEE conference on computer vision and pattern recognition", + pages = "4510--4520", + year = "2018" + ), + simonyan2014very = bibentry("article", + title = "Very deep convolutional networks for large-scale image recognition", + author = "Simonyan, Karen and Zisserman, Andrew", + journal= "arXiv preprint arXiv:1409.1556", + year = "2014" + ), + he2016deep = bibentry("InProceedings", + title = "Deep residual learning for image recognition ", + author = "He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian ", + booktitle = "Proceedings of the IEEE conference on computer vision and pattern recognition ", + pages = "770--778 ", + year = "2016 " + ), + krizhevsky2014one = bibentry("article", + title = "One weird trick for parallelizing convolutional neural networks", + author = "Krizhevsky, Alex", + journal = "arXiv preprint arXiv:1404.5997", + year = "2014" + ), + szegedy2016rethinking = bibentry("InProceedings", + title = "Rethinking the inception architecture for computer vision ", + author = "Szegedy, Christian and Vanhoucke, Vincent and Ioffe, Sergey and Shlens, Jon and Wojna, Zbigniew ", + booktitle = "Proceedings of the IEEE conference on computer vision and pattern recognition ", + pages = "2818--2826 ", + year = "2016 " ) ) # nolint end + diff --git a/R/utils.R b/R/utils.R index 04fee71e..69633732 100644 --- a/R/utils.R +++ b/R/utils.R @@ -219,6 +219,17 @@ clone_graph_unique_ids = function(g) { return(g1) } +#' Replace the head of a network +#' Replaces the head of the network with a linear layer with d_out classes. +#' @param network ([`torch::nn_module`])\cr +#' The network +#' @param d_out (`integer(1)`)\cr +#' The number of output classes. +#' @export +replace_head = function(network, d_out) { + UseMethod("replace_head") +} + check_nn_module = function(x) { check_class(x, "nn_module") } diff --git a/R/zzz.R b/R/zzz.R index 84887759..205c7703 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -39,7 +39,11 @@ register_resampling = function(name, constructor) { } register_learner = function(.name, .constructor, ...) { - assert_class(.constructor, "R6ClassGenerator") + assert_multi_class(.constructor, c("function", "R6ClassGenerator")) + if (is.function(.constructor)) { + mlr3torch_learners[[.name]] = list(fn = .constructor, prototype_args = list(...)) + return(NULL) + } task_type = if (startsWith(.name, "classif")) "classif" else "regr" # What I am doing here: # The problem is that we wan't to set the task_type when creating the learner from the dictionary diff --git a/man-roxygen/learner.R b/man-roxygen/learner.R index b372d2f0..e4e3aed2 100644 --- a/man-roxygen/learner.R +++ b/man-roxygen/learner.R @@ -11,16 +11,8 @@ #' <%= if ("regr" %in% task_types_vec) paste0("lrn(\"regr.", name, "\", ...)") else ""%> #' ``` #' -#' @section Meta Information: +#' @section Properties: #' `r mlr3torch:::rd_info_learner_torch("<%=name%>", "<%=task_types%>")` #' @md #' -#' @section State: -#' The state is a list with elements: -#' * `network` :: The trained [network][torch::nn_module]. -#' * `optimizer` :: The [optimizer][torch::optimizer] used to train the network. -#' * `loss_fn` :: The [loss][torch::nn_module] used to train the network. -#' * `callbacks` :: The [callbacks][mlr3torch::mlr_callback_set] used to train the network. -#' * `seed` :: The actual seed that was / is used for training and prediction. -#' #' @family Learner diff --git a/man-roxygen/learner_example.R b/man-roxygen/learner_example.R index f2c763fd..1e6fa25f 100644 --- a/man-roxygen/learner_example.R +++ b/man-roxygen/learner_example.R @@ -8,7 +8,8 @@ #' # Define the Learner and set parameter values #' <%= sprintf("learner = lrn(\"%s\")", id)%> #' learner$param_set$set_values( -#' <%= paste0(" ", paste0(c(param_vals, "batch_size = 1", "epochs = 1", "device = \"cpu\""), collapse = ", "))%> +#' epochs = 1, batch_size = 16, device = "cpu"<%= if (length(param_vals)) "," else character()%> +#' <%= if (length(param_vals)) paste0(param_vals, collapse = ",\n ") else character()%> #' ) #' #' # Define a Task diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index ecc24d50..dabc178a 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -17,7 +17,7 @@ lrn("regr.mlp", ...) }\if{html}{\out{}} } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Supported task types: 'classif', 'regr' @@ -31,18 +31,6 @@ lrn("regr.mlp", ...) } } -\section{State}{ - -The state is a list with elements: -\itemize{ -\item \code{network} :: The trained \link[torch:nn_module]{network}. -\item \code{optimizer} :: The \link[torch:optimizer]{optimizer} used to train the network. -\item \code{loss_fn} :: The \link[torch:nn_module]{loss} used to train the network. -\item \code{callbacks} :: The \link[=mlr_callback_set]{callbacks} used to train the network. -\item \code{seed} :: The actual seed that was / is used for training and prediction. -} -} - \section{Parameters}{ Parameters from \code{\link{LearnerTorch}}, as well as: @@ -70,7 +58,8 @@ Otherwise the input shape is inferred from the number of numeric features. # Define the Learner and set parameter values learner = lrn("classif.mlp") learner$param_set$set_values( - neurons = 10, batch_size = 1, epochs = 1, device = "cpu" + epochs = 1, batch_size = 16, device = "cpu", + neurons = 10 ) # Define a Task @@ -90,7 +79,6 @@ predictions$score() } \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.tab_resnet}}, \code{\link{mlr_learners.torch_featureless}}, \code{\link{mlr_learners_torch}}, diff --git a/man/mlr_learners.tab_resnet.Rd b/man/mlr_learners.tab_resnet.Rd index fdeed55c..bda3660f 100644 --- a/man/mlr_learners.tab_resnet.Rd +++ b/man/mlr_learners.tab_resnet.Rd @@ -16,7 +16,7 @@ lrn("regr.tab_resnet", ...) }\if{html}{\out{}} } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Supported task types: 'classif', 'regr' @@ -30,18 +30,6 @@ lrn("regr.tab_resnet", ...) } } -\section{State}{ - -The state is a list with elements: -\itemize{ -\item \code{network} :: The trained \link[torch:nn_module]{network}. -\item \code{optimizer} :: The \link[torch:optimizer]{optimizer} used to train the network. -\item \code{loss_fn} :: The \link[torch:nn_module]{loss} used to train the network. -\item \code{callbacks} :: The \link[=mlr_callback_set]{callbacks} used to train the network. -\item \code{seed} :: The actual seed that was / is used for training and prediction. -} -} - \section{Parameters}{ Parameters from \code{\link{LearnerTorch}}, as well as: @@ -67,7 +55,8 @@ Second dropout ratio. # Define the Learner and set parameter values learner = lrn("classif.tab_resnet") learner$param_set$set_values( - n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3, batch_size = 1, epochs = 1, device = "cpu" + epochs = 1, batch_size = 16, device = "cpu", + n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3 ) # Define a Task @@ -87,12 +76,11 @@ predictions$score() } \references{ Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). -\dQuote{Revisiting Deep Learning Models for Tabular Data.} +\dQuote{Revisiting Deep Learning for Tabular Data.} \emph{arXiv}, \bold{2106.11959}. } \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.mlp}}, \code{\link{mlr_learners.torch_featureless}}, \code{\link{mlr_learners_torch}}, diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd index 4b13dca9..b1fe59a6 100644 --- a/man/mlr_learners.torch_featureless.Rd +++ b/man/mlr_learners.torch_featureless.Rd @@ -19,7 +19,7 @@ lrn("regr.torch_featureless", ...) }\if{html}{\out{}} } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Supported task types: 'classif', 'regr' @@ -33,18 +33,6 @@ lrn("regr.torch_featureless", ...) } } -\section{State}{ - -The state is a list with elements: -\itemize{ -\item \code{network} :: The trained \link[torch:nn_module]{network}. -\item \code{optimizer} :: The \link[torch:optimizer]{optimizer} used to train the network. -\item \code{loss_fn} :: The \link[torch:nn_module]{loss} used to train the network. -\item \code{callbacks} :: The \link[=mlr_callback_set]{callbacks} used to train the network. -\item \code{seed} :: The actual seed that was / is used for training and prediction. -} -} - \section{Parameters}{ Only those from \code{\link{LearnerTorch}}. @@ -56,7 +44,8 @@ Only those from \code{\link{LearnerTorch}}. # Define the Learner and set parameter values learner = lrn("classif.torch_featureless") learner$param_set$set_values( - batch_size = 1, epochs = 1, device = "cpu" + epochs = 1, batch_size = 16, device = "cpu" + ) # Define a Task @@ -76,7 +65,6 @@ predictions$score() } \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.mlp}}, \code{\link{mlr_learners.tab_resnet}}, \code{\link{mlr_learners_torch}}, diff --git a/man/mlr_learners.alexnet.Rd b/man/mlr_learners.torchvision.Rd similarity index 64% rename from man/mlr_learners.alexnet.Rd rename to man/mlr_learners.torchvision.Rd index 2829caf2..90afa5a6 100644 --- a/man/mlr_learners.alexnet.Rd +++ b/man/mlr_learners.torchvision.Rd @@ -1,78 +1,41 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/LearnerTorchAlexNet.R -\name{mlr_learners.alexnet} -\alias{mlr_learners.alexnet} -\alias{LearnerTorchAlexNet} +% Please edit documentation in R/LearnerTorchVision.R +\name{mlr_learners.torchvision} +\alias{mlr_learners.torchvision} +\alias{LearnerTorchVision} \title{AlexNet Image Classifier} \description{ -Historic convolutional neural network for image classification. +Classic image classification networks from \code{torchvision}. } -\section{Dictionary}{ - -This \link{Learner} can be instantiated using the sugar function \code{\link[=lrn]{lrn()}}: - -\if{html}{\out{
}}\preformatted{lrn("classif.alexnet", ...) - -}\if{html}{\out{
}} -} - -\section{Meta Information}{ - -\itemize{ -\item Supported task types: 'classif' -\item Predict Types: -\itemize{ -\item classif: 'response', 'prob' -} -\item Feature Types: \dQuote{lazy_tensor} -\item Required Packages: \CRANpkg{mlr3}, \CRANpkg{mlr3torch}, \CRANpkg{torchvision}, \CRANpkg{magick}, \CRANpkg{torch} -} -} - -\section{State}{ - -The state is a list with elements: -\itemize{ -\item \code{network} :: The trained \link[torch:nn_module]{network}. -\item \code{optimizer} :: The \link[torch:optimizer]{optimizer} used to train the network. -\item \code{loss_fn} :: The \link[torch:nn_module]{loss} used to train the network. -\item \code{callbacks} :: The \link[=mlr_callback_set]{callbacks} used to train the network. -\item \code{seed} :: The actual seed that was / is used for training and prediction. -} -} - \section{Parameters}{ Parameters from \code{\link{LearnerTorchImage}} and \itemize{ \item \code{pretrained} :: \code{logical(1)}\cr Whether to use the pretrained model. +The final linear layer will be replaced with a new \code{nn_linear} with the +number of classes inferred from the \code{\link[mlr3:Task]{Task}}. } } -\references{ -Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). -\dQuote{Imagenet classification with deep convolutional neural networks.} -\emph{Communications of the ACM}, \bold{60}(6), 84--90. +\section{Properties}{ + +\itemize{ +\item Supported task types: \code{"classif"} +\item Predict Types: \code{"response"} and \code{"prob"} +\item Feature Types: \code{"lazy_tensor"} +\item Required packages: \code{"mlr3torch"}, \code{"torch"}, \code{"torchvision"} } -\seealso{ -Other Learner: -\code{\link{mlr_learners.mlp}}, -\code{\link{mlr_learners.tab_resnet}}, -\code{\link{mlr_learners.torch_featureless}}, -\code{\link{mlr_learners_torch}}, -\code{\link{mlr_learners_torch_image}}, -\code{\link{mlr_learners_torch_model}} } -\concept{Learner} + \section{Super classes}{ -\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3torch:LearnerTorch]{mlr3torch::LearnerTorch}} -> \code{\link[mlr3torch:LearnerTorchImage]{mlr3torch::LearnerTorchImage}} -> \code{LearnerTorchAlexNet} +\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3torch:LearnerTorch]{mlr3torch::LearnerTorch}} -> \code{\link[mlr3torch:LearnerTorchImage]{mlr3torch::LearnerTorchImage}} -> \code{LearnerTorchVision} } \section{Methods}{ \subsection{Public methods}{ \itemize{ -\item \href{#method-LearnerTorchAlexNet-new}{\code{LearnerTorchAlexNet$new()}} -\item \href{#method-LearnerTorchAlexNet-clone}{\code{LearnerTorchAlexNet$clone()}} +\item \href{#method-LearnerTorchVision-new}{\code{LearnerTorchVision$new()}} +\item \href{#method-LearnerTorchVision-clone}{\code{LearnerTorchVision$clone()}} } } \if{html}{\out{ @@ -92,13 +55,15 @@ Other Learner: }} \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-LearnerTorchAlexNet-new}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerTorchVision-new}{}}} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{LearnerTorchAlexNet$new( - task_type, +\if{html}{\out{
}}\preformatted{LearnerTorchVision$new( + name, + module_generator, + label, optimizer = NULL, loss = NULL, callbacks = list() @@ -108,8 +73,27 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{task_type}}{(\code{character(1)})\cr -The task type, either \verb{"classif}" or \code{"regr"}.} +\item{\code{name}}{(\code{character(1)})\cr +The name of the network.} + +\item{\code{module_generator}}{(\verb{function(pretrained, num_classes)})\cr +Function that generates the network.} + +\item{\code{label}}{(\code{character(1)})\cr +The label of the network. +#' @references +Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). +\dQuote{Imagenet classification with deep convolutional neural networks.} +\emph{Communications of the ACM}, \bold{60}(6), 84--90. +Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018). +\dQuote{Mobilenetv2: Inverted residuals and linear bottlenecks.} +In \emph{Proceedings of the IEEE conference on computer vision and pattern recognition}, 4510--4520. +He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016 ). +\dQuote{Deep residual learning for image recognition .} +In \emph{Proceedings of the IEEE conference on computer vision and pattern recognition }, 770--778 . +Simonyan, Karen, Zisserman, Andrew (2014). +\dQuote{Very deep convolutional networks for large-scale image recognition.} +\emph{arXiv preprint arXiv:1409.1556}.} \item{\code{optimizer}}{(\code{\link{TorchOptimizer}})\cr The optimizer to use for training. @@ -126,12 +110,12 @@ The callbacks. Must have unique ids.} } } \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-LearnerTorchAlexNet-clone}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerTorchVision-clone}{}}} \subsection{Method \code{clone()}}{ The objects of this class are cloneable with this method. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{LearnerTorchAlexNet$clone(deep = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{LearnerTorchVision$clone(deep = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 8860cd5e..4ca6d130 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -166,7 +166,6 @@ to implement the private \verb{$.additional_phash_input()}. \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.mlp}}, \code{\link{mlr_learners.tab_resnet}}, \code{\link{mlr_learners.torch_featureless}}, diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd index 3bb6835a..c59277bf 100644 --- a/man/mlr_learners_torch_image.Rd +++ b/man/mlr_learners_torch_image.Rd @@ -15,7 +15,6 @@ Parameters include those inherited from \code{\link{LearnerTorch}} and the \code \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.mlp}}, \code{\link{mlr_learners.tab_resnet}}, \code{\link{mlr_learners.torch_featureless}}, diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd index 37eb069d..217a52e4 100644 --- a/man/mlr_learners_torch_model.Rd +++ b/man/mlr_learners_torch_model.Rd @@ -44,7 +44,6 @@ learner$predict(task, ids$test) } \seealso{ Other Learner: -\code{\link{mlr_learners.alexnet}}, \code{\link{mlr_learners.mlp}}, \code{\link{mlr_learners.tab_resnet}}, \code{\link{mlr_learners.torch_featureless}}, diff --git a/man/mlr_tasks_lazy_iris.Rd b/man/mlr_tasks_lazy_iris.Rd index f13ce8fc..63e308a8 100644 --- a/man/mlr_tasks_lazy_iris.Rd +++ b/man/mlr_tasks_lazy_iris.Rd @@ -20,7 +20,7 @@ Just like the iris task, but the features are represented as one lazy tensor col }\if{html}{\out{
}} } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Task type: \dQuote{classif} diff --git a/man/mlr_tasks_mnist.Rd b/man/mlr_tasks_mnist.Rd index fdb4e009..fb73e183 100644 --- a/man/mlr_tasks_mnist.Rd +++ b/man/mlr_tasks_mnist.Rd @@ -29,7 +29,7 @@ You can cache these datasets by setting the \code{mlr3torch.cache} option to \co as the cache directory. } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Task type: \dQuote{classif} diff --git a/man/mlr_tasks_tiny_imagenet.Rd b/man/mlr_tasks_tiny_imagenet.Rd index 36a8a4f3..d3c820c6 100644 --- a/man/mlr_tasks_tiny_imagenet.Rd +++ b/man/mlr_tasks_tiny_imagenet.Rd @@ -28,7 +28,7 @@ You can cache these datasets by setting the \code{mlr3torch.cache} option to \co as the cache directory. } -\section{Meta Information}{ +\section{Properties}{ \itemize{ \item Task type: \dQuote{classif} diff --git a/man/replace_head.Rd b/man/replace_head.Rd new file mode 100644 index 00000000..04a88a1f --- /dev/null +++ b/man/replace_head.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{replace_head} +\alias{replace_head} +\title{Replace the head of a network +Replaces the head of the network with a linear layer with d_out classes.} +\usage{ +replace_head(network, d_out) +} +\arguments{ +\item{network}{(\code{\link[torch:nn_module]{torch::nn_module}})\cr +The network} + +\item{d_out}{(\code{integer(1)})\cr +The number of output classes.} +} +\description{ +Replace the head of a network +Replaces the head of the network with a linear layer with d_out classes. +} diff --git a/tests/testthat/test_LearnerTorch.R b/tests/testthat/test_LearnerTorch.R index e01f4e6c..087c3051 100644 --- a/tests/testthat/test_LearnerTorch.R +++ b/tests/testthat/test_LearnerTorch.R @@ -13,7 +13,6 @@ test_that("deep cloning", { learner_cloned = learner$clone(deep = TRUE) expect_deep_clone(learner, learner_cloned) - # just because we are paranoid network = learner$network network_cloned = learner_cloned$network expect_true(torch_equal(network$weights, network_cloned$weights)) @@ -21,7 +20,6 @@ test_that("deep cloning", { network$weights[1] = network$weights[1] + 1 expect_false(torch_equal(network$weights, network_cloned$weights)) - # but the generators are not cloned expect_identical( get_private(learner)$.loss$generator, get_private(learner_cloned)$.loss$generator diff --git a/tests/testthat/test_LearnerTorchAlexNet.R b/tests/testthat/test_LearnerTorchAlexNet.R deleted file mode 100644 index 1c9b6785..00000000 --- a/tests/testthat/test_LearnerTorchAlexNet.R +++ /dev/null @@ -1,17 +0,0 @@ -test_that("LearnerAlexnet runs", { - learner = lrn("classif.alexnet", - epochs = 1L, - batch_size = 1L, - callbacks = list(), - optimizer = "adam", - loss = "cross_entropy", - pretrained = FALSE - ) - task = nano_imagenet() - resampling = rsmp("holdout") - task$row_roles$use = sample(task$nrow, size = 2) - learner$train(task) - - pred = learner$predict(task) - expect_class(pred, "PredictionClassif") -}) diff --git a/tests/testthat/test_LearnerTorchVision.R b/tests/testthat/test_LearnerTorchVision.R new file mode 100644 index 00000000..8900cdfe --- /dev/null +++ b/tests/testthat/test_LearnerTorchVision.R @@ -0,0 +1,232 @@ +# different number of classes than the predefined ones +task = as_task_classif(data.table( + y = as.factor(rep(c("a", "b", "c"), each = 2)), + x = as_lazy_tensor(torch_randn(6, 3, 64, 64)) +), id = "test_task", target = "y") + +test_that("LearnerTorchVision basic checks", { + alexnet = lrn("classif.alexnet", epochs = 1L, batch_size = 1L, pretrained = FALSE) + expect_deep_clone(alexnet, alexnet$clone(deep = TRUE)) + + alexnet$train(task) + expect_class(alexnet$predict(task), "PredictionClassif") + + expect_learner_torch(alexnet) + alexnet$id = "a" + vgg13 = lrn("classif.vgg13", pretrained = FALSE) + vgg13$id = "a" + expect_false(alexnet$phash == vgg13$phash) + expect_true("torchvision" %in% alexnet$packages) + expect_true("magick" %in% alexnet$packages) +}) + +test_that("alexnet", { + learner = lrn("classif.alexnet", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +# weird warning regarding weight initialization from torchvision +# test_that("inception_v3", { +# learner = lrn("classif.inception_v3", epochs = 0L, batch_size = 2L, pretrained = FALSE) +# learner$train(task, sample(task$nrow, 1L)) +# pred = learner$predict(task, sample(task$nrow, 1L)) +# expect_class(pred, "PredictionClassif") +# }) + +# these tests are run the CI, but they should basically never fail, so +# we skip them in the local run +# models are also cached in the CI, so it is not too slow +skip_if(Sys.getenv("INCLUDE_IGNORED") != "true", "Slow vision tests") + +test_that("mobilenet_v2", { + learner = lrn("classif.mobilenet_v2", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + learner = lrn("classif.mobilenet_v2", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet18", { + learner = lrn("classif.resnet18", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnet18", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet34", { + learner = lrn("classif.resnet34", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnet34", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet50", { + learner = lrn("classif.resnet50", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnet50", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet101", { + learner = lrn("classif.resnet101", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnet101", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet152", { + learner = lrn("classif.resnet152", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnet152", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet101_32x8d", { + learner = lrn("classif.resnext101_32x8d", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnext101_32x8d", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("resnet50_32x4d", { + learner = lrn("classif.resnext50_32x4d", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.resnext50_32x4d", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg11", { + learner = lrn("classif.vgg11", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg11", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg11_bn", { + learner = lrn("classif.vgg11_bn", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg11_bn", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg13", { + learner = lrn("classif.vgg13", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg13", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg13_bn", { + learner = lrn("classif.vgg13_bn", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg13_bn", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg16", { + learner = lrn("classif.vgg16", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg16", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg16_bn", { + learner = lrn("classif.vgg16_bn", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg16_bn", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg19", { + learner = lrn("classif.vgg19", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg19", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +}) + +test_that("vgg19_bn", { + learner = lrn("classif.vgg19_bn", epochs = 0L, batch_size = 2L, pretrained = FALSE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") + + learner = lrn("classif.vgg19_bn", epochs = 0L, batch_size = 2L, pretrained = TRUE) + learner$train(task, sample(task$nrow, 1L)) + pred = learner$predict(task, sample(task$nrow, 1L)) + expect_class(pred, "PredictionClassif") +})