Skip to content

Commit

Permalink
Fixed externalsampler (#2089)
Browse files Browse the repository at this point in the history
* fixed getparams call in Transition construction for AbstractMCMC

* bumped patch version

* fixed getparams implementations for AdvancedHMC and AdvancedMH

* fixed ESLogDensityFunction as suggested by @devmotion

* Revert "fixed ESLogDensityFunction as suggested by @devmotion"

This reverts commit 3780ca8.

* removed ESLogDensityFunction

* fixed typo
  • Loading branch information
torfjelde authored Oct 17, 2023
1 parent 4ffe2cd commit 878294f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.29.2"
version = "0.29.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
14 changes: 5 additions & 9 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,15 @@ Wrap a sampler so it can be used as an inference algorithm.
"""
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)

"""
ESLogDensityFunction
A log density function for the External sampler.
"""
const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext}
function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple)
function LogDensityProblems.logdensity(
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
x::NamedTuple
)
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
end

# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
set_namedtuple!(deepcopy(vi), θ)
return vi
end
Expand Down
6 changes: 3 additions & 3 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
# TODO: We should probably rename this `getparams` since it returns something
# very different from `Turing.Inference.getparams`.
θ = getparams(transition)
θ = getparams(f.model, transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
return Transition(f.model, varinfo, transition)
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(transition::AdvancedHMC.Transition) = transition.z.θ
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(transition::AdvancedMH.Transition) = transition.params
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))
Expand Down

2 comments on commit 878294f

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/92953

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.29.3 -m "<description of version>" 878294f1f7879d508f794fff627b3214c18e3a80
git push origin v0.29.3

Please sign in to comment.