Skip to content

Commit

Permalink
Merge pull request #385 from scipp/refactor-colormapper-update
Browse files Browse the repository at this point in the history
Refactor colormapper update mechanism
  • Loading branch information
nvaytet authored Oct 31, 2024
2 parents 70ad803 + abd78d8 commit ddd351d
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 134 deletions.
42 changes: 34 additions & 8 deletions src/plopp/backends/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ...core.utils import coord_as_bin_edges, merge_masks, repeat, scalar_to_string
from ...graphics.bbox import BoundingBox, axis_bounds
from ...graphics.colormapper import ColorMapper
from ..common import check_ndim
from .canvas import Canvas

Expand Down Expand Up @@ -63,9 +64,13 @@ class Image:
Parameters
----------
canvas:
The canvas that will display the line.
The canvas that will display the image.
colormapper:
The colormapper to use for the image.
data:
The initial data to create the line from.
The initial data to create the image from.
uid:
The unique identifier of the artist. If None, a random UUID is generated.
shading:
The shading to use for the ``pcolormesh``.
rasterized:
Expand All @@ -77,16 +82,19 @@ class Image:
def __init__(
self,
canvas: Canvas,
colormapper: ColorMapper,
data: sc.DataArray,
uid: str | None = None,
shading: str = 'auto',
rasterized: bool = True,
**kwargs,
):
check_ndim(data, ndim=2, origin='Image')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._canvas = canvas
self._colormapper = colormapper
self._ax = self._canvas.ax
self._data = data
self._id = uuid.uuid4().hex
# Because all keyword arguments from the figure are forwarded to both the canvas
# and the line, we need to remove the arguments that belong to the canvas.
kwargs.pop('ax', None)
Expand Down Expand Up @@ -129,7 +137,10 @@ def __init__(
rasterized=rasterized,
**kwargs,
)

self._colormapper.add_artist(self.uid, self)
self._mesh.set_array(None)
self._update_colors()

for xy, var in string_labels.items():
getattr(self._ax, f'set_{xy}ticks')(np.arange(float(var.shape[0])))
Expand Down Expand Up @@ -163,17 +174,24 @@ def data(self):
)
return out

def set_colors(self, rgba: np.ndarray):
def notify_artist(self, message: str) -> None:
"""
Set the mesh's rgba colors:
Receive notification from the colormapper that its state has changed.
We thus need to update the colors of the mesh.
Parameters
----------
rgba:
The array of rgba colors.
message:
The message from the colormapper.
"""
self._update_colors()

def _update_colors(self):
"""
Update the mesh colors.
"""
rgba = self._colormapper.rgba(self.data)
self._mesh.set_facecolors(rgba.reshape(np.prod(rgba.shape[:-1]), 4))
self._canvas.draw()

def update(self, new_values: sc.DataArray):
"""
Expand All @@ -187,6 +205,7 @@ def update(self, new_values: sc.DataArray):
check_ndim(new_values, ndim=2, origin='Image')
self._data = new_values
self._data_with_bin_edges.data = new_values.data
self._update_colors()

def format_coord(
self, xslice: tuple[str, sc.Variable], yslice: tuple[str, sc.Variable]
Expand Down Expand Up @@ -223,3 +242,10 @@ def bbox(self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log'
**{**axis_bounds(('xmin', 'xmax'), image_x, xscale)},
**{**axis_bounds(('ymin', 'ymax'), image_y, yscale)},
)

def remove(self):
"""
Remove the image artist from the canvas.
"""
self._mesh.remove()
self._colormapper.remove_artist(self.uid)
9 changes: 8 additions & 1 deletion src/plopp/backends/matplotlib/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,29 @@ class Line:
The canvas that will display the line.
data:
The initial data to create the line from.
uid:
The unique identifier of the artist. If None, a random UUID is generated.
artist_number:
The canvas keeps track of how many lines have been added to it. This number is
used to set the color and marker parameters of the line.
errorbars:
Whether to add error bars to the line.
mask_color:
The color of the masked points.
"""

