Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bayesian IV slow for large data sets #235

Closed
NathanielF opened this issue Sep 1, 2023 · 6 comments
Closed

Bayesian IV slow for large data sets #235

NathanielF opened this issue Sep 1, 2023 · 6 comments

Comments

@NathanielF
Copy link
Contributor

Just adding a note here to investigate the parameterisation options of the Bayesian IV model.
I was exploring this fitting large IV designs with lots of data and many instruments for this ticket: #229

The Bayesian IV design just takes a very long time and doesn't always give good results on large data sets. Because CausalPy hides some of the model building complexity it's harder for the user to iteratively debug and re-parameterise. I'm wondering if there is there is anything in the model specification we can do to address this.

@NathanielF
Copy link
Contributor Author

FYI @juanitorduz

I think i'm going to read up on this paper: https://www.nber.org/papers/t0204

It seems like a good example of the advantage to a Bayesian solution, but need to ensure some kind of efficiency.

@drbenvincent
Copy link
Collaborator

Just to add that the IV methods slow down the tests and doctests. So a great solution to this issue would also address that :)

@jessegrabowski
Copy link

Have you tried benchmarking the model with the JAX backend? I peeked at the code and saw it was an MvNormal likelihood. I see orders-of-magnitude speedups on statespace models (also MvNormal) by switching to JAX sampler.

@drbenvincent
Copy link
Collaborator

Just wondering if any lessons were learnt from #345 @NathanielF which could be useful for this issue? If the default backend is particularly slow, then could this be worth an issue in the PyMC repo?

Or maybe the speed has improved given changes to the IV code?

@jessegrabowski
Copy link

jessegrabowski commented Jun 19, 2024

The issue for this is pymc-devs/pymc#7348

But the specific error @NathanielF reported in his testing is that there's no JAX funcify for LKJCholeskyCov, which needs another issue. TFP has it, so it should be quite trivial to implement (famous last words)

@NathanielF
Copy link
Contributor Author

The main finding was just that the actual model fit can be quite quick with the default sampler and numpyro sampler. Honestly not a huge difference for pm.sample ~5mins for about 3000 rows of data with priors informed by the 2sls processing step.

Good prior management is important with IV regression especially where the scale of the parameters can be of different magnitude. I'd probably recommend standardising inputs in general but didn't here because I wanted to replicate the Card parameter recovery.

So broadly I think IV model fits can be achieved in reasonable time with the base model.

The main issue before was we had bundled all the ppc sampling into the model fit and hid the progress bars so it took me a minute to realise the majority of the time spent was in the posterior predictive sampling. This was greatly sped up 20mins -> 2seconds with @jessegrabowski 's Jax trick. So I think the issue is perhaps just some inefficiency in the ppc sampling with multivariate normal distributions in the base pymc instantiation...

The same slowness occurs with the prior predictive checks but is less pronounced because we only sample 500 by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants