Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add way to set backend fn random generators #7629

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Dec 26, 2024

This PR depends on #7540. Do not merge before that one has been merged.

Description

This PR adds the set_function_rngs function that takes a compiled pytensor function, looks for any random generators in it, and makes a copy that sets all generators to spawned versions of a supplied generator. The implementation is taken from @ricardoV94's function here (and that's why I listed him as coauthor of this commit). This function is then used in init_traces and base trace initialization to make the backend's fn have a properly seeded random generator. Thanks to this change, we can reproducible sampling results even when the function has Deterministic that depend on raw pytensor random variables.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

Need to raise an error if the linker is JAXLinker as function.get_shared() doesn't retrieve the shared RNGs that are actually used by JAX functions (they are incompatible with numpy Generators): https://github.com/pymc-devs/pymc-extras/blob/c1809e8149fc89ac6eadf4bf73050ea6fe82955c/pymc_experimental/sampling/optimizations/conjugate_sampler.py#L79-L82

@ricardoV94
Copy link
Member

Perhaps a more explicit name like copy_function_with_new_rngs? That gives the mental model that the old function is not affected?

@lucianopaz
Copy link
Contributor Author

Need to raise an error if the linker is JAXLinker as function.get_shared() doesn't retrieve the shared RNGs that are actually used by JAX functions (they are incompatible with numpy Generators): https://github.com/pymc-devs/pymc-extras/blob/c1809e8149fc89ac6eadf4bf73050ea6fe82955c/pymc_experimental/sampling/optimizations/conjugate_sampler.py#L79-L82

I don't like errorring out here. Maybe a warning will be good for now, until we figure out how to control jax's random keys.

@lucianopaz
Copy link
Contributor Author

@ricardoV94, I can't get the function used by the trace to compile with JAX or NUMBA mode. Could you have a look at the test and tell me if I'm messing up the config context?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

trace fn does not respect random seed for stochastic "Deterministics"
2 participants