Skip to content

Commit

Permalink
better error messages for lazy tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed May 10, 2024
1 parent 0516743 commit a05742e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
16 changes: 11 additions & 5 deletions R/DataDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,23 @@ assert_compatible_shapes = function(shapes, dataset) {

# prevent user from e.g. forgetting to wrap the return in a list
example = if (is.null(dataset$.getbatch)) {
example = dataset$.getitem(1L)
if (!test_list(example)) {
stopf("dataset must return names list")
}
map(example, function(x) x$unsqueeze(1))
dataset$.getitem(1L)
} else {
dataset$.getbatch(1L)
}
if (!test_list(example, names = "unique") || !test_permutation(names(example), names(shapes))) {
stopf("Dataset must return a list with named elements that are a permutation of the dataset_shapes names.")
}
iwalk(example, function(x, nm) {
if (!test_class(x, "torch_tensor")) {
stopf("The dataset must return torch tensors, but element '%s' is of class %s", nm, class(x)[[1L]])
}
})

if (is.null(dataset$.getbatch)) {
example = map(example, function(x) x$unsqueeze(1))
}

iwalk(shapes, function(dataset_shape, name) {
if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
expected_shape = example[[name]]$shape
Expand Down
44 changes: 44 additions & 0 deletions tests/testthat/test_lazy_tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,47 @@ test_that("comparison", {
expect_equal(x[c(1, 1)] == x, c(TRUE, FALSE))
expect_equal(x == y, c(FALSE, FALSE))
})

test_that("error messages: no torch tensor or no unique names", {
ds = dataset(
initialize = function() self$x = torch_randn(10, 3, 3),
.getitem = function(i) list(x = self$x[i, ], y = sample.int(1)),
.length = function() nrow(self$x)
)()

expect_error(
as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, 3, 3), y = NULL)),
regexp = "must return torch tensors"
)

dsb = dataset(
initialize = function() self$x = torch_randn(10, 3, 3),
.getbatch = function(i) list(x = self$x[i, , drop = FALSE], y = sample.int(1)),
.length = function() nrow(self$x)
)()

expect_error(
as_lazy_tensor(dsb, dataset_shapes = list(x = c(NA, 3, 3), y = NULL)),
regexp = "must return torch tensors"
)

ds1 = dataset(
initialize = function() self$x = torch_randn(10, 3, 3),
.getitem = function(i) list(self$x[i, ]),
.length = function() nrow(self$x)
)()
expect_error(
as_lazy_tensor(ds1, dataset_shapes = list(x = c(NA, 3, 3))),
regexp = "list with named elements"
)

ds1b = dataset(
initialize = function() self$x = torch_randn(10, 3, 3),
.getbatch = function(i) list(self$x[i, drop = FALSE]),
.length = function() nrow(self$x)
)()
expect_error(
as_lazy_tensor(ds1, dataset_shapes = list(x = c(NA, 3, 3))),
regexp = "list with named elements"
)
})

0 comments on commit a05742e

Please sign in to comment.