Skip to content

Commit

Permalink
Merge pull request #5063 from cphyc/bugfix/5062
Browse files Browse the repository at this point in the history
[BUG] Ensure positions are dimensionless in plot callbacks
  • Loading branch information
chrishavlin authored Dec 9, 2024
2 parents a9a56b7 + 3a5dd48 commit e6e5f1f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
39 changes: 28 additions & 11 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,8 +999,8 @@ def __call__(self, plot) -> None:

if plot._type_name in ["CuttingPlane", "Projection", "Slice"]:
if plot._type_name == "CuttingPlane":
x = data["px"] * dx
y = data["py"] * dy
x = (data["px"] * dx).to("1")
y = (data["py"] * dy).to("1")
z = data[self.field]
elif plot._type_name in ["Projection", "Slice"]:
# Makes a copy of the position fields "px" and "py" and adds the
Expand Down Expand Up @@ -1028,8 +1028,10 @@ def __call__(self, plot) -> None:
wI = AllX & AllY

# This converts XShifted and YShifted into plot coordinates
x = ((XShifted[wI] - x0) * dx).ndarray_view() + xx0
y = ((YShifted[wI] - y0) * dy).ndarray_view() + yy0
# Note: we force conversion into "1" to prevent issues in case
# one of the length has some dimensionless factor (Mpc/h)
x = ((XShifted[wI] - x0) * dx).to("1").ndarray_view() + xx0
y = ((YShifted[wI] - y0) * dy).to("1").ndarray_view() + yy0
z = data[self.field][wI]

# Both the input and output from the triangulator are in plot
Expand Down Expand Up @@ -1134,8 +1136,8 @@ def __call__(self, plot):

x0, x1, y0, y1 = self._physical_bounds(plot)
xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
(dx, dy) = self._pixel_scale(plot)
(ypix, xpix) = plot.raw_image_shape
dx, dy = self._pixel_scale(plot)
ypix, xpix = plot.raw_image_shape
ax = plot.data.axis
px_index = plot.data.ds.coordinates.x_axis[ax]
py_index = plot.data.ds.coordinates.y_axis[ax]
Expand Down Expand Up @@ -1169,10 +1171,17 @@ def __call__(self, plot):
for px_off, py_off in zip(pxs.ravel(), pys.ravel(), strict=True):
pxo = px_off * DW[px_index]
pyo = py_off * DW[py_index]
left_edge_x = np.array((GLE[:, px_index] + pxo - x0) * dx) + xx0
left_edge_y = np.array((GLE[:, py_index] + pyo - y0) * dy) + yy0
right_edge_x = np.array((GRE[:, px_index] + pxo - x0) * dx) + xx0
right_edge_y = np.array((GRE[:, py_index] + pyo - y0) * dy) + yy0
# Note: [dx] = 1/length, [GLE] = length
# we force conversion into "1" to prevent issues if e.g. GLE is in Mpc/h
# where dx * GLE would have units 1/h rather than being truly dimensionless
left_edge_x = np.array((((GLE[:, px_index] + pxo - x0) * dx) + xx0).to("1"))
left_edge_y = np.array((((GLE[:, py_index] + pyo - y0) * dy) + yy0).to("1"))
right_edge_x = np.array(
(((GRE[:, px_index] + pxo - x0) * dx) + xx0).to("1")
)
right_edge_y = np.array(
(((GRE[:, py_index] + pyo - y0) * dy) + yy0).to("1")
)
xwidth = xpix * (right_edge_x - left_edge_x) / (xx1 - xx0)
ywidth = ypix * (right_edge_y - left_edge_y) / (yy1 - yy0)
visible = np.logical_and(
Expand Down Expand Up @@ -2075,12 +2084,20 @@ def __call__(self, plot):
units = "code_length"
self.radius = self.radius.to(units)

if not hasattr(self.radius, "units"):
self.radius = plot.data.ds.quan(self.radius, "code_length")

if not hasattr(self.center, "units"):
self.center = plot.data.ds.arr(self.center, "code_length")

# This assures the radius has the appropriate size in
# the different coordinate systems, since one cannot simply
# apply a different transform for a length in the same way
# you can for a coordinate.
if self.coord_system == "data" or self.coord_system == "plot":
scaled_radius = self.radius * self._pixel_scale(plot)[0]
# Note: we force conversion into "1" to prevent issues in case
# one of the length has some dimensionless factor (Mpc/h)
scaled_radius = (self.radius * self._pixel_scale(plot)[0]).to("1")
else:
scaled_radius = self.radius / (plot.xlim[1] - plot.xlim[0])

Expand Down
31 changes: 31 additions & 0 deletions yt/visualization/tests/test_image_comp_2D_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,37 @@ def test_sliceplot_set_unit_and_zlim_order():
npt.assert_allclose(im0, im1)


def test_annotation_parse_h():
ds = fake_random_ds(16)

# Make sure `h` (reduced Hubble constant) is not equal to 1
ds.unit_registry.modify("h", 0.7)

rad = ds.quan(0.1, "cm/h")
center = ds.arr([0.5] * 3, "code_length")

# Twice the same slice plot
p1 = SlicePlot(ds, "x", "density")
p2 = SlicePlot(ds, "x", "density")

# But the *same* center is given in different units
p1.annotate_sphere(center.to("cm"), rad, circle_args={"color": "black"})
p2.annotate_sphere(center.to("cm/h"), rad, circle_args={"color": "black"})

# Render annotations, and extract matplotlib image
# as an RGB array
p1.render()
p1.plots["gas", "density"].figure.canvas.draw()
img1 = p1.plots["gas", "density"].figure.canvas.renderer.buffer_rgba()

p2.render()
p2.plots["gas", "density"].figure.canvas.draw()
img2 = p2.plots["gas", "density"].figure.canvas.renderer.buffer_rgba()

# This should be the same image
npt.assert_allclose(img1, img2)


@pytest.mark.mpl_image_compare
def test_inf_and_finite_values_set_zlim():
# see https://github.com/yt-project/yt/issues/3901
Expand Down

0 comments on commit e6e5f1f

Please sign in to comment.