Skip to content

Commit

Permalink
remover asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 21, 2024
1 parent 1bd90ea commit 19e7cfc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _masked_observe(
obs: Optional[ArrayLike],
obs_mask,
**kwargs,
) -> Array:
) -> ArrayLike:
# Split into two auxiliary sample sites.
with numpyro.handlers.mask(mask=obs_mask):
observed = sample(f"{name}_observed", fn, **kwargs, obs=obs)
Expand Down Expand Up @@ -247,7 +247,7 @@ def sample(

def param(
name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, **kwargs
) -> Array:
) -> Optional[ArrayLike]:
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.example_libraries.optimizers`. For an example of how `param` statements
Expand Down Expand Up @@ -277,7 +277,7 @@ def param(
assert not callable(
init_value
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler."
return jnp.asarray(init_value)
return init_value

if callable(init_value):

Expand All @@ -304,7 +304,7 @@ def fn(init_fn: Callable, *args, **kwargs):
return msg["value"]


def deterministic(name: str, value: ArrayLike) -> Array:
def deterministic(name: str, value: ArrayLike) -> ArrayLike:
"""
Used to designate deterministic sites in the model. Note that most effect
handlers will not operate on deterministic sites (except
Expand All @@ -316,7 +316,7 @@ def deterministic(name: str, value: ArrayLike) -> Array:
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return jnp.asarray(value)
return value

initial_msg: Message = {
"type": "deterministic",
Expand Down

0 comments on commit 19e7cfc

Please sign in to comment.