Skip to content

Commit

Permalink
Merge pull request #42 from TuringLang/logger-type
Browse files Browse the repository at this point in the history
Loosen the logger type in TensorBoardLogger
  • Loading branch information
yebai authored Jul 8, 2023
2 parents d9b2e48 + 2e9de56 commit 6473445
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TuringCallbacks"
uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c"
authors = ["Tor Erlend Fjelde <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pkg> add TuringCallbacks.jl
```

## Visualizing sampling on-the-fly
`TensorBoardCallback` is a wrapper around `TensorBoardLogger.TBLogger` which can be used to create a `callback` compatible with `Turing.sample`.
`TensorBoardCallback` is a wrapper around `Base.CoreLogging.AbstractLogger` which can be used to create a `callback` compatible with `Turing.sample`.

To actually visualize the results of the logging, you need to have installed `tensorboard` in Python. If you do not have `tensorboard` installed,
it should hopefully be sufficient to just run
Expand All @@ -35,7 +35,7 @@ python3 -m tensorboard.main --logdir tensorboard_logs/run
```
Now we're ready to actually write some Julia code.

The following snippet demonstrates the usage of `TensorBoardCallback` on a simple model.
The following snippet demonstrates the usage of `TensorBoardCallback` on a simple model.
This will write a set of statistics at each iteration to an event-file compatible with Tensorboard:

```julia
Expand Down
23 changes: 14 additions & 9 deletions src/callbacks/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ using Dates
"""
$(TYPEDEF)
Wraps a `TensorBoardLogger.TBLogger` to construct a callback to be passed to `AbstractMCMC.step`.
Wraps a `CoreLogging.AbstractLogger` to construct a callback to be
passed to `AbstractMCMC.step`.
# Usage
TensorBoardCallback(; kwargs...)
TensorBoardCallback(directory::string[, stats]; kwargs...)
TensorBoardCallback(lg::TBLogger[, stats]; kwargs...)
TensorBoardCallback(lg::AbstractLogger[, stats]; kwargs...)
Constructs an instance of a `TensorBoardCallback`, creating a `TBLogger` if `directory` is
Constructs an instance of a `TensorBoardCallback`, creating a `TBLogger` if `directory` is
provided instead of `lg`.
## Arguments
- `lg`: an instance of an `AbstractLogger` which implements `TuringCallbacks.increment_step!`.
- `stats = nothing`: `OnlineStat` or lookup for variable name to statistic estimator.
If `stats isa OnlineStat`, we will create a `DefaultDict` which copies `stats` for unseen
variable names.
Expand All @@ -24,9 +26,9 @@ provided instead of `lg`.
## Keyword arguments
- `num_bins::Int = 100`: Number of bins to use in the histograms.
- `filter = nothing`: Filter determining whether or not we should log stats for a
- `filter = nothing`: Filter determining whether or not we should log stats for a
particular variable and value; expected signature is `filter(varname, value)`.
If `isnothing` a default-filter constructed from `exclude` and
If `isnothing` a default-filter constructed from `exclude` and
`include` will be used.
- `exclude = nothing`: If non-empty, these variables will not be logged.
- `include = nothing`: If non-empty, only these variables will be logged.
Expand All @@ -41,7 +43,7 @@ $(TYPEDFIELDS)
"""
struct TensorBoardCallback{L,F,VI,VE}
"Underlying logger."
logger::TBLogger
logger::AbstractLogger
"Lookup for variable name to statistic estimate."
stats::L
"Filter determining whether or not we should log stats for a particular variable."
Expand All @@ -68,15 +70,15 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs.
else
directory
end

# Set up the logger
lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0)

return TensorBoardCallback(lg, args...; kwargs...)
end

function TensorBoardCallback(
lg::TBLogger,
lg::AbstractLogger,
stats = nothing;
num_bins::Int = 100,
exclude = nothing,
Expand Down Expand Up @@ -162,6 +164,9 @@ extras(transition; kwargs...) = ()
extras(transition, state; kwargs...) = extras(transition; kwargs...)
extras(model, sampler, transition, state; kwargs...) = extras(transition, state; kwargs...)

increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) =
TensorBoardLogger.increment_step!(lg, Δ_Step)

function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...)
stats = cb.stats
lg = cb.logger
Expand Down Expand Up @@ -189,6 +194,6 @@ function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, itera
end
end
# Increment the step for the logger.
TensorBoardLogger.increment_step!(lg, 1)
increment_step!(lg, 1)
end
end
2 changes: 1 addition & 1 deletion src/tensorboardlogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end

# Unlike the `preprocess` overload, this allows us to specify if we want to normalize
function TBL.log_histogram(
logger::TBLogger, name::AbstractString, hist::OnlineStats.HistogramStat;
logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat;
step=nothing, normalize=false
)
edges = edges(hist)
Expand Down

2 comments on commit 6473445

@yebai
Copy link
Member Author

@yebai yebai commented on 6473445 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87091

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" 6473445b4dc186ec18c091af70a13960d32a9e82
git push origin v0.3.1

Please sign in to comment.