Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to the primitives module #1940

Merged
merged 19 commits into from
Dec 23, 2024

Conversation

juanitorduz
Copy link
Contributor

Add type hints to the primitives module and refine other type hints accordingly.

@juanitorduz juanitorduz marked this pull request as draft December 20, 2024 12:17
@juanitorduz
Copy link
Contributor Author

In ba9a07a I suggest adding a DistributionLike type as a Protocol (https://mypy.readthedocs.io/en/stable/protocols.html) which just list the minimal required attributes and methods expected from a stochastic function (from the docstrings)

@juanitorduz juanitorduz marked this pull request as ready for review December 20, 2024 14:14
numpyro/contrib/hsgp/approximation.py Show resolved Hide resolved
@@ -25,7 +26,7 @@
]


def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _compute_chain_variance_stats(x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why NDArray is used instead of ndarray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyPy started complaining while I was adding hints and decided to follow the typing guidelines from numpy https://numpy.org/doc/stable/reference/typing.html :) Then all errors were gone

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. thanks! maybe import NDArray instead of npt? It makes the code easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! changed in 1bd90ea

numpyro/distributions/distribution.py Show resolved Hide resolved
numpyro/primitives.py Outdated Show resolved Hide resolved
@juanitorduz
Copy link
Contributor Author

I need to fix some stuff because of the last released version of mypy (https://pypi.org/project/mypy/#history) 😅

@juanitorduz juanitorduz marked this pull request as draft December 20, 2024 18:34
@juanitorduz juanitorduz marked this pull request as ready for review December 20, 2024 19:01
@juanitorduz juanitorduz requested a review from fehiepsi December 20, 2024 19:01
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta
return jnp.asarray(phi @ beta)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed when inputs are Array now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, changed in a7984ce

with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)
return jnp.asarray(phi @ (spd * beta))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed when inputs are Array now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, changed in a7984ce

from numpyro.util import find_stack_level, identity

_PYRO_STACK = []
# Type aliases
MessageType = dict[str, Any]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess Message is enough. We don't have Message class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! Changed in 6dd9396

@juanitorduz juanitorduz requested a review from fehiepsi December 21, 2024 19:16
@@ -292,9 +316,9 @@ def deterministic(name, value):
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
return jnp.asarray(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we return ArrayLike and avoid using asarray here and others? I use numpyro with string types so jnp.asarray does not work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! 19e7cfc

@@ -253,15 +277,15 @@ def param(name, init_value=None, **kwargs):
assert not callable(
init_value
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler."
return init_value
return jnp.asarray(init_value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like below, keep init_value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! 19e7cfc

@juanitorduz juanitorduz requested a review from fehiepsi December 21, 2024 19:32
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool! Thanks, @juanitorduz!

@juanitorduz
Copy link
Contributor Author

I will continue with the handlers, and once we crack them, the rest should be easier (I hope). Thankyou for your valuable feedback @fehiepsi 🙇

@fehiepsi fehiepsi merged commit e71aa62 into pyro-ppl:master Dec 23, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants