Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: lazy tensor datatype #139

Merged
merged 94 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
87c2b88
WIP: lazy tensor
sebffischer Oct 11, 2023
a1a997c
materialization starts working
sebffischer Oct 12, 2023
e5c5a32
progress on lazy tensor transform
sebffischer Oct 12, 2023
ca0c6c9
sperate preproc and trafo class, generate pipeops programatically
sebffischer Oct 13, 2023
ce54f27
add test
sebffischer Oct 13, 2023
e29f6b6
WIP ...
sebffischer Oct 17, 2023
d7df862
task preprocessing is working
sebffischer Oct 18, 2023
9d4404f
...
sebffischer Oct 20, 2023
8636cd3
tests passing
sebffischer Oct 26, 2023
ccbca0a
some progress
sebffischer Oct 27, 2023
5db7750
bug fix + important test
sebffischer Oct 27, 2023
7097cb4
tests are passing
sebffischer Nov 3, 2023
4c6d5a6
more fixes
sebffischer Nov 8, 2023
becbcff
add implementations, change augment -> stages
sebffischer Nov 24, 2023
3cfafdf
...
sebffischer Dec 5, 2023
32b327d
Merge branch 'main' into feat/lazy_tensor
sebffischer Dec 5, 2023
e57b3c0
broken but almost working
sebffischer Dec 6, 2023
97435d6
lets go
sebffischer Dec 6, 2023
ceced06
various improvements
sebffischer Jan 17, 2024
cbc234e
some fixes
sebffischer Jan 17, 2024
b79aedc
fix documentation issue
sebffischer Jan 17, 2024
45f3d3d
fix some bugs, remove unneeded deps
sebffischer Jan 19, 2024
acf6239
remove vctrs from desc
sebffischer Jan 19, 2024
5849806
remove vctrs from imports
sebffischer Jan 19, 2024
da9b40d
rcmdcheck passes
sebffischer Jan 22, 2024
bf9212a
pkgdown
sebffischer Jan 22, 2024
83ff680
some more fixes
sebffischer Jan 22, 2024
ad25184
specify hash input
sebffischer Jan 22, 2024
559ade7
fixes
sebffischer Jan 23, 2024
6990a2b
...
sebffischer Jan 23, 2024
a60abb6
...
sebffischer Jan 23, 2024
75a7274
fixes
sebffischer Jan 23, 2024
60e14e8
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
0af2f7a
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
1570059
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
d097126
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
b87bfc2
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
e33dc1a
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
9dc507c
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
fbe093f
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
cf380e8
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
c44d63c
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
c9f912a
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
6f661a8
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
47d8e17
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
0b4ea03
stages_init
sebffischer Jan 23, 2024
3911966
Update R/PipeOpTaskPreprocTorch.R
sebffischer Jan 23, 2024
560e9ee
Update R/PipeOpTorchIngress.R
sebffischer Jan 23, 2024
0f98936
Update R/shape.R
sebffischer Jan 23, 2024
ed6b014
Update R/shape.R
sebffischer Jan 23, 2024
9b2b1e4
...
sebffischer Jan 23, 2024
6d80fbe
...
sebffischer Jan 23, 2024
d13d986
address remaining issues
sebffischer Jan 23, 2024
8c36923
...
sebffischer Jan 23, 2024
1e3d3a7
fix test on macOS
sebffischer Jan 23, 2024
5896e06
more tests
sebffischer Jan 23, 2024
6b8755a
refactor lazy_tensor tests
sebffischer Jan 23, 2024
20e7e7c
better tests
sebffischer Jan 24, 2024
41bc9c1
tags
sebffischer Jan 25, 2024
4f32e59
fix tests
sebffischer Jan 25, 2024
7a25be1
suppor tags for preprocessing
sebffischer Jan 25, 2024
a1e8653
add missing docs
sebffischer Jan 25, 2024
d744cbd
...
sebffischer Jan 25, 2024
b2ae575
...
sebffischer Jan 25, 2024
f52922e
fix dumb bug
sebffischer Jan 25, 2024
62e4f44
...
sebffischer Jan 25, 2024
b2730c3
...
sebffischer Jan 26, 2024
05451e5
rcmdcheck passing
sebffischer Jan 26, 2024
132326c
support mps in auto_device
sebffischer Jan 26, 2024
93fafb0
Merge branch 'main' into feat/lazy_tensor
sebffischer Jan 26, 2024
88ad698
disable mps in gha tests
sebffischer Jan 26, 2024
2f66af1
...
sebffischer Jan 26, 2024
285736c
run examples on cpu
sebffischer Jan 26, 2024
19aca8b
...
sebffischer Jan 26, 2024
466b4a0
please run on macos
sebffischer Jan 26, 2024
9992ab9
clone is FALSE
sebffischer Jan 27, 2024
9f2450a
lazy tensor
sebffischer Jan 30, 2024
09957f8
fix docs
sebffischer Jan 30, 2024
671e5c7
...
sebffischer Jan 30, 2024
a52cd1d
...
sebffischer Jan 30, 2024
9c88bd3
pkgdown
sebffischer Jan 30, 2024
a818ea8
...
sebffischer Jan 31, 2024
15b74cd
...
sebffischer Jan 31, 2024
099bf5c
init vs default
sebffischer Jan 31, 2024
d9a8c98
many fixes
sebffischer Feb 7, 2024
e828722
pkgdown
sebffischer Feb 7, 2024
d995f74
Merge branch 'main' into feat/lazy_tensor
sebffischer Feb 7, 2024
7c9aa2c
docs: escape examples if torch not installed
sebffischer Feb 7, 2024
83b8e2d
Update R/task_dataset.R
sebffischer Feb 7, 2024
f9493af
Update DESCRIPTION
sebffischer Feb 7, 2024
a6317dc
Update R/materialize.R
sebffischer Feb 7, 2024
05644c4
Update R/utils.R
sebffischer Feb 7, 2024
eccccaa
cleanup
sebffischer Feb 7, 2024
c18f7d0
docs
sebffischer Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
6 changes: 6 additions & 0 deletions .github/dependabot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
5 changes: 1 addition & 4 deletions .github/workflows/pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }}
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
TORCH_INSTALL: 1
steps:
- uses: actions/checkout@v3

Expand All @@ -34,10 +35,6 @@ jobs:
extra-packages: any::pkgdown, local::.
needs: website

- name: Install torch
run: torch::install_torch()
shell: Rscript {0}

- name: Install template
run: pak::pkg_install("mlr-org/mlr3pkgdowntemplate")
shell: Rscript {0}
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ jobs:

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
TORCH_INSTALL: 1

strategy:
fail-fast: false
matrix:
config:
- {os: ubuntu-latest, r: 'devel'}
- {os: ubuntu-latest, r: 'release'}
- {os: macos-latest, r: 'release'}
- {os: windows-latest, r: 'release'}

steps:
- uses: actions/checkout@v3
Expand All @@ -50,10 +53,6 @@ jobs:
extra-packages: any::rcmdcheck
needs: check

- name: Install torch
run: torch::install_torch()
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2

- uses: mxschmitt/action-tmate@v3
Expand Down
18 changes: 9 additions & 9 deletions .lintr
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
linters: with_defaults(
# lintr defaults: https://github.com/jimhester/lintr#available-linters
# the following setup changes/removes certain linters
assignment_linter = NULL, # do not force using <- for assignments
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(180)
)
linters: linters_with_defaults(
assignment_linter=NULL,
object_name_linter=NULL,
cyclocomp_linter=NULL,
commented_code_linter=NULL,
line_length_linter=line_length_linter(180),
indentation_linter=NULL,
object_length_linter=object_length_linter(40))

24 changes: 16 additions & 8 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Imports:
checkmate (>= 2.2.0),
coro,
lgr,
mlr3misc (>= 0.11.0),
mlr3misc (>= 0.14.0),
methods,
data.table,
paradox (>= 0.11.0),
Expand All @@ -59,19 +59,19 @@ Suggests:
progress,
rmarkdown,
viridis,
torchvision,
testthat (>= 3.0.0),
zip
torchvision
Remotes:
r-lib/zip,
mlr-org/mlr3
mlr-org/mlr3,
mlr-org/mlr3pipelines@feat/keep_results,
mlverse/torchvision
Config/testthat/edition: 3
NeedsCompilation: no
ByteCompile: no
VignetteBuilder: knitr
Encoding: UTF-8
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.3.1
Collate:
'CallbackSet.R'
'zzz.R'
Expand All @@ -81,6 +81,8 @@ Collate:
'CallbackSetProgress.R'
'ContextTorch.R'
'DataBackendLazy.R'
'utils.R'
'DataDescriptor.R'
'LearnerTorch.R'
'LearnerTorchImage.R'
'LearnerTorchAlexNet.R'
Expand All @@ -90,6 +92,7 @@ Collate:
'ModelDescriptor.R'
'PipeOpModule.R'
'PipeOpTorch.R'
'PipeOpTaskPreprocTorch.R'
'PipeOpTorchActivation.R'
'PipeOpTorchAvgPool.R'
'PipeOpTorchBatchNorm.R'
Expand All @@ -98,6 +101,7 @@ Collate:
'PipeOpTorchConvTranspose.R'
'PipeOpTorchDropout.R'
'PipeOpTorchHead.R'
'shape.R'
'PipeOpTorchIngress.R'
'PipeOpTorchLayerNorm.R'
'PipeOpTorchLinear.R'
Expand All @@ -110,17 +114,21 @@ Collate:
'PipeOpTorchReshape.R'
'PipeOpTorchSoftmax.R'
'ResamplingRowRoles.R'
'TaskClassif_lazy_iris.R'
'TaskClassif_mnist.R'
'TaskClassif_tiny_imagenet.R'
'TorchDescriptor.R'
'TorchOptimizer.R'
'bibentries.R'
'cache.R'
'imageuri.R'
'lazy_tensor.R'
'learner_torch_methods.R'
'materialize.R'
'merge_graphs.R'
'nn_graph.R'
'paramset_torchlearner.R'
'preprocess.R'
'rd_info.R'
'reset_last_layer.R'
'task_dataset.R'
'utils.R'
'with_torch_settings.R'
36 changes: 27 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Generated by roxygen2: do not edit by hand

S3method("[",imageuri)
S3method("[<-",imageuri)
S3method("[[",imageuri)
S3method("[[<-",imageuri)
S3method("==",lazy_tensor)
S3method("[",lazy_tensor)
S3method("[<-",lazy_tensor)
S3method("[[",lazy_tensor)
S3method("[[<-",lazy_tensor)
S3method(as.data.table,DictionaryMlr3torchCallbacks)
S3method(as.data.table,DictionaryMlr3torchLosses)
S3method(as.data.table,DictionaryMlr3torchOptimizers)
S3method(as_imageuri,character)
S3method(as_imageuri,imageuri)
S3method(as_lazy_tensor,DataDescriptor)
S3method(as_lazy_tensor,dataset)
S3method(as_lazy_tensor,numeric)
S3method(as_lazy_tensor,torch_tensor)
S3method(as_torch_callback,R6ClassGenerator)
S3method(as_torch_callback,TorchCallback)
S3method(as_torch_callback,character)
Expand All @@ -22,10 +25,16 @@ S3method(as_torch_loss,nn_loss)
S3method(as_torch_optimizer,TorchOptimizer)
S3method(as_torch_optimizer,character)
S3method(as_torch_optimizer,torch_optimizer_generator)
S3method(c,imageuri)
S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
S3method(format,lazy_tensor)
S3method(hash_input,lazy_tensor)
S3method(materialize,data.frame)
S3method(materialize,lazy_tensor)
S3method(materialize,list)
S3method(print,ModelDescriptor)
S3method(print,TorchIngressToken)
S3method(print,lazy_tensor)
S3method(reset_last_layer,AlexNet)
S3method(reset_last_layer,resnet)
S3method(t_clbk,"NULL")
Expand All @@ -46,6 +55,7 @@ export(CallbackSetHistory)
export(CallbackSetProgress)
export(ContextTorch)
export(DataBackendLazy)
export(DataDescriptor)
export(LearnerTorch)
export(LearnerTorchAlexNet)
export(LearnerTorchFeatureless)
Expand All @@ -54,6 +64,7 @@ export(LearnerTorchMLP)
export(LearnerTorchModel)
export(ModelDescriptor)
export(PipeOpModule)
export(PipeOpTaskPreprocTorch)
export(PipeOpTorch)
export(PipeOpTorchAvgPool1D)
export(PipeOpTorchAvgPool2D)
Expand All @@ -80,7 +91,7 @@ export(PipeOpTorchHardTanh)
export(PipeOpTorchHead)
export(PipeOpTorchIngress)
export(PipeOpTorchIngressCategorical)
export(PipeOpTorchIngressImage)
export(PipeOpTorchIngressLazyTensor)
export(PipeOpTorchIngressNumeric)
export(PipeOpTorchLayerNorm)
export(PipeOpTorchLeakyReLU)
Expand Down Expand Up @@ -120,14 +131,19 @@ export(TorchDescriptor)
export(TorchIngressToken)
export(TorchLoss)
export(TorchOptimizer)
export(as_lazy_tensor)
export(as_torch_callback)
export(as_torch_callbacks)
export(as_torch_loss)
export(as_torch_optimizer)
export(assert_lazy_tensor)
export(auto_device)
export(batchgetter_categ)
export(batchgetter_num)
export(callback_set)
export(imageuri)
export(is_lazy_tensor)
export(lazy_tensor)
export(materialize)
export(mlr3torch_callbacks)
export(mlr3torch_losses)
export(mlr3torch_optimizers)
Expand All @@ -141,6 +157,7 @@ export(nn_merge_sum)
export(nn_reshape)
export(nn_squeeze)
export(nn_unsqueeze)
export(pipeop_preproc_torch)
export(reset_last_layer)
export(t_clbk)
export(t_clbks)
Expand All @@ -150,6 +167,7 @@ export(t_opt)
export(t_opts)
export(task_dataset)
export(torch_callback)
exportPattern("^PipeOpPreprocTorch")
import(checkmate)
import(data.table)
import(mlr3)
Expand Down
4 changes: 2 additions & 2 deletions R/CallbackSetHistory.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
stopf("No eligible measures to plot for set '%s'.", set)
}

epoch = score = measure = NULL
epoch = score = measure = .data = NULL
if (ncol(data) == 2L) {
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = !!rlang::sym(measures))) +
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
Expand Down
18 changes: 9 additions & 9 deletions R/DataBackendLazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
#'
#' Beware that accessing the backend's hash also contructs the backend.
#'
#' Note that while in most cases the data contains [`lazy_tensor`] columns, this is not necessary and the naming
#' of this class has nothing to do with the [`lazy_tensor`] data type.
#'
#' **Important**
#'
#' When the constructor generates `factor()` variables it is important that the ordering of the levels in data
#' corresponds to the ordering of the levels in the `col_info` argument.
#' Because the ordering of the level depends on the locale, it is recommended to e.g. use the `C` locale in the
#' `constructor` function, e.g. with `withr::with_locale()`.
#'
#' @param constructor (`function()`)\cr
#' A function with no arguments, whose return value must be the actual backend.
#' @param constructor (`function`)\cr
#' A function with argument `backend` (the lazy backend), whose return value must be the actual backend.
#' This function is called the first time the field `$backend` is accessed.
#' @param rownames (`integer()`)\cr
#' The row names. Must be a permutation of the rownames of the lazily constructed backend.
Expand All @@ -53,7 +54,7 @@
#' @export
#' @examples
#' # We first define a backend constructor
#' constructor = function() {
#' constructor = function(backend) {
#' cat("Data is constructed!\n")
#' DataBackendDataTable$new(
#' data.table(x = rnorm(10), y = rnorm(10), row_id = 1:10),
Expand Down Expand Up @@ -100,7 +101,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
assert_choice(primary_key, col_info$id)
private$.colnames = col_info$id
assert_choice(primary_key, col_info$id)
private$.constructor = assert_function(constructor, nargs = 0)
private$.constructor = assert_function(constructor, args = "backend")

super$initialize(data = NULL, primary_key = primary_key, data_formats = data_formats)
},
Expand Down Expand Up @@ -172,7 +173,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
backend = function(rhs) {
assert_ro_binding(rhs)
if (is.null(private$.backend)) {
private$.backend = assert_backend(private$.constructor())
private$.backend = assert_backend(private$.constructor(self))

f = function(test, x, y, var_name) {
if (!test(x, y)) {
Expand All @@ -184,8 +185,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
}
}

# test_equal does not exist so we abuse test_permutation for primary_key
f(test_permutation, private$.backend$primary_key, self$primary_key, "primary key")
f(identical, private$.backend$primary_key, self$primary_key, "primary key")
f(test_permutation, private$.backend$rownames, self$rownames, "row identifiers")
f(test_permutation, private$.backend$colnames, private$.colnames, "column names")
f(test_equal_col_info, col_info(private$.backend), private$.col_info, "column information")
Expand Down
Loading
Loading