Skip to content

Commit

Permalink
feat: torchvision learners (#238)
Browse files Browse the repository at this point in the history
* wip

* add torchvision learners

* desc

* ...

* ...

* ...

* ...

* fix example

* ...

* ...
  • Loading branch information
sebffischer authored Jul 1, 2024
1 parent c089d8a commit f3acbec
Show file tree
Hide file tree
Showing 27 changed files with 595 additions and 220 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
TORCH_INSTALL: 1
INCLUDE_IGNORED: "true"

strategy:
fail-fast: false
Expand All @@ -53,6 +54,42 @@ jobs:
extra-packages: any::rcmdcheck
needs: check

- name: Install additional packages
run: |
Rscript -e 'install.packages("rappdirs")'
- name: Get torchvision package version (Linux/macOS)
if: runner.os != 'Windows'
id: get_package_version_unix
run: |
echo "TORCHVISION_PACKAGE_VERSION=$(Rscript -e 'cat(as.character(packageVersion("torchvision")))')" >> $GITHUB_ENV
- name: Get torchvision package version (Windows)
if: runner.os == 'Windows'
id: get_package_version_windows
run: |
$version = Rscript -e 'cat(as.character(packageVersion("torchvision")))'
echo "TORCHVISION_PACKAGE_VERSION=$version" >> $env:GITHUB_ENV
- name: Get torch cache path (Linux/macOS)
if: runner.os != 'Windows'
id: get_cache_path_unix
run: |
echo "TORCH_CACHE_PATH=$(Rscript -e 'cat(rappdirs::user_cache_dir("torch"))')" >> $GITHUB_ENV
- name: Get torch cache path (Windows)
if: runner.os == 'Windows'
id: get_cache_path_windows
run: |
$cachePath = Rscript -e 'cat(rappdirs::user_cache_dir("torch"))'
echo "TORCH_CACHE_PATH=$cachePath" >> $env:GITHUB_ENV
- name: Cache Torchvision Downloads
uses: actions/cache@v3
with:
path: ${{ env.TORCH_CACHE_PATH }}
key: ${{ runner.os }}-r-${{ env.TORCHVISION_PACKAGE_VERSION }}

- uses: r-lib/actions/check-r-package@v2

- uses: mxschmitt/action-tmate@v3
Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ Collate:
'utils.R'
'DataDescriptor.R'
'LearnerTorch.R'
'LearnerTorchImage.R'
'LearnerTorchAlexNet.R'
'LearnerTorchFeatureless.R'
'LearnerTorchImage.R'
'LearnerTorchMLP.R'
'task_dataset.R'
'shape.R'
'PipeOpTorchIngress.R'
'LearnerTorchModel.R'
'LearnerTorchTabResNet.R'
'LearnerTorchVision.R'
'ModelDescriptor.R'
'PipeOpModule.R'
'PipeOpTorch.R'
Expand Down
7 changes: 6 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ S3method(print,TorchIngressToken)
S3method(print,lazy_tensor)
S3method(rep,lazy_tensor)
S3method(rep_len,lazy_tensor)
S3method(replace_head,AlexNet)
S3method(replace_head,VGG)
S3method(replace_head,mobilenet_v2)
S3method(replace_head,resnet)
S3method(reset_last_layer,AlexNet)
S3method(reset_last_layer,resnet)
S3method(t_clbk,"NULL")
Expand All @@ -65,12 +69,12 @@ export(ContextTorch)
export(DataBackendLazy)
export(DataDescriptor)
export(LearnerTorch)
export(LearnerTorchAlexNet)
export(LearnerTorchFeatureless)
export(LearnerTorchImage)
export(LearnerTorchMLP)
export(LearnerTorchModel)
export(LearnerTorchTabResNet)
export(LearnerTorchVision)
export(ModelDescriptor)
export(PipeOpModule)
export(PipeOpTaskPreprocTorch)
Expand Down Expand Up @@ -168,6 +172,7 @@ export(nn_reshape)
export(nn_squeeze)
export(nn_unsqueeze)
export(pipeop_preproc_torch)
export(replace_head)
export(reset_last_layer)
export(t_clbk)
export(t_clbks)
Expand Down
61 changes: 0 additions & 61 deletions R/LearnerTorchAlexNet.R

This file was deleted.

180 changes: 180 additions & 0 deletions R/LearnerTorchVision.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#' @title AlexNet Image Classifier
#'
#' @name mlr_learners.torchvision
#'
#' @description
#' Classic image classification networks from `torchvision`.
#'
#' @section Parameters:
#' Parameters from [`LearnerTorchImage`] and
#'
#' * `pretrained` :: `logical(1)`\cr
#' Whether to use the pretrained model.
#' The final linear layer will be replaced with a new `nn_linear` with the
#' number of classes inferred from the [`Task`][mlr3::Task].
#'
#' @section Properties:
#' * Supported task types: `"classif"`
#' * Predict Types: `"response"` and `"prob"`
#' * Feature Types: `"lazy_tensor"`
#' * Required packages: `"mlr3torch"`, `"torch"`, `"torchvision"`
#' @template params_learner
#' @param name (`character(1)`)\cr
#' The name of the network.
#' @param module_generator (`function(pretrained, num_classes)`)\cr
#' Function that generates the network.
#' @param label (`character(1)`)\cr
#' The label of the network.
#'#' @references
#' `r format_bib("krizhevsky2017imagenet")`
#' `r format_bib("sandler2018mobilenetv2")`
#' `r format_bib("he2016deep")`
#' `r format_bib("simonyan2014very")`
#' @include LearnerTorchImage.R
#' @export
LearnerTorchVision = R6Class("LearnerTorchVision",
inherit = LearnerTorchImage,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(name, module_generator, label, optimizer = NULL, loss = NULL, callbacks = list()) { # nolint
task_type = "classif"
private$.module_generator = module_generator
param_set = ps(
pretrained = p_lgl(tags = c("required", "train"))
)
param_set$values = list(pretrained = TRUE)
super$initialize(
task_type = task_type,
id = paste0(task_type, ".", name),
param_set = param_set,
man = paste0("mlr3torch::mlr_learners.torchvision"),
optimizer = optimizer,
loss = loss,
callbacks = callbacks,
label = label
)
}
),
private = list(
.module_generator = NULL,
.network = function(task, param_vals) {
nout = get_nout(task)
if (param_vals$pretrained) {
network = replace_head(private$.module_generator(pretrained = TRUE), nout)
return(network)
}
private$.module_generator(pretrained = FALSE, num_classes = nout)
},
.additional_phash_input = function() {
list(private$.module_generator)
}
)
)

