Skip to content

Commit

Permalink
Make Turing.jl an optional dep (#36)
Browse files Browse the repository at this point in the history
* abstract away the extraction of stuff frm from the transitions

* make Turing.jl a weakdep or using Requires if not available

* update calls to params_and_values and extars

* fixed conditional loading and typo in params_and_values

* fixed requires usage

* breaking change

* added some simple tests with Turing

* dont run workflow on nightly for all the OSes

* fixed actions maybe

* only run GH actions upon push to main rather than any branch

* nvm

* maybe fixed actions

* run nightly on ubuntu

* nah

* okay maybe

* updated docs and workflow
  • Loading branch information
torfjelde authored Feb 23, 2023
1 parent 515ef9b commit 96af50f
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 55 deletions.
11 changes: 4 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
name: CI
on:
- push
- pull_request
push:
branches:
- master
pull_request:
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -11,17 +13,12 @@ jobs:
matrix:
version:
- '1'
- 'nightly'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- x64
- x86
exclude:
- os: macOS-latest
arch: x86
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
14 changes: 9 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TuringCallbacks"
uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c"
authors = ["Tor Erlend Fjelde <[email protected]> and contributors"]
version = "0.1.9"
version = "0.2.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -11,20 +11,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"

[weakdeps]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[extensions]
TuringCallbacksTuringExt = "Turing"

[compat]
DataStructures = "0.18"
DocStringExtensions = "0.8, 0.9"
OnlineStats = "1.5"
Reexport = "0.2, 1.0"
Requires = "1"
TensorBoardLogger = "0.1"
Turing = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ callback = TensorBoardCallback(
```
Or you can create the filter (a mapping `variable_name -> ::Bool` yourself:
```julia
var_filter(varname) = varname != "m"
var_filter(varname, value1) = varname != "m"
callback = TensorBoardCallback(
"tensorboard_logs/run", stats;
variable_filter = var_filter
filter = var_filter
)
```

Expand Down
26 changes: 26 additions & 0 deletions ext/TuringCallbacksTuringExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module TuringCallbacksTuringExt

if isdefined(Base, :get_extension)
using Turing: Turing
using TuringCallbacks: TuringCallbacks
else
# Requires compatible.
using ..Turing: Turing
using ..TuringCallbacks: TuringCallbacks
end

const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition}

function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference._params_to_array([transition])...)) do (ksym, val)
return string(ksym), val
end
end

function TuringCallbacks.extras(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference.get_transition_extras([transition])...)) do (ksym, val)
return string(ksym), val
end
end

end
14 changes: 12 additions & 2 deletions src/TuringCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@ using Reexport
using LinearAlgebra
using Logging
using DocStringExtensions
import Turing

@reexport using OnlineStats # used to compute different statistics on-the-fly

using TensorBoardLogger
const TBL = TensorBoardLogger

import DataStructures: DefaultDict
using DataStructures: DefaultDict

@static if !isdefined(Base, :get_extension)
using Requires
end

export TensorBoardCallback, DefaultDict, WindowStat, Thin, Skip

include("stats.jl")
include("tensorboardlogger.jl")
include("callbacks/tensorboard.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
@require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("../ext/TuringCallbacksTuringExt.jl")
end
end

end
99 changes: 64 additions & 35 deletions src/callbacks/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ using Dates
"""
$(TYPEDEF)
Wraps a `TensorBoardLogger.TBLogger` to construct a callback to be passed to
`Turing.sample`.
Wraps a `TensorBoardLogger.TBLogger` to construct a callback to be passed to `AbstractMCMC.step`.
# Usage
Expand All @@ -25,8 +24,8 @@ provided instead of `lg`.
## Keyword arguments
- `num_bins::Int = 100`: Number of bins to use in the histograms.
- `variable_filter = nothing`: Filter determining whether or not we should log stats for a
particular variable.
- `filter = nothing`: Filter determining whether or not we should log stats for a
particular variable and value; expected signature is `filter(varname, value)`.
If `isnothing` a default-filter constructed from `exclude` and
`include` will be used.
- `exclude = String[]`: If non-empty, these variables will not be logged.
Expand All @@ -40,15 +39,19 @@ provided instead of `lg`.
# Fields
$(TYPEDFIELDS)
"""
struct TensorBoardCallback{F, L}
struct TensorBoardCallback{L,F,VI,VE}
"Underlying logger."
logger::TBLogger
"Lookup for variable name to statistic estimate."
stats::L
"Filter determining whether or not we should log stats for a particular variable."
variable_filter::F
filter::F
"Variables to include in the logging."
include::VI
"Variables to exclude from the logging."
exclude::VE
"Include extra statistics from transitions."
include_extras::Bool
"Lookup for variable name to statistic estimate."
stats::L
end

function TensorBoardCallback(directory::String, args...; kwargs...)
Expand All @@ -72,26 +75,16 @@ function TensorBoardCallback(
lg::TBLogger,
stats = nothing;
num_bins::Int = 100,
exclude = String[],
include = String[],
exclude = nothing,
include = nothing,
include_extras::Bool = true,
variable_filter = nothing,
filter = nothing,
kwargs...
)
# Create the filter
filter = if !isnothing(variable_filter)
variable_filter
else
varname -> (
(isempty(exclude) || varname exclude) &&
(isempty(include) || varname include)
)
end

# Lookups: create default ones if not given
stats_lookup = if stats isa OnlineStat
# Warn the user if they've provided a non-empty `OnlineStat`
nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)")
OnlineStats.nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)")
let o = stats
DefaultDict{String, typeof(o)}(() -> deepcopy(o))
end
Expand All @@ -106,37 +99,73 @@ function TensorBoardCallback(
end

return TensorBoardCallback(
lg, filter, include_extras, stats_lookup
lg, stats_lookup, filter, include, exclude, include_extras
)
end

function (cb::TensorBoardCallback)(rng, model, sampler, transition, iteration, state; kwargs...)
"""
filter_param_and_value(cb::TensorBoardCallback, param_name, value)
Filter parameters and values from a `transition` based on the `filter` of `cb`.
"""
function filter_param_and_value(cb::TensorBoardCallback, param, value)
if !isnothing(cb.filter)
return cb.filter(param, value)
end

# Othnerwise we construct from `include` and `exclude`.
!isnothing(cb.exclude) && param cb.exclude && return false
!isnothing(cb.include) && param cb.include && return true

return true
end
filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) = filter_param_and_value(cb, param_and_value...)

"""
default_param_names_for_values(x)
Return an iterator of `θ[i]` for each element in `x`.
"""
default_param_names_for_values(x) = ("θ[$i]" for i = 1:length(x))


"""
params_and_values(transition[, state]; param_names = nothing)
Return an iterator over parameter names and values from a `transition`.
"""
params_and_values(transition, state; kwargs...) = params_and_values(transition; kwargs...)

"""
extras(transition[, state]; kwargs...)
Return an iterator with elements of the form `(name, value)` for additional statistics in `transition`.
"""
extras(transition, state; kwargs...) = extras(transition; kwargs...)

function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; param_names=nothing, kwargs...)
stats = cb.stats
lg = cb.logger
filter = cb.variable_filter

filterf = Base.Fix1(filter_param_and_value, cb)

# TODO: Should we use the explicit interface for TensorBoardLogger?
with_logger(lg) do
for (ksym, val) in zip(Turing.Inference._params_to_array([transition])...)
k = string(ksym)
if !filter(k)
continue
end
for (k, val) in Iterators.filter(filterf, params_and_values(transition, state; param_names))
stat = stats[k]

# Log the raw value
@info k val

# Update statistic estimators
fit!(stat, val)
OnlineStats.fit!(stat, val)

# Need some iterations before we start showing the stats
@info k stat
end

# Transition statstics
if cb.include_extras
names, vals = Turing.Inference.get_transition_extras([transition])
for (name, val) in zip(string.(names), vec(vals))
for (name, val) in extras(transition, state; param_names)
@info ("extras/" * name) val
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/tensorboardlogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ tb_name(arg1, arg2) = tb_name(arg1) * "/" * tb_name(arg2)
tb_name(arg, args...) = tb_name(arg) * "/" * tb_name(args...)

function TBL.preprocess(name, stat::OnlineStat, data)
if nobs(stat) > 0
if OnlineStats.nobs(stat) > 0
TBL.preprocess(tb_name(name, stat), value(stat), data)
end
end
Expand Down Expand Up @@ -56,7 +56,7 @@ function TBL.preprocess(name, stat::Series, data)
end

function TBL.preprocess(name, hist::KHist, data)
if nobs(hist) > 0
if OnlineStats.nobs(hist) > 0
# Creates a NORMALIZED histogram
edges = OnlineStats.edges(hist)
cnts = OnlineStats.counts(hist)
Expand Down
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
39 changes: 37 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
using TuringCallbacks
using Test
using Turing
using TuringCallbacks
using TensorBoardLogger, ValueHistories

@testset "TuringCallbacks.jl" begin
# Write your tests here.
# TODO: Improve.
@model function demo(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, s)
for i in eachindex(x)
x[i] ~ Normal(m, s)
end
end

xs = randn(100) .+ 1
model = demo(xs)

# Number of MCMC samples/steps
num_samples = 1_000
num_adapts = 500

# Sampling algorithm to use
alg = NUTS(num_adapts, 0.65)

# Create the callback
callback = TensorBoardCallback(mktempdir())

# Sample
chain = sample(model, alg, num_samples; callback=callback)

# Extract the values.
hist = convert(MVHistory, callback.logger)

# Compare the recorded values to the chain.
m_mean = last(last(hist["m/stat/Mean"]))
s_mean = last(last(hist["s/stat/Mean"]))

@test m_mean mean(chain[:m])
@test s_mean mean(chain[:s])
end

2 comments on commit 96af50f

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78374

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 96af50f18e77c7ee483163fef3484e8a45517bd8
git push origin v0.2.0

Please sign in to comment.