-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type hints to the primitives module #1940
Conversation
In ba9a07a I suggest adding a |
numpyro/diagnostics.py
Outdated
@@ -25,7 +26,7 @@ | |||
] | |||
|
|||
|
|||
def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | |||
def _compute_chain_variance_stats(x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why NDArray is used instead of ndarray?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MyPy started complaining while I was adding hints and decided to follow the typing guidelines from numpy https://numpy.org/doc/stable/reference/typing.html :) Then all errors were gone
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting. thanks! maybe import NDArray
instead of npt
? It makes the code easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! changed in 1bd90ea
f8ffed9
to
6274eee
Compare
I need to fix some stuff because of the last released version of mypy (https://pypi.org/project/mypy/#history) 😅 |
6274eee
to
a810776
Compare
with numpyro.plate("basis", m): | ||
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) | ||
|
||
return phi @ beta | ||
return jnp.asarray(phi @ beta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed when inputs are Array now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, changed in a7984ce
with numpyro.plate("basis", m): | ||
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0)) | ||
|
||
return phi @ (spd * beta) | ||
return jnp.asarray(phi @ (spd * beta)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed when inputs are Array now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, changed in a7984ce
numpyro/primitives.py
Outdated
from numpyro.util import find_stack_level, identity | ||
|
||
_PYRO_STACK = [] | ||
# Type aliases | ||
MessageType = dict[str, Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess Message is enough. We don't have Message class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed! Changed in 6dd9396
numpyro/primitives.py
Outdated
@@ -292,9 +316,9 @@ def deterministic(name, value): | |||
:param jnp.ndarray value: deterministic value to record in the trace. | |||
""" | |||
if not _PYRO_STACK: | |||
return value | |||
return jnp.asarray(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we return ArrayLike and avoid using asarray here and others? I use numpyro with string types so jnp.asarray does not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! 19e7cfc
numpyro/primitives.py
Outdated
@@ -253,15 +277,15 @@ def param(name, init_value=None, **kwargs): | |||
assert not callable( | |||
init_value | |||
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler." | |||
return init_value | |||
return jnp.asarray(init_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like below, keep init_value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! 19e7cfc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool! Thanks, @juanitorduz!
I will continue with the handlers, and once we crack them, the rest should be easier (I hope). Thankyou for your valuable feedback @fehiepsi 🙇 |
Add type hints to the primitives module and refine other type hints accordingly.