Skip to content

Commit

Permalink
fix: remaining mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 4, 2023
1 parent 250202e commit 541d31b
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 20 deletions.
3 changes: 1 addition & 2 deletions examples/advection-esweno-conservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: MIT

import pathlib
from typing import cast

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -32,7 +31,7 @@ def cosine_pulse(
r = (0.5 + 0.5 * jnp.cos(w * (x - xc))) ** p
mask = (jnp.abs(x - xc) < sigma).astype(x.dtype)

return cast(Array, r * mask)
return r * mask


def main(
Expand Down
4 changes: 2 additions & 2 deletions pyshocks/burgers/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: MIT

from dataclasses import dataclass, field, replace
from typing import ClassVar, cast
from typing import ClassVar

import jax.numpy as jnp

Expand Down Expand Up @@ -324,7 +324,7 @@ def hesthaven_limiter(u: Array, *, variant: int = 1) -> Array:
else:
raise ValueError(f"Unknown variant: {variant!r}")

return cast(Array, phi)
return phi


@numerical_flux.register(SSMUSCL)
Expand Down
3 changes: 1 addition & 2 deletions pyshocks/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import enum
from typing import cast

import jax.numpy as jnp

Expand Down Expand Up @@ -123,4 +122,4 @@ def convolve1d(
result = u[n:-n]

assert result.shape == ary.shape
return cast(Array, result)
return result
14 changes: 6 additions & 8 deletions pyshocks/sbp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
import enum
from dataclasses import dataclass, replace
from functools import singledispatch
from typing import Any, cast
from typing import Any

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -363,7 +363,7 @@ def make_sbp_mattsson2012_second_derivative(
assert jnp.linalg.norm(M - M.T) < 1.0e-8
assert jnp.linalg.norm(jnp.sum(M, axis=1)) < 1.0e-8

return cast(Array, invP @ (-M + BS))
return invP @ (-M + BS)


@singledispatch
Expand Down Expand Up @@ -562,7 +562,7 @@ def make_boundary(b1: Array, b2: Array, b3: Array, b4: Array) -> Array:
n, m = mb_r.shape
M = M.at[-n:, -m:].set(mb_r)

return cast(Array, M / dx)
return M / dx


def make_sbp_21_norm_stencil(dtype: Any = None) -> Stencil:
Expand Down Expand Up @@ -780,10 +780,8 @@ def make_sbp_42_second_derivative_r_matrix(
C34 = jnp.diag(make_sbp_matrix_from_stencil(bc, n, c34))
C44 = jnp.diag(make_sbp_matrix_from_stencil(bc, n, c44))

return cast(
Array,
dx**5 / 18 * D34.T @ C34 @ B34 @ D34
+ dx**7 / 144 * D44.T @ C44 @ B44 @ D44,
return (
dx**5 / 18 * D34.T @ C34 @ B34 @ D34 + dx**7 / 144 * D44.T @ C44 @ B44 @ D44
)


Expand Down Expand Up @@ -1042,7 +1040,7 @@ def make_boundary(
n, m = mb_r.shape
M = M.at[-n:, -m:].set(mb_r)

return cast(Array, -M / dx)
return -M / dx


def make_sbp_42_norm_stencil(dtype: Any = None) -> Stencil:
Expand Down
2 changes: 1 addition & 1 deletion pyshocks/timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def step(
"""
if tfinal is None:
# NOTE: can't set this to jnp.inf because we have a debug check for it
tfinal = jnp.finfo(u0.dtype).max
tfinal = jnp.finfo(u0.dtype).max # type: ignore[no-untyped-call]

m = 0
t = jnp.array(tstart, dtype=u0.dtype)
Expand Down
6 changes: 3 additions & 3 deletions pyshocks/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def estimate_order_of_convergence(x: Array, y: Array) -> tuple[Scalar, Scalar]:
if x.size <= 1:
raise RuntimeError("Need at least two values to estimate order.")

eps = jnp.finfo(x.dtype).eps
eps = jnp.finfo(x.dtype).eps # type: ignore[no-untyped-call]
c = jnp.polyfit(jnp.log10(x + eps), jnp.log10(y + eps), 1)
return 10 ** c[-1], c[-2]

Expand Down Expand Up @@ -272,7 +272,7 @@ def satisfied(

_, error = self._history
if atol is None:
atol = 1.0e2 * jnp.finfo(error.dtype).eps
atol = 1.0e2 * jnp.finfo(error.dtype).eps # type: ignore[no-untyped-call]

return bool(self.estimated_order >= (order - slack) or jnp.max(error) < atol)

Expand Down Expand Up @@ -877,7 +877,7 @@ def _anim_func(n: int) -> tuple[Any, ...]:
if legends is not None:
ax.set_ylabel(ylabel)
ax.set_xlim((float(x[0]), float(x[-1])))
ax.set_ylim((ymin - 0.1 * jnp.abs(ymin), ymax + 0.1 * jnp.abs(ymax)))
ax.set_ylim((float(ymin - 0.1 * jnp.abs(ymin)), float(ymax + 0.1 * jnp.abs(ymax))))
ax.margins(0.05)

if legends:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_finite_difference_taylor_stencil(*, visualize: bool = False) -> None:
ax.set_xlabel("$k h$")
ax.set_ylabel(r"$\tilde{k} h$")
ax.set_xlim((0.0, jnp.pi))
ax.set_ylim((0.0, sign * jnp.pi**s.derivative))
ax.set_ylim((0.0, float(sign * jnp.pi**s.derivative)))

fig.savefig(f"finite_difference_wavenumber_{s.derivative}_{s.order}")
fig.clf()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sbp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_sbp_matrices(name: str, bc: BoundaryType, *, visualize: bool = False) -
fig.clf()

# NOTE: allow negative values larger than eps because floating point..
mask = jnp.real(s) > -jnp.finfo(dtype).eps
mask = jnp.real(s) > -jnp.finfo(dtype).eps # type: ignore[no-untyped-call]
assert jnp.all(mask), jnp.real(s[~mask])

# }}}
Expand Down

0 comments on commit 541d31b

Please sign in to comment.