diff --git a/src/sampler.jl b/src/sampler.jl index 7d1b7eb5..d66e716e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -163,6 +163,7 @@ function sample( adaptor::AbstractAdaptor = NoAdaptation(), n_adapts::Int = min(div(n_samples, 10), 1_000); drop_warmup = false, + keep_gradients = false, verbose::Bool = true, progress::Bool = false, (pm_next!)::Function = pm_next!, @@ -225,7 +226,12 @@ function sample( # Store sample if !drop_warmup || i > n_adapts j = i - drop_warmup * n_adapts - θs[j], stats[j] = t.z.θ, tstat + if keep_gradients + sample = [t.z.θ; t.z.ℓπ] + else + sample = t.z.θ + end + θs[j], stats[j] = sample, tstat end end # Report end of sampling