Skip to content

Commit

Permalink
access path field of callback, ifelse -> if
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Oct 16, 2024
1 parent 5f7d56f commit a1ddb7d
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tests/testthat/test_CallbackSetTB.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
event_tag_is = function(event, tag_name) {
ifelse(is.null(event), FALSE, event["tag"] == tag_name)
if (is.null(event)) FALSE else event["tag"] == tag_name
}

test_that("autotest", {
Expand All @@ -13,9 +13,11 @@ test_that("metrics are logged correctly", {

task = tsk("iris")

n_epochs = 10

mlp = lrn("classif.mlp",
callbacks = cb,
epochs = 10, batch_size = 150, neurons = 10,
epochs = n_epochs, batch_size = 150, neurons = 10,
validate = 0.2,
measures_valid = msrs(c("classif.acc", "classif.ce")),
measures_train = msrs(c("classif.acc", "classif.ce"))
Expand All @@ -25,7 +27,7 @@ test_that("metrics are logged correctly", {

mlp$train(task)

events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist)
events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist)

n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss"))
n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc"))
Expand All @@ -43,20 +45,23 @@ test_that("metrics are logged correctly", {
test_that("eval_freq works", {
task = tsk("iris")

n_epochs = 9
eval_freq = 4

mlp = lrn("classif.mlp",
callbacks = t_clbk("tb"),
epochs = 9, batch_size = 150, neurons = 200,
epochs = n_epochs, batch_size = 150, neurons = 200,
validate = 0.2,
measures_valid = msrs(c("classif.acc", "classif.ce")),
measures_train = msrs(c("classif.acc", "classif.ce")),
eval_freq = 4
eval_freq = eval_freq
)
mlp$param_set$set_values(cb.tb.path = tempfile())
mlp$param_set$set_values(cb.tb.log_train_loss = TRUE)

mlp$train(task)

events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist)
events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist)

n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss"))
n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc"))
Expand All @@ -71,22 +76,19 @@ test_that("eval_freq works", {
expect_equal(n_valid_ce_events, ceiling(n_epochs / eval_freq))
})

test_that("the flag for tracking the train loss works", {
test_that("we can disable training loss tracking", {
task = tsk("iris")

mlp = lrn("classif.mlp",
callbacks = t_clbk("tb"),
epochs = 10, batch_size = 150, neurons = 200,
validate = 0.2,
measures_valid = msrs(c("classif.acc", "classif.ce")),
measures_train = msrs(c("classif.acc", "classif.ce"))
epochs = 10, batch_size = 150, neurons = 200
)
mlp$param_set$set_values(cb.tb.path = tempfile()
mlp$param_set$set_values(cb.tb.path = tempfile())
mlp$param_set$set_values(cb.tb.log_train_loss = FALSE)

mlp$train(task)

events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist)
events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist)

n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss"))

Expand Down

0 comments on commit a1ddb7d

Please sign in to comment.