Skip to content

Commit

Permalink
Add compatibility with MCMCDiagnosticTools v0.3 (#401)
Browse files Browse the repository at this point in the history
* Bump MCMCDiagnosticTools compat

* Update imported/exported methods

* Remove type constraint on classifier

* Overload and export mcse

* Overload and update ess and rhat

* Update summarystats

* Update tests

* Increment major version

* Rename ess.jl to ess_rhat.jl

* Add back ess_per_sec

* Fix bug constructing ess_per_sec

* Update ess_rhat tests

* Test mcse

* Update docs

* Remove deprecations

* Remove unused import

* Revert "Fix MLJDecisionTreeInterface to 0.3.0 (#402)"

This reverts commit 991f10b.

* Always include ess_per_sec in table

* Use isequal to pass with missing values

* Use isequal for missing

* Remove naive_se

Fixes #351

* Test Tables interface before loading StatsPlots

DataValues (a StatsPlots dependency) pirates a convert method that causes the Tables equality tests with `missing` to fail. See https://github.com/queryverse/DataValues.jl

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
sethaxen and devmotion authored Feb 28, 2023
1 parent 991f10b commit ddac60f
Show file tree
Hide file tree
Showing 16 changed files with 268 additions and 151 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "5.7.1"
version = "6.0.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -21,7 +21,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Expand All @@ -35,7 +34,7 @@ Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
Formatting = "0.4"
IteratorInterfaceExtensions = "0.1.1, 1"
KernelDensity = "0.6.2"
MCMCDiagnosticTools = "0.2"
MCMCDiagnosticTools = "0.3"
MLJModelInterface = "0.3.5, 0.4, 1.0"
NaturalSort = "1"
OrderedCollections = "1.4"
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ CategoricalArrays = "0.8, 0.9, 0.10"
DataFrames = "0.22, 1"
Documenter = "0.26, 0.27"
Gadfly = "1.3.4"
MCMCChains = "5"
MCMCChains = "6"
MLJBase = "0.19, 0.20, 0.21"
MLJDecisionTreeInterface = "=0.3.0"
MLJDecisionTreeInterface = "0.3"
StatsPlots = "0.14, 0.15"
julia = "1.7"
3 changes: 2 additions & 1 deletion docs/src/diagnostics.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Pages = [
"heideldiag.jl",
"rafterydiag.jl",
"rstar.jl",
"ess.jl"
"ess_rhat.jl",
"mcse.jl",
]
```
27 changes: 7 additions & 20 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import IteratorInterfaceExtensions

import LinearAlgebra
import Random
import Serialization
import Statistics: std, cor, mean, var, mean!

export Chains, chains, chainscat
Expand All @@ -36,13 +35,15 @@ export ChainDataFrame
export summarize

# Reexport diagnostics functions
using MCMCDiagnosticTools: discretediag, ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod,
gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, rafterydiag, rstar
using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod,
BDAAutocovMethod, gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, mcse,
rafterydiag, rhat, rstar
export discretediag
export ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod
export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod
export gelmandiag, gelmandiag_multivariate
export gewekediag
export heideldiag
export mcse
export rafterydiag
export rstar

Expand All @@ -69,13 +70,14 @@ end
include("utils.jl")
include("chains.jl")
include("constructors.jl")
include("ess.jl")
include("ess_rhat.jl")
include("summarize.jl")
include("discretediag.jl")
include("fileio.jl")
include("gelmandiag.jl")
include("gewekediag.jl")
include("heideldiag.jl")
include("mcse.jl")
include("rafterydiag.jl")
include("sampling.jl")
include("stats.jl")
Expand All @@ -84,19 +86,4 @@ include("plot.jl")
include("tables.jl")
include("rstar.jl")

# deprecations
# TODO: Remove dependency on Serialization if this deprecation is removed
# somehow `@deprecate` doesn't work with qualified function names,
# so we use the following hack
const _read = Base.read
const _write = Base.write
Base.@deprecate _read(
f::AbstractString,
::Type{T}
) where {T<:Chains} Serialization.deserialize(f) false
Base.@deprecate _write(
f::AbstractString,
c::Chains
) Serialization.serialize(f, c) false

end # module
33 changes: 0 additions & 33 deletions src/ess.jl

This file was deleted.

85 changes: 85 additions & 0 deletions src/ess_rhat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
ess(chains::Chains; duration=compute_duration, kwargs...)
Estimate the effective sample size.
ESS per second options include `duration=MCMCChains.compute_duration` (the default)
and `duration=MCMCChains.wall_duration`.
"""
function MCMCDiagnosticTools.ess(
chains::Chains;
sections = _default_sections(chains), duration = compute_duration, kwargs...
)
# Subset the chain
_chains = Chains(chains, _clean_sections(chains, sections))

# Estimate the effective sample size
ess = MCMCDiagnosticTools.ess(
_permutedims_diagnostics(_chains.value.data);
kwargs...,
)

# Calculate ESS/minute if available
dur = duration(chains)

# Convert to NamedTuple
ess_per_sec = ess ./ dur
nt = merge((parameters = names(_chains),), (; ess, ess_per_sec))

return ChainDataFrame("ESS", nt)
end

"""
rhat(chains::Chains; kwargs...)
Estimate the ``\\widehat{R}`` diagnostic.
"""
function MCMCDiagnosticTools.rhat(
chains::Chains;
sections = _default_sections(chains), kwargs...
)
# Subset the chain
_chains = Chains(chains, _clean_sections(chains, sections))

# Estimate the rhat
rhat = MCMCDiagnosticTools.rhat(
_permutedims_diagnostics(_chains.value.data);
kwargs...,
)

# Convert to NamedTuple
nt = merge((parameters = names(_chains),), (; rhat))

return ChainDataFrame("R-hat", nt)
end

"""
ess_rhat(chains::Chains; duration=compute_duration, kwargs...)
Estimate the effective sample size and the ``\\widehat{R}`` diagnostic
ESS per second options include `duration=MCMCChains.compute_duration` (the default)
and `duration=MCMCChains.wall_duration`.
"""
function MCMCDiagnosticTools.ess_rhat(
chains::Chains;
sections = _default_sections(chains), duration = compute_duration, kwargs...
)
# Subset the chain
_chains = Chains(chains, _clean_sections(chains, sections))

# Estimate the effective sample size and rhat
ess_rhat = MCMCDiagnosticTools.ess_rhat(
_permutedims_diagnostics(_chains.value.data);
kwargs...,
)

# Calculate ESS/minute if available
dur = duration(chains)

# Convert to NamedTuple
ess_per_sec = ess_rhat.ess ./ dur
nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec))

return ChainDataFrame("ESS/R-hat", nt)
end
22 changes: 22 additions & 0 deletions src/mcse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
mcse(chains::Chains; duration=compute_duration, kwargs...)
Estimate the Monte Carlo standard error.
"""
function MCMCDiagnosticTools.mcse(
chains::Chains;
sections = _default_sections(chains), kwargs...
)
# Subset the chain
_chains = Chains(chains, _clean_sections(chains, sections))

# Estimate the effective sample size
mcse = MCMCDiagnosticTools.mcse(
_permutedims_diagnostics(_chains.value.data);
kwargs...,
)

nt = merge((parameters = names(_chains),), (; mcse))

return ChainDataFrame("MCSE", nt)
end
4 changes: 2 additions & 2 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ true
```
"""
function MCMCDiagnosticTools.rstar(
classif::MLJModelInterface.Supervised, chn::Chains; kwargs...
classif, chn::Chains; kwargs...
)
return MCMCDiagnosticTools.rstar(Random.GLOBAL_RNG, classif, chn; kwargs...)
end

function MCMCDiagnosticTools.rstar(
rng::Random.AbstractRNG,
classif::MLJModelInterface.Supervised,
classif,
chn::Chains;
sections = _default_sections(chn),
kwargs...
Expand Down
36 changes: 25 additions & 11 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,13 @@ end
chains;
sections = _default_sections(chains),
append_chains= true,
method::AbstractESSMethod = ESSMethod(),
autocov_method::AbstractAutocovMethod = AutocovMethod(),
maxlag = 250,
etype = :bm,
kwargs...
)
Compute the mean, standard deviation, naive standard error, Monte Carlo standard error,
and effective sample size for each parameter in the chain.
Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective
sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain.
Setting `append_chains=false` will return a vector of dataframes containing the summary
statistics for each chain.
Expand All @@ -288,27 +287,42 @@ function summarystats(
chains::Chains;
sections = _default_sections(chains),
append_chains::Bool = true,
method::MCMCDiagnosticTools.AbstractESSMethod = ESSMethod(),
autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(),
maxlag = 250,
etype = :bm,
kwargs...
)
# Store everything.
funs = [meancskip, stdcskip, semcskip, x -> MCMCDiagnosticTools.mcse(cskip(x); method=etype, kwargs...)]
func_names = [:mean, :std, :naive_se, :mcse]
funs = [meancskip, stdcskip]
func_names = [:mean, :std]

# Subset the chain.
_chains = Chains(chains, _clean_sections(chains, sections))

# Calculate ESS separately.
ess_df = MCMCDiagnosticTools.ess_rhat(_chains; sections = nothing, method = method, maxlag = maxlag)
# Calculate MCSE and ESS/R-hat separately.
mcse_df = MCMCDiagnosticTools.mcse(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag,
)
ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank
)
ess_tail_df = MCMCDiagnosticTools.ess(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail
)
nt_additional = (
mcse=mcse_df.nt.mcse,
ess_bulk=ess_rhat_rank_df.nt.ess,
ess_tail=ess_tail_df.nt.ess,
rhat=ess_rhat_rank_df.nt.rhat,
ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec,
)
additional_df = ChainDataFrame("Additional", nt_additional)

# Summarize.
summary_df = summarize(
_chains, funs...;
func_names = func_names,
append_chains = append_chains,
additional_df = ess_df,
additional_df = additional_df,
name = "Summary Statistics",
sections = nothing
)
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ Documenter = "0.26, 0.27"
FFTW = "1.1"
IteratorInterfaceExtensions = "1"
KernelDensity = "0.6.2"
MCMCChains = "5"
MCMCChains = "6"
MLJBase = "0.18, 0.19, 0.20, 0.21"
MLJDecisionTreeInterface = "=0.3.0"
MLJDecisionTreeInterface = "0.3"
StatsBase = "0.33.2"
StatsPlots = "0.14.17, 0.15"
TableTraits = "1"
Expand Down
Loading

2 comments on commit ddac60f

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78718

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v6.0.0 -m "<description of version>" ddac60fb57b8dba6955922cec46573b310dbc0c1
git push origin v6.0.0

Please sign in to comment.