-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * add torchvision learners * desc * ... * ... * ... * ... * fix example * ... * ...
- Loading branch information
1 parent
c089d8a
commit f3acbec
Showing
27 changed files
with
595 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#' @title AlexNet Image Classifier | ||
#' | ||
#' @name mlr_learners.torchvision | ||
#' | ||
#' @description | ||
#' Classic image classification networks from `torchvision`. | ||
#' | ||
#' @section Parameters: | ||
#' Parameters from [`LearnerTorchImage`] and | ||
#' | ||
#' * `pretrained` :: `logical(1)`\cr | ||
#' Whether to use the pretrained model. | ||
#' The final linear layer will be replaced with a new `nn_linear` with the | ||
#' number of classes inferred from the [`Task`][mlr3::Task]. | ||
#' | ||
#' @section Properties: | ||
#' * Supported task types: `"classif"` | ||
#' * Predict Types: `"response"` and `"prob"` | ||
#' * Feature Types: `"lazy_tensor"` | ||
#' * Required packages: `"mlr3torch"`, `"torch"`, `"torchvision"` | ||
#' @template params_learner | ||
#' @param name (`character(1)`)\cr | ||
#' The name of the network. | ||
#' @param module_generator (`function(pretrained, num_classes)`)\cr | ||
#' Function that generates the network. | ||
#' @param label (`character(1)`)\cr | ||
#' The label of the network. | ||
#'#' @references | ||
#' `r format_bib("krizhevsky2017imagenet")` | ||
#' `r format_bib("sandler2018mobilenetv2")` | ||
#' `r format_bib("he2016deep")` | ||
#' `r format_bib("simonyan2014very")` | ||
#' @include LearnerTorchImage.R | ||
#' @export | ||
LearnerTorchVision = R6Class("LearnerTorchVision", | ||
inherit = LearnerTorchImage, | ||
public = list( | ||
#' @description Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(name, module_generator, label, optimizer = NULL, loss = NULL, callbacks = list()) { # nolint | ||
task_type = "classif" | ||
private$.module_generator = module_generator | ||
param_set = ps( | ||
pretrained = p_lgl(tags = c("required", "train")) | ||
) | ||
param_set$values = list(pretrained = TRUE) | ||
super$initialize( | ||
task_type = task_type, | ||
id = paste0(task_type, ".", name), | ||
param_set = param_set, | ||
man = paste0("mlr3torch::mlr_learners.torchvision"), | ||
optimizer = optimizer, | ||
loss = loss, | ||
callbacks = callbacks, | ||
label = label | ||
) | ||
} | ||
), | ||
private = list( | ||
.module_generator = NULL, | ||
.network = function(task, param_vals) { | ||
nout = get_nout(task) | ||
if (param_vals$pretrained) { | ||
network = replace_head(private$.module_generator(pretrained = TRUE), nout) | ||
return(network) | ||
} | ||
private$.module_generator(pretrained = FALSE, num_classes = nout) | ||
}, | ||
.additional_phash_input = function() { | ||
list(private$.module_generator) | ||
} | ||
) | ||
) | ||
|
||
#' @export | ||
replace_head.AlexNet = function(network, d_out) { | ||
network$classifier$`6` = torch::nn_linear( | ||
in_features = network$classifier$`6`$in_features, | ||
out_features = d_out, | ||
bias = TRUE | ||
) | ||
network | ||
} | ||
|
||
# #' @export | ||
# replace_head.Inception3 = function(network, d_out) { | ||
# network$fc = nn_linear(2048, d_out) | ||
# network | ||
# } | ||
|
||
#' @export | ||
replace_head.mobilenet_v2 = function(network, d_out) { | ||
network$classifier$`1` = nn_linear(1280, d_out) | ||
network | ||
} | ||
|
||
#' @export | ||
replace_head.resnet = function(network, d_out) { | ||
in_features = network$fc$in_features | ||
network$fc = nn_linear(in_features, d_out) | ||
network | ||
} | ||
|
||
#' @export | ||
replace_head.VGG = function(network, d_out) { | ||
network$classifier$`6` = nn_linear(4096, d_out) | ||
network | ||
} | ||
|
||
#' @include zzz.R | ||
register_learner("classif.alexnet", function() { | ||
LearnerTorchVision$new("alexnet", torchvision::model_alexnet, "AlexNet") | ||
}) | ||
|
||
# register_learner("classif.inception_v3", function() { | ||
# LearnerTorchVision$new("inception_v3", torchvision::model_inception_v3, "Inception V3") | ||
# }) | ||
|
||
register_learner("classif.mobilenet_v2", function() { | ||
LearnerTorchVision$new("mobilenet_v2", torchvision::model_mobilenet_v2, "Mobilenet V2") | ||
}) | ||
|
||
register_learner("classif.resnet18", function() { | ||
LearnerTorchVision$new("resnet18", torchvision::model_resnet18, "ResNet-18") | ||
}) | ||
|
||
register_learner("classif.resnet34", function() { | ||
LearnerTorchVision$new("resnet34", torchvision::model_resnet34, "ResNet-34") | ||
}) | ||
|
||
register_learner("classif.resnet50", function() { | ||
LearnerTorchVision$new("resnet50", torchvision::model_resnet50, "ResNet-50") | ||
}) | ||
|
||
register_learner("classif.resnet101", function() { | ||
LearnerTorchVision$new("resnet101", torchvision::model_resnet101, "ResNet-101") | ||
}) | ||
|
||
register_learner("classif.resnet152", function() { | ||
LearnerTorchVision$new("resnet152", torchvision::model_resnet152, "ResNet-152") | ||
}) | ||
|
||
register_learner("classif.resnext101_32x8d", function() { | ||
LearnerTorchVision$new("resnext101_32x8d", torchvision::model_resnext101_32x8d, "ResNeXt-101 32x8d") | ||
}) | ||
|
||
register_learner("classif.resnext50_32x4d", function() { | ||
LearnerTorchVision$new("resnext50_32x4d", torchvision::model_resnext50_32x4d, "ResNeXt-50 32x4d") | ||
}) | ||
|
||
register_learner("classif.vgg11", function() { | ||
LearnerTorchVision$new("vgg11", torchvision::model_vgg11, "VGG 11") | ||
}) | ||
|
||
register_learner("classif.vgg11_bn", function() { | ||
LearnerTorchVision$new("vgg11_bn", torchvision::model_vgg11_bn, "VGG 11") | ||
}) | ||
|
||
register_learner("classif.vgg13", function() { | ||
LearnerTorchVision$new("vgg13", torchvision::model_vgg13, "VGG 13") | ||
}) | ||
|
||
register_learner("classif.vgg13_bn", function() { | ||
LearnerTorchVision$new("vgg13_bn", torchvision::model_vgg13_bn, "VGG 13") | ||
}) | ||
|
||
register_learner("classif.vgg16", function() { | ||
LearnerTorchVision$new("vgg16", torchvision::model_vgg16, "VGG 16") | ||
}) | ||
|
||
register_learner("classif.vgg16_bn", function() { | ||
LearnerTorchVision$new("vgg16_bn", torchvision::model_vgg16_bn, "VGG 16") | ||
}) | ||
|
||
register_learner("classif.vgg19", function() { | ||
LearnerTorchVision$new("vgg19", torchvision::model_vgg19, "VGG 19") | ||
}) | ||
|
||
register_learner("classif.vgg19_bn", function() { | ||
LearnerTorchVision$new("vgg19_bn", torchvision::model_vgg19_bn, "VGG 19") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.