#' @export
replace_head.AlexNet = function(network, d_out) {
network$classifier$`6` = torch::nn_linear(
in_features = network$classifier$`6`$in_features,
out_features = d_out,
bias = TRUE
)
network
}

# #' @export
# replace_head.Inception3 = function(network, d_out) {
# network$fc = nn_linear(2048, d_out)
# network
# }

#' @export
replace_head.mobilenet_v2 = function(network, d_out) {
network$classifier$`1` = nn_linear(1280, d_out)
network
}

#' @export
replace_head.resnet = function(network, d_out) {
in_features = network$fc$in_features
network$fc = nn_linear(in_features, d_out)
network
}

#' @export
replace_head.VGG = function(network, d_out) {
network$classifier$`6` = nn_linear(4096, d_out)
network
}

#' @include zzz.R
register_learner("classif.alexnet", function() {
LearnerTorchVision$new("alexnet", torchvision::model_alexnet, "AlexNet")
})

# register_learner("classif.inception_v3", function() {
# LearnerTorchVision$new("inception_v3", torchvision::model_inception_v3, "Inception V3")
# })

register_learner("classif.mobilenet_v2", function() {
LearnerTorchVision$new("mobilenet_v2", torchvision::model_mobilenet_v2, "Mobilenet V2")
})

register_learner("classif.resnet18", function() {
LearnerTorchVision$new("resnet18", torchvision::model_resnet18, "ResNet-18")
})

register_learner("classif.resnet34", function() {
LearnerTorchVision$new("resnet34", torchvision::model_resnet34, "ResNet-34")
})

register_learner("classif.resnet50", function() {
LearnerTorchVision$new("resnet50", torchvision::model_resnet50, "ResNet-50")
})

register_learner("classif.resnet101", function() {
LearnerTorchVision$new("resnet101", torchvision::model_resnet101, "ResNet-101")
})

register_learner("classif.resnet152", function() {
LearnerTorchVision$new("resnet152", torchvision::model_resnet152, "ResNet-152")
})

register_learner("classif.resnext101_32x8d", function() {
LearnerTorchVision$new("resnext101_32x8d", torchvision::model_resnext101_32x8d, "ResNeXt-101 32x8d")
})

register_learner("classif.resnext50_32x4d", function() {
LearnerTorchVision$new("resnext50_32x4d", torchvision::model_resnext50_32x4d, "ResNeXt-50 32x4d")
})

register_learner("classif.vgg11", function() {
LearnerTorchVision$new("vgg11", torchvision::model_vgg11, "VGG 11")
})

register_learner("classif.vgg11_bn", function() {
LearnerTorchVision$new("vgg11_bn", torchvision::model_vgg11_bn, "VGG 11")
})

register_learner("classif.vgg13", function() {
LearnerTorchVision$new("vgg13", torchvision::model_vgg13, "VGG 13")
})

register_learner("classif.vgg13_bn", function() {
LearnerTorchVision$new("vgg13_bn", torchvision::model_vgg13_bn, "VGG 13")
})

register_learner("classif.vgg16", function() {
LearnerTorchVision$new("vgg16", torchvision::model_vgg16, "VGG 16")
})

register_learner("classif.vgg16_bn", function() {
LearnerTorchVision$new("vgg16_bn", torchvision::model_vgg16_bn, "VGG 16")
})

register_learner("classif.vgg19", function() {
LearnerTorchVision$new("vgg19", torchvision::model_vgg19, "VGG 19")
})

register_learner("classif.vgg19_bn", function() {
LearnerTorchVision$new("vgg19_bn", torchvision::model_vgg19_bn, "VGG 19")
})
2 changes: 1 addition & 1 deletion R/TaskClassif_lazy_iris.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#' @source
#' \url{https://en.wikipedia.org/wiki/Iris_flower_data_set}
#'
#' @section Meta Information:
#' @section Properties:
#' `r rd_info_task_torch("lazy_iris", missings = FALSE)`
#'
#' @references
Expand Down
2 changes: 1 addition & 1 deletion R/TaskClassif_mnist.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#' @source
#' \url{https://torchvision.mlverse.org/reference/mnist_dataset.html}
#'
#' @section Meta Information:
#' @section Properties:
#' `r rd_info_task_torch("mnist", missings = FALSE)`
#'
#' @references
Expand Down
2 changes: 1 addition & 1 deletion R/TaskClassif_tiny_imagenet.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#'
#' @template task_download
#'
#' @section Meta Information:
#' @section Properties:
#' `r rd_info_task_torch("tiny_imagenet", missings = FALSE)`
#'
#' @references
Expand Down
Loading

0 comments on commit f3acbec

Please sign in to comment.