Skip to content

Commit

Permalink
added back in test for user-dists without AD
Browse files Browse the repository at this point in the history
  • Loading branch information
danielturek committed Jan 25, 2024
1 parent c3c08f4 commit c8f3c50
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
30 changes: 15 additions & 15 deletions nimbleHMC/R/HMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,21 @@ hmc_checkTarget <- function(model, targetNodes, hmcType) {
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', ')))
## next, check for:
## - target with user-defined distribution (without AD support)
####dists <- model$getDistribution(targetNodes)
####ADok <- rep(TRUE, length(dists))
####for(i in seq_along(dists)) {
#### ## these distributions get re-named to a nimble-version, and won't be found:
#### if(dists[i] %in% c('dweib', 'dmnorm', 'dmvt', 'dwish', 'dinvwish')) next
#### ## find the function or this distribution:
#### nfObj <- get(dists[i], envir = parent.frame(4)) ## this took a bit of an investigation to make work
#### ## is a user-defined distribution:
#### if(!is.null(environment(nfObj)$nfMethodRCobject)) {
#### ## check for AD support:
#### ADok[i] <- !isFALSE(environment(nfObj)$nfMethodRCobject[['buildDerivs']])
#### }
####}
####if(!all(ADok))
#### stop(paste0(hmcType, ' sampler cannot operate on user-defined distributions which do not support AD calculations. Try using buildDerivs = TRUE in the definition the distributions: ', paste0(dists[!ADok], collapse = ', ')))
dists <- model$getDistribution(targetNodes)
ADok <- rep(TRUE, length(dists))
for(i in seq_along(dists)) {
## these distributions get re-named to a nimble-version, and won't be found:
if(dists[i] %in% c('dweib', 'dmnorm', 'dmvt', 'dwish', 'dinvwish')) next
## find the function or this distribution:
nfObj <- get(dists[i], envir = parent.frame(4)) ## this took a bit of an investigation to make work
## is a user-defined distribution:
if(!is.null(environment(nfObj)$nfMethodRCobject)) {
## check for AD support:
ADok[i] <- !isFALSE(environment(nfObj)$nfMethodRCobject[['buildDerivs']])
}
}
if(!all(ADok))
stop(paste0(hmcType, ' sampler cannot operate on user-defined distributions which do not support AD calculations. Try using buildDerivs = TRUE in the definition the distributions: ', paste0(dists[!ADok], collapse = ', ')))
}


Expand Down
3 changes: 1 addition & 2 deletions nimbleHMC/tests/testthat/test-HMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ test_that('hmc_checkTarget catches non-AD support for custom distributions', {
})
Rmodel <- nimbleModel(code, data = list(y=0), inits = list(x=0), buildDerivs = TRUE)
conf <- configureHMC(Rmodel)
###expect_error(buildMCMC(conf))
expect_no_error(buildMCMC(conf))
expect_error(buildMCMC(conf))
##
ddistAD <- nimbleFunction(
run = function(x = double(0), log = integer(0, default = 0)) {
Expand Down

0 comments on commit c8f3c50

Please sign in to comment.