Skip to content

Commit

Permalink
fix: update type annotations for matplotlib
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 4, 2023
1 parent f4b23fe commit b410658
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 97 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ help: ## Show this help

# {{{ linting

fmt: black ## Run all formatting scripts
format: black ## Run all formatting scripts
$(PYTHON) -m pyproject_fmt --indent 4 pyproject.toml
$(PYTHON) -m isort pyshocks tests examples docs drivers
.PHONY: format

fmt: format
.PHONY: fmt

lint: ruff mypy reuse codespell sphinxlint ## Run all linting scripts
.PHONY: lint

black: ## Run black over the source code
$(PYTHON) -m black \
--safe --target-version py38 --preview \
Expand Down
12 changes: 6 additions & 6 deletions drivers/advection-adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def evolve_forward(
mp.ion()

_, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1.0, jnp.max(u0) + 1.0])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1.0), float(jnp.max(u0) + 1.0)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -110,7 +110,7 @@ def evolve_forward(
ax.plot(grid.x[grid.i_], event.u[grid.i_], label="$u(T)$")
ax.plot(grid.x[grid.i_], u0[grid.i_], "k--", label="$u(0)$")

ax.set_ylim([umin - 0.1 * umag, umax + 0.1 * umag])
ax.set_ylim((float(umin - 0.1 * umag), float(umax + 0.1 * umag)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")
ax.set_title(f"$T = {event.t:.3f}$")
Expand Down Expand Up @@ -171,8 +171,8 @@ def _apply_boundary(t: ScalarLike, u: Array, p: Array) -> Array:

ln0, ln1 = ax.plot(grid.x[s], uf[s], "k--", grid.x[s], p0[s], "o-", ms=1)

ax.set_xlim([grid.a, grid.b])
ax.set_ylim([pmin - 0.25 * pmag, pmax + 0.25 * pmag])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(pmin - 0.25 * pmag), float(pmax + 0.25 * pmag)))
ax.set_xlabel("$x$")

from pyshocks import norm
Expand Down Expand Up @@ -209,7 +209,7 @@ def _apply_boundary(t: ScalarLike, u: Array, p: Array) -> Array:
ax.plot(grid.x[s], event.p[s])
ax.plot(grid.x[s], p0[s], "k--")

ax.set_ylim([pmin - 0.1 * pmag, pmax + 0.1 * pmag])
ax.set_ylim((float(pmin - 0.1 * pmag), float(pmax + 0.1 * pmag)))
ax.set_xlabel("$x$")
ax.set_ylabel("$p$")

Expand Down
12 changes: 6 additions & 6 deletions drivers/burgers-adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def evolve_forward(
mp.ion()

_, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1.0, jnp.max(u0) + 1.0])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1.0), float(jnp.max(u0) + 1.0)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -236,8 +236,8 @@ def _apply_boundary(t: ScalarLike, u: Array, p: Array) -> Array:
# chosen in `main`; modify as needed
ax.axhline(0.5, color="k", linestyle="--", lw=1)

ax.set_xlim([grid.a, grid.b])
ax.set_ylim([pmin - 0.1 * pmag, pmax + 0.1 * pmag])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(pmin - 0.1 * pmag), float(pmax + 0.1 * pmag)))
ax.set_xlabel("$x$")

