Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a generic fallback for reconstruct #520

Closed
wants to merge 3 commits into from
Closed

Add a generic fallback for reconstruct #520

wants to merge 3 commits into from

Conversation

yebai
Copy link
Member

@yebai yebai commented Aug 22, 2023

Fix the following issue reported by Christopher Fisher in slack.

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra

@model function model(data; min_rt = minimum(data[2]))
    ν ~ MvNormal(zeros(2), I * 2)
    A ~ truncated(Normal(.8, .4), 0.0, Inf)
    k ~ truncated(Normal(.2, .2), 0.0, Inf)
    τ  ~ Uniform(0.0, min_rt)
    data ~ LBA(;ν, A, k, τ )
end

# generate some data
Random.seed!(254)
dist = LBA(ν=[3.0,2.0], A = .8, k = .2, τ = .3) 
data = rand(dist, 100)

# estimate parameters
chain = sample(model(data), NUTS(200, .65), 100)
predictions = predict(model(missing; min_rt = minimum(data[2])), chain)

@itsdfish
Copy link

itsdfish commented Aug 24, 2023

Thank you for the PR. I confirmed locally that it works on the example above.

Update

I must have produced a false result above. copy does not work, but deepcopy does, e.g.,

copy((a=Int[],b=[]))

throws an error.

@yebai yebai mentioned this pull request Aug 25, 2023
@github-actions
Copy link
Contributor

Pull Request Test Coverage Report for Build 6023022641

  • 0 of 1 (0.0%) changed or added relevant line in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.03%) to 80.375%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/utils.jl 0 1 0.0%
Totals Coverage Status
Change from base Build 6018544130: -0.03%
Covered Lines: 2232
Relevant Lines: 2777

💛 - Coveralls

@codecov
Copy link

codecov bot commented Aug 30, 2023

Codecov Report

Patch coverage has no change and project coverage change: -0.03% ⚠️

Comparison is base (549d9b1) 80.40% compared to head (322aa1d) 80.37%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #520      +/-   ##
==========================================
- Coverage   80.40%   80.37%   -0.03%     
==========================================
  Files          24       24              
  Lines        2776     2777       +1     
==========================================
  Hits         2232     2232              
- Misses        544      545       +1     
Files Changed Coverage Δ
src/utils.jl 78.37% <0.00%> (-0.36%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@torfjelde
Copy link
Member

torfjelde commented Aug 30, 2023

This PR shouldn't be necessary after #521 , no?

EDIT: Ah, sorry, I was thinking about vectorize, my bad!

@torfjelde
Copy link
Member

torfjelde commented Aug 30, 2023

Btw, how is this enough to make predict work? This should be failing when trying to convert into a MCMCChains.Chains since, AFAIK, we don't support NamedTuple.

EDIT: After running it locally I see what's happening 👍 We're just treating it as an iterable.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have some tests for this?

src/utils.jl Outdated Show resolved Hide resolved
Co-authored-by: Tor Erlend Fjelde <[email protected]>
@yebai
Copy link
Member Author

yebai commented Aug 30, 2023

I'll add a test later this week.

@torfjelde
Copy link
Member

torfjelde commented Aug 30, 2023

I'm still confused why this example is even working; vectorize, for example, shouldn't be implemented. Is SequentialSamplingModels doing anything with DynamicPPL?

EDIT: Ah, found it: https://github.com/itsdfish/SequentialSamplingModels.jl/blob/728319eed7fac9c30fd2d2372834761eba97ee4f/src/type_system.jl#L91

But in this scenario, wouldn't it be better to just add a reconstruct implementation there too? These types are really breaking with what is expected of the abstract types they are sub-typing, and so I'm very worried things will just silently produce incorrect results somewhere 😕 And if we then add a default reconstruct, I'm worried this will happen in Turing too.

Of course, the better scenario, is probably to add proper support for something like NamedTupleSupport in Distributions.jl.

@yebai
Copy link
Member Author

yebai commented Aug 30, 2023

NamedTupleSupport random variate types from Distribution would be ideal, but that probably won't happen soon.

@torfjelde
Copy link
Member

True, but then shouldn't we instead make a push to loosen the requirement of using <:Distribution instead? It feels like we should either

  1. follow distributions and its assumptions, or
  2. allow non-distributions to be used on the RHS.

If we do something that is in the "in-between" space, then this increases the likelihood of silently producing incorrect results.

@itsdfish if you're already messing with DynamicPPL-internals such as overloading vectorize, then I'd suggest you just also overload reconstruct, i.e. you add

reconstruct(d::SSM2D, r::AbstractVector) = deepcopy(r)

to the package:)

@itsdfish
Copy link

reconstruct(d::SSM2D, r::AbstractVector) = deepcopy(r)

Thank you. That sounds reasonable. I will do that. Thanks for your help!

@torfjelde
Copy link
Member

Awesome:)

@yebai happy to close this PR then?

@yebai yebai closed this Sep 1, 2023
@yebai yebai deleted the yebai-patch-1 branch September 1, 2023 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants