diff --git a/Project.toml b/Project.toml index de9f4b8..177f8a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TuringCallbacks" uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" authors = ["Tor Erlend Fjelde and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/docs/src/index.md b/docs/src/index.md index 8b743f4..ab29cd5 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -22,7 +22,7 @@ pkg> add TuringCallbacks.jl ``` ## Visualizing sampling on-the-fly -`TensorBoardCallback` is a wrapper around `TensorBoardLogger.TBLogger` which can be used to create a `callback` compatible with `Turing.sample`. +`TensorBoardCallback` is a wrapper around `Base.CoreLogging.AbstractLogger` which can be used to create a `callback` compatible with `Turing.sample`. To actually visualize the results of the logging, you need to have installed `tensorboard` in Python. If you do not have `tensorboard` installed, it should hopefully be sufficient to just run @@ -35,7 +35,7 @@ python3 -m tensorboard.main --logdir tensorboard_logs/run ``` Now we're ready to actually write some Julia code. -The following snippet demonstrates the usage of `TensorBoardCallback` on a simple model. +The following snippet demonstrates the usage of `TensorBoardCallback` on a simple model. This will write a set of statistics at each iteration to an event-file compatible with Tensorboard: ```julia diff --git a/src/callbacks/tensorboard.jl b/src/callbacks/tensorboard.jl index 1b727aa..9d28f9a 100644 --- a/src/callbacks/tensorboard.jl +++ b/src/callbacks/tensorboard.jl @@ -3,18 +3,20 @@ using Dates """ $(TYPEDEF) -Wraps a `TensorBoardLogger.TBLogger` to construct a callback to be passed to `AbstractMCMC.step`. +Wraps a `CoreLogging.AbstractLogger` to construct a callback to be +passed to `AbstractMCMC.step`. # Usage TensorBoardCallback(; kwargs...) TensorBoardCallback(directory::string[, stats]; kwargs...) - TensorBoardCallback(lg::TBLogger[, stats]; kwargs...) + TensorBoardCallback(lg::AbstractLogger[, stats]; kwargs...) -Constructs an instance of a `TensorBoardCallback`, creating a `TBLogger` if `directory` is +Constructs an instance of a `TensorBoardCallback`, creating a `TBLogger` if `directory` is provided instead of `lg`. ## Arguments +- `lg`: an instance of an `AbstractLogger` which implements `TuringCallbacks.increment_step!`. - `stats = nothing`: `OnlineStat` or lookup for variable name to statistic estimator. If `stats isa OnlineStat`, we will create a `DefaultDict` which copies `stats` for unseen variable names. @@ -24,9 +26,9 @@ provided instead of `lg`. ## Keyword arguments - `num_bins::Int = 100`: Number of bins to use in the histograms. -- `filter = nothing`: Filter determining whether or not we should log stats for a +- `filter = nothing`: Filter determining whether or not we should log stats for a particular variable and value; expected signature is `filter(varname, value)`. - If `isnothing` a default-filter constructed from `exclude` and + If `isnothing` a default-filter constructed from `exclude` and `include` will be used. - `exclude = nothing`: If non-empty, these variables will not be logged. - `include = nothing`: If non-empty, only these variables will be logged. @@ -41,7 +43,7 @@ $(TYPEDFIELDS) """ struct TensorBoardCallback{L,F,VI,VE} "Underlying logger." - logger::TBLogger + logger::AbstractLogger "Lookup for variable name to statistic estimate." stats::L "Filter determining whether or not we should log stats for a particular variable." @@ -68,7 +70,7 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs. else directory end - + # Set up the logger lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0) @@ -76,7 +78,7 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs. end function TensorBoardCallback( - lg::TBLogger, + lg::AbstractLogger, stats = nothing; num_bins::Int = 100, exclude = nothing, @@ -162,6 +164,9 @@ extras(transition; kwargs...) = () extras(transition, state; kwargs...) = extras(transition; kwargs...) extras(model, sampler, transition, state; kwargs...) = extras(transition, state; kwargs...) +increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) = + TensorBoardLogger.increment_step!(lg, Δ_Step) + function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...) stats = cb.stats lg = cb.logger @@ -189,6 +194,6 @@ function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, itera end end # Increment the step for the logger. - TensorBoardLogger.increment_step!(lg, 1) + increment_step!(lg, 1) end end diff --git a/src/tensorboardlogger.jl b/src/tensorboardlogger.jl index 9bfb37b..16a74fe 100644 --- a/src/tensorboardlogger.jl +++ b/src/tensorboardlogger.jl @@ -68,7 +68,7 @@ end # Unlike the `preprocess` overload, this allows us to specify if we want to normalize function TBL.log_histogram( - logger::TBLogger, name::AbstractString, hist::OnlineStats.HistogramStat; + logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat; step=nothing, normalize=false ) edges = edges(hist)