from pyshocks import norm
Expand Down Expand Up @@ -339,7 +339,7 @@ def main(
ax.plot(grid.x[i], uf[i], label="$u(T)$")
ax.plot(grid.x[i], sim.u0[i], "k--", label="$u(0)$")

ax.set_ylim([umin - 0.1 * umag, umax + 0.1 * umag])
ax.set_ylim((float(umin - 0.1 * umag), float(umax + 0.1 * umag)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")
ax.legend()
Expand All @@ -361,7 +361,7 @@ def main(
ax.plot(grid.x[i], sim.u0[i], "k--", label="$u(0)$")
ax.axhline(0.5, color="k", linestyle=":", lw=1)

ax.set_ylim([pmin - 0.1 * pmag, pmax + 0.1 * pmag])
ax.set_ylim((float(pmin - 0.1 * pmag), float(pmax + 0.1 * pmag)))
ax.set_xlabel("$x$")
ax.legend()

Expand Down
6 changes: 3 additions & 3 deletions examples/advection-esweno-conservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def main(

(ln0,) = ax.plot(grid.x[s], u0[s], "o-", ms=1)

ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -154,7 +154,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s], "o-")
ax.plot(grid.x[s], u0[s], "k--")
ax.axhline(1.0, color="k", ls=":", lw=1)
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
6 changes: 3 additions & 3 deletions examples/advection-sbp-sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def main(
plt.ion()

_, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -145,7 +145,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], u0[s], "k--")
ax.plot(grid.x[s], func(tfinal, grid.x), "k:")
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
6 changes: 3 additions & 3 deletions examples/advection-splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def main(
plt.ion()

ln0, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -147,7 +147,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], u0[s], "k--")
ax.plot(grid.x[s], func(tfinal, grid.x[s]), "k:")
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
8 changes: 4 additions & 4 deletions examples/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def main(

ln0, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.axhline(1.0, color="k", ls=":", lw=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -215,7 +215,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], uhat[s])
ax.axhline(1.0, color="k", ls=":", lw=1)
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -288,7 +288,7 @@ def convergence(

if visualize:
ax.axhline(1.0, color="k", ls=":", lw=2)
ax.set_xlim([-1.0, 1.0])
ax.set_xlim((-1.0, 1.0))
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")

Expand Down
8 changes: 5 additions & 3 deletions examples/burgers-sbp-numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def main(
except ImportError:
return

from matplotlib.lines import Line2D

fig = plt.figure()
ax = plt.axes()

Expand All @@ -320,12 +322,12 @@ def main(
# prepare for animated lines
(line,) = ax.plot(x, x, color="b", linestyle="-", linewidth=2, marker="*")

def animate(n: int) -> tuple[plt.Line2D, ...]:
def animate(n: int) -> tuple[Line2D, ...]:
line.set_ydata(sol.y[:, n])
return (line,)

# Init only required for blitting to give a clean slate.
def init() -> tuple[plt.Line2D, ...]:
def init() -> tuple[Line2D, ...]:
line.set_ydata(x)
return (line,)

Expand All @@ -334,7 +336,7 @@ def init() -> tuple[plt.Line2D, ...]:
anim = animation.FuncAnimation( # noqa: F841
fig, animate, np.arange(1, nt), init_func=init, interval=25, blit=True
)
plt.show()
plt.show() # type: ignore[no-untyped-call]


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/burgers-splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def main(
plt.ion()

ln0, ln1 = ax.plot(grid.x[s], u0[s], "k--", grid.x[s], u0[s], "o-", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -140,7 +140,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], u0[s], "k--")
ax.plot(grid.x[s], func(tfinal, grid.x[s]), "k:")
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
26 changes: 14 additions & 12 deletions examples/burgers-ssweno-numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,14 +948,16 @@ def rhs_func(t: float, u: Array) -> Array:
# https://matplotlib.org/examples/animation/simple_anim.html

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

def gca(f: plt.Figure) -> plt.Axes:
def gca(f: Figure) -> Axes:
ax = f.gca()
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$u(x, t)$")
ax.set_xlim([a, b])
ax.set_ylim([np.min(solution.y) - 0.25, np.max(solution.y) + 0.25])
# ax.set_ylim([-2.0, 2.0])
ax.set_xlim((a, b))
ax.set_ylim((np.min(solution.y) - 0.25, np.max(solution.y) + 0.25))
# ax.set_ylim((-2.0, 2.0))
ax.margins(0.05)

return ax
Expand Down Expand Up @@ -1089,10 +1091,10 @@ def gca(f: plt.Figure) -> plt.Axes:
if not animate:
# NOTE: this needs updating when the domain changes
ax = gca(fig)
# ax.set_xlim([1.7, 2.1])
# ax.set_ylim([1.2, 1.55])
ax.set_xlim([-0.15, +0.15])
ax.set_ylim([0.75, 1.05])
# ax.set_xlim((1.7, 2.1))
# ax.set_ylim((1.2, 1.55))
ax.set_xlim((-0.15, +0.15))
ax.set_ylim((0.75, 1.05))

if uf is not None:
ax.plot(x, uf, "k--")
Expand All @@ -1102,10 +1104,10 @@ def gca(f: plt.Figure) -> plt.Axes:
fig.clf()

ax = gca(fig)
# ax.set_xlim([-1.5, -1.1])
# ax.set_ylim([0.4, 0.6])
ax.set_xlim([-0.15, +0.15])
ax.set_ylim([-1.05, -0.75])
# ax.set_xlim((-1.5, -1.1))
# ax.set_ylim((0.4, 0.6))
ax.set_xlim((-0.15, +0.15))
ax.set_ylim((-1.05, -0.75))

if uf is not None:
ax.plot(x, uf, "k--")
Expand Down
28 changes: 15 additions & 13 deletions examples/burgers-ssweno-periodic-numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,16 @@ def rhs_func(t: float, u: Array) -> Array:
# https://matplotlib.org/examples/animation/simple_anim.html

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

def gca(f: plt.Figure) -> plt.Axes:
def gca(f: Figure) -> Axes:
ax = f.gca()
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$u(x, t)$")
ax.set_xlim([a, b])
ax.set_ylim([np.min(solution.y) - 0.25, np.max(solution.y) + 0.25])
# ax.set_ylim([-2.0, 2.0])
ax.set_xlim((a, b))
ax.set_ylim((np.min(solution.y) - 0.25, np.max(solution.y) + 0.25))
# ax.set_ylim((-2.0, 2.0))
ax.margins(0.05)

return ax
Expand Down Expand Up @@ -720,10 +722,10 @@ def gca(f: plt.Figure) -> plt.Axes:
if not animate:
# NOTE: this needs updating when the domain changes
ax = gca(fig)
# ax.set_xlim([1.7, 2.1])
# ax.set_ylim([1.2, 1.55])
ax.set_xlim([-0.15, +0.15])
ax.set_ylim([0.75, 1.05])
# ax.set_xlim((1.7, 2.1))
# ax.set_ylim((1.2, 1.55))
ax.set_xlim((-0.15, +0.15))
ax.set_ylim((0.75, 1.05))

if uf is not None:
ax.plot(x, uf, "k--")
Expand All @@ -733,10 +735,10 @@ def gca(f: plt.Figure) -> plt.Axes:
fig.clf()

ax = gca(fig)
# ax.set_xlim([-1.5, -1.1])
# ax.set_ylim([0.4, 0.6])
ax.set_xlim([-0.15, +0.15])
ax.set_ylim([-1.05, -0.75])
# ax.set_xlim((-1.5, -1.1))
# ax.set_ylim((0.4, 0.6))
ax.set_xlim((-0.15, +0.15))
ax.set_ylim((-1.05, -0.75))

if uf is not None:
ax.plot(x, uf, "k--")
Expand Down Expand Up @@ -875,7 +877,7 @@ def test_interpolation(*, visualize: bool = True) -> None:
# ax.plot(y, fhat[R], label="$f^R$")
# ax.semilogy(np.abs(fbar - fhathat), "o-")
ax.set_xlabel(r"$\bar{x}$")
# ax.set_ylim([1.0e-16, 1])
# ax.set_ylim((1.0e-16, 1))
# ax.legend()

fig.savefig("burgers_ssweno_interp_convergence")
Expand Down
6 changes: 3 additions & 3 deletions examples/burgers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def main(
phi = hesthaven_limiter(u0, variant=scheme.variant)
(ln2,) = ax.plot(grid.f[1:-1], phi, "k", lw=1)

ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1.5, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1.5), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")
ax.set_title(f"t = {0.0:.3f}")
Expand Down Expand Up @@ -245,7 +245,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax = fig.gca()
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], uhat[s])
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
6 changes: 3 additions & 3 deletions examples/diffusion-sbp-sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def main(
plt.ion()

ln0, ln1 = ax.plot(grid.x[s], u0[s], "o-", grid.x[s], u0[s], "k--", ms=1)
ax.set_xlim([grid.a, grid.b])
ax.set_ylim([jnp.min(u0) - 1, jnp.max(u0) + 1])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_ylim((float(jnp.min(u0) - 1), float(jnp.max(u0) + 1)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down Expand Up @@ -152,7 +152,7 @@ def _apply_operator(_t: ScalarLike, _u: Array) -> Array:
ax.plot(grid.x[s], event.u[s])
ax.plot(grid.x[s], u0[s], "k--")
ax.plot(grid.x[s], func(tfinal, grid.x), "k:")
ax.set_xlim([grid.a, grid.b])
ax.set_xlim((float(grid.a), float(grid.b)))
ax.set_xlabel("$x$")
ax.set_ylabel("$u$")

Expand Down
Loading

0 comments on commit b410658

Please sign in to comment.