Skip to content

Commit

Permalink
bring other tests up to date
Browse files Browse the repository at this point in the history
  • Loading branch information
danielturek committed Jan 24, 2024
1 parent 34fde82 commit a04a3d2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 14 additions & 8 deletions nimbleHMC/R/HMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ sampler_langevin <- nimbleFunction(
hmc_checkTarget <- function(model, targetNodes, hmcType) {
## checks for:
## - target with discrete or truncated distribution
## - target with user-defined distribution (without AD support)
## - dependencies with truncated, dinterval, or dconstraint distribution
calcNodes <- model$getDependencies(targetNodes, stochOnly = TRUE)
if(any(model$isDiscrete(targetNodes)))
Expand All @@ -139,14 +138,21 @@ hmc_checkTarget <- function(model, targetNodes, hmcType) {
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', ')))
if(any(model$getDistribution(calcNodes) == 'dconstraint'))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', ')))
##
## next, check for:
## - target with user-defined distribution (without AD support)
####dists <- model$getDistribution(targetNodes)
####ADok <- sapply(dists,
#### function(dist) {
#### if(!is.null(environment(get(dist))$nfMethodRCobject)) ## user-defined:
#### return(!isFALSE(environment(get(dist))$nfMethodRCobject[['buildDerivs']]))
#### else return(TRUE) ## non-user-defined
#### })
####ADok <- rep(TRUE, length(dists))
####for(i in seq_along(dists)) {
#### ## these distributions get re-named to a nimble-version, and won't be found:
#### if(dists[i] %in% c('dweib', 'dmnorm', 'dmvt', 'dwish', 'dinvwish')) next
#### ## find the function or this distribution:
#### nfObj <- get(dists[i], envir = parent.frame(4)) ## this took a bit of an investigation to make work
#### ## is a user-defined distribution:
#### if(!is.null(environment(nfObj)$nfMethodRCobject)) {
#### ## check for AD support:
#### ADok[i] <- !isFALSE(environment(nfObj)$nfMethodRCobject[['buildDerivs']])
#### }
####}
####if(!all(ADok))
#### stop(paste0(hmcType, ' sampler cannot operate on user-defined distributions which do not support AD calculations. Try using buildDerivs = TRUE in the definition the distributions: ', paste0(dists[!ADok], collapse = ', ')))
}
Expand Down
4 changes: 1 addition & 3 deletions nimbleHMC/tests/testthat/test-HMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ test_that('hmc_checkTarget catches non-AD support for custom distributions', {
})
Rmodel <- nimbleModel(code, data = list(y=0), inits = list(x=0), buildDerivs = TRUE)
conf <- configureHMC(Rmodel)
e <- try(Rmcmc <- buildMCMC(conf), silent = TRUE)
## expect_no_error(buildMCMC(conf))
expect_true(class(e) == 'MCMC')
expect_no_error(buildMCMC(conf))
})

test_that('HMC sampler error messages for invalid M mass matrix arguments', {
Expand Down

0 comments on commit a04a3d2

Please sign in to comment.