-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible Improvements to FixedContext
#710
base: torfjelde/context-cleanup
Are you sure you want to change the base?
Conversation
cases where current `fix` is failiing
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I certainly wish we didn't need to have added complexity, but I'm thoroughly convinced by the profiling 😄 Thank you for looking into it.
To mitigate that, I think we should at least elaborate in the docstring of contextual_isfixed
about this case (where the FixedContext may contain variables that don't match those in the model)? I think that will at least help us (or me) the next time we revisit haha.
CI
Separately, there are a few tests that are failing. Most of them should be fixed by the suggested changes (a small typo). Some others will be fixed by #704, so we just need to merge master into this branch to fix those.
However, there's one more newly failing test @ test/turing/model.jl:6
specifically for demo_dot_assume_matrix_dot_observe_matrix
. Here's a DPPL-only MWE:
using DynamicPPL, Distributions
@model function f()
s = Array{Float64}(undef, 1, 2)
s .~ product_distribution([InverseGamma(2, 3)])
# also fails
# s .~ MvNormal([0.0], [1.0])
end
model = f()
# this doesn't fix the variables, because the varnames are not
# concretised -- although this probably isn't a particularly huge deal
fix(model, @varname(s[:, 1], false) => [1.0], @varname(s[:, 2], false) => [2.0])()
# however, this version with concretised varnames errors, and
# `generated_quantities` calls this and in turn errors
s = Array{Float64}(undef, 1, 2)
fix(model, @varname(s[:, 1], true) => [1.0], @varname(s[:, 2], true) => [2.0])()
# e.g. like this (which is a simplified version of test/turing/model.jl:6)
using Turing
chain = sample(model, Prior(), 10)
generated_quantities(model, MCMCChains.get_sections(chain, :parameters))
Co-authored-by: Penelope Yong <[email protected]>
Thanks for catching those typos @penelopeysm ! Regarding the failiing test, this feels like someting we should be able to fix 👍 |
Thoughts on the incosistency of overriding |
Btw, coonretization doesn't handle Lines 1060 to 1085 in da6f9a0
So the error is caused (after fixing some broadcasting bug with |
That's true using Turing
@model function f()
s = Array{Float64}(undef, 1, 2)
s .~ product_distribution([InverseGamma(2, 3)])
end
chain = sample(f(), Prior(), 10)
dump(collect(keys(chain.info.varname_to_symbol))[1])
#=
AbstractPPL.VarName{:s, ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}
optic: (@o _[:, 1][1]) (function of type ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}})
outer: Accessors.IndexLens{Tuple{Int64}}
indices: Tuple{Int64}
1: Int64 1
inner: Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}
indices: Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}
1: AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}
range: Base.OneTo{Int64}
stop: Int64 1
2: Int64 1
=# |
Because we use |
Hmm, this is actually quite an annoying issue 😕 It raises the question of whether DynamicPPL.hasvalue(OrderedDict(@varname(s[1,1]) => 0.0), @varname(s[:, 1])) which I'm somewhat uncertain we want 😕 The current implementation assumes Lines 872 to 895 in da6f9a0
In an ideal world, this would also handle stuff like DynamicPPL.hasvalue(OrderedDict(@varname(s[1,1]) => 0.0), @varname(s[:, 1]))
DynamicPPL.hasvalue(OrderedDict(@varname(s[:,1]) => [0.0]), @varname(s[1, 1]))
DynamicPPL.hasvalue(OrderedDict(@varname(s[:,1]) => [0.0]), @varname(s)) but this will complicate the implementation of both EDIT: Similarly we also need to add support for these in |
This is a bit of a drive-by comment, but I've so far failed to wrap my head around how we use I don't really have a proposal for how to change this, but for many cases like your above |
As pointed out by @penelopeysm and @mhauru in #702 ,
FixedContext
andConditionContext
doesn't quite do what we want for.~
statements. One of the reasons we first introducedfix
was to avoid hitting thetilde_*_assume
pipeline to improve performance.This PR implements a possible way of fixing #702 which involves overloading the
tilde_dot_assume
forFixedContext
to handle cases where only parts of the LHS isfixed
.With this branch we can do stuff like fixing only a subset of a
.~
statement:However, it requires overloading
tilde_dot_assume
forFixedContext
, which does go slightly against whattilde_*_assume
is meant to do (it's meant to be used for random variables, but clearly fixed variables are not random).Performance implications
IMO the interesting "case" is when we use
fix(::Model, ::NamedTuple)
since this is consistently what we consider as the "fast mode" in Turing.jl / DynamicPPL.jl, and we can always ask the user to provide the values as aNamedTuple
if they really want performance.There are a few different "approaches" we can take with
fixed
(and equallycondition
):conditional_isfixed
+getfixed_nested(__context__, vn)
in the main-body of a@model
. When it works, this is very performant, as it's just compile-time generated check ofsym in names
forVarName{sym}
andNamedTuple{names}
.tilde_*_assume
pipeline to extract the fixed values.tilde_*_assume
, we also check there fortilde_dot_assume
(so that we cover the cases listed in FixedContext and ConditionedContext don't use the same varnames as tilde-pipeline #702) by iterating over all the variables and defering totilde_assume
(i.e. without thedot
).I ran the following snippet for the different approaches:
On
#master
(Approach 1)On this branch (Approach 3)
As we can see, the performance difference is very, very minor. However, note that this PR still includes the
contetxual_isfixed
checki n the main body of the model.Replace current approach fully be overloading tilde (Approach 2)
If we remove this, i.e. only rely on overloading tilde-pipeline, we get the following result:
As we see here, once we have to rely on a for-loop over the variables to check, we do incur a "signfiicant" runtime overhead.
Conclusion
Performing the check in
dot_tilde_assume
only when explicitly needed doesn't really hurt performance much forfix(::Model, ::NamedTuple)
(i.e. Approach 3 vs. Approach 1).However, purely relying on Approach 2 (i.e. replacing current approach completely with overloading tilde assume) does have quite a significant overhead for just evaluation (assuming this will be even worse when computing gradients).
Soooo I'm leaning towards Approach 3 (as is implemented in this branch), even though it does make things a bit uglier.