def __init__(
self,
canvas: Canvas,
data: sc.DataArray,
uid: str | None = None,
artist_number: int = 0,
errorbars: bool = True,
mask_color: str = 'black',
**kwargs,
):
check_ndim(data, ndim=1, origin='Line')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._canvas = canvas
self._ax = self._canvas.ax
self._data = data
Expand All @@ -64,7 +72,6 @@ def __init__(
self._dim = self._data.dim
self._unit = self._data.unit
self._coord = self._data.coords[self._dim]
self._id = uuid.uuid4().hex

aliases = {'ls': 'linestyle', 'lw': 'linewidth', 'c': 'color'}
for key, alias in aliases.items():
Expand Down
64 changes: 58 additions & 6 deletions src/plopp/backends/matplotlib/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,65 @@

from ...core.utils import merge_masks
from ...graphics.bbox import BoundingBox, axis_bounds
from ...graphics.colormapper import ColorMapper
from ..common import check_ndim
from .canvas import Canvas
from .utils import parse_dicts_in_kwargs


class Scatter:
""" """
"""
Artist to represent a two-dimensional scatter plot.
Parameters
----------
canvas:
The canvas that will display the scatter plot.
data:
The initial data to create the line from.
x:
The name of the coordinate that is to be used for the X positions.
y:
The name of the coordinate that is to be used for the Y positions.
uid:
The unique identifier of the artist. If None, a random UUID is generated.
size:
The size of the markers.
color:
The color of the markers (this is ignored if a colorbar is used).
artist_number:
Number of the artist (can be used to set the color of the artist).
colormapper:
The colormapper to use for the scatter plot.
mask_color:
The color of the masked points.
cbar:
Whether to use a colorbar.
"""

def __init__(
self,
canvas: Canvas,
data: sc.DataArray,
x: str = 'x',
y: str = 'y',
uid: str | None = None,
size: str | float | None = None,
artist_number: int = 0,
colormapper: ColorMapper | None = None,
mask_color: str = 'black',
cbar: bool = False,
**kwargs,
):
check_ndim(data, ndim=1, origin='Scatter')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._canvas = canvas
self._ax = self._canvas.ax
self._data = data
self._x = x
self._y = y
self._size = size
self._colormapper = colormapper
# Because all keyword arguments from the figure are forwarded to both the canvas
# and the line, we need to remove the arguments that belong to the canvas.
kwargs.pop('ax', None)
Expand Down Expand Up @@ -70,6 +102,9 @@ def __init__(
label=self.label,
**merged_kwargs,
)
if self._colormapper is not None:
self._colormapper.add_artist(self.uid, self)
self._scatter.set_array(None)

xmask = self._data.coords[self._x].values.copy()
ymask = self._data.coords[self._y].values.copy()
Expand All @@ -91,6 +126,24 @@ def __init__(
visible=visible_mask,
)

def notify_artist(self, message: str) -> None:
"""
Receive notification from the colormapper that its state has changed.
We thus need to update the colors of the points.
Parameters
----------
message:
The message from the colormapper.
"""
self._update_colors()

def _update_colors(self):
"""
Update the colors of the scatter points.
"""
self._scatter.set_facecolors(self._colormapper.rgba(self.data))

def update(self, new_values: sc.DataArray):
"""
Update the x and y positions of the data points from new data.
Expand All @@ -110,18 +163,17 @@ def update(self, new_values: sc.DataArray):
)
if isinstance(self._size, str):
self._scatter.set_sizes(self._data.coords[self._size].values)
if self._colormapper is not None:
self._update_colors()

def remove(self):
"""
Remove the scatter and mask artists from the canvas.
"""
self._scatter.remove()
self._mask.remove()

def set_colors(self, rgba: np.ndarray):
if self._scatter.get_array() is not None:
self._scatter.set_array(None)
self._scatter.set_facecolors(rgba)
if self._colormapper is not None:
self._colormapper.remove_artist(self.uid)

@property
def data(self):
Expand Down
13 changes: 8 additions & 5 deletions src/plopp/backends/plotly/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Line:
The canvas that will display the line.
data:
The initial data to create the line from.
uid:
The unique identifier of the artist. If None, a random UUID is generated.
artist_number:
The canvas keeps track of how many lines have been added to it. This number is
used to set the color and marker parameters of the line.
Expand All @@ -54,6 +56,7 @@ def __init__(
self,
canvas: Canvas,
data: sc.DataArray,
uid: str | None = None,
artist_number: int = 0,
errorbars: bool = True,
mask_color: str = 'black',
Expand All @@ -62,6 +65,7 @@ def __init__(
**kwargs,
):
check_ndim(data, ndim=1, origin='Line')
self.uid = uid if uid is not None else uuid.uuid4().hex
self._fig = canvas.fig
self._data = data

Expand All @@ -75,7 +79,6 @@ def __init__(
self._dim = self._data.dim
self._unit = self._data.unit
self._coord = self._data.coords[self._dim]
self._id = uuid.uuid4().hex

line_data = make_line_data(data=self._data, dim=self._dim)

Expand Down Expand Up @@ -162,12 +165,12 @@ def __init__(
self._error = self._fig.data[-1]
self._fig.add_trace(self._mask)
self._mask = self._fig.data[-1]
self._line._plopp_id = self._id
self._line._plopp_id = self.uid
self.line_mask = sc.array(dims=['x'], values=~np.isnan(line_data['mask']['y']))

self._mask._plopp_id = self._id
self._mask._plopp_id = self.uid
if self._error is not None:
self._error._plopp_id = self._id
self._error._plopp_id = self.uid

def update(self, new_values: sc.DataArray):
"""
Expand Down Expand Up @@ -209,7 +212,7 @@ def remove(self):
Remove the line, masks and errorbar artists from the canvas.
"""
self._fig.data = [
trace for trace in list(self._fig.data) if trace._plopp_id != self._id
trace for trace in list(self._fig.data) if trace._plopp_id != self.uid
]

@property
Expand Down
Loading

0 comments on commit ddd351d

Please sign in to comment.