Skip to content

Commit

Permalink
Merge pull request #2479 from jfoster17/fix-region-flipping-logic
Browse files Browse the repository at this point in the history
Fix the logic for region display when exchanging x/y world coordinates
  • Loading branch information
astrofrog authored Apr 11, 2024
2 parents 9ea3636 + 297dcca commit 58fc3a3
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 48 deletions.
4 changes: 3 additions & 1 deletion glue/tests/visual/py311-test-visual.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
"glue.viewers.image.tests.test_viewer.test_region_layer_flip": "a142142f34961aba7e98188ad43abafe0e6e5b82e13e8cdab5131d297ed5832c",
"glue.viewers.profile.tests.test_viewer.test_simple_viewer": "f68a21be5080fec513388b2d2b220512e7b0df5498e2489da54e58708de435b3",
"glue.viewers.scatter.tests.test_viewer.test_simple_viewer": "1020a7bd3abe40510b9e03047c3b423b75c3c64ac18e6dcd6257173cec1ed53f",
"glue.viewers.scatter.tests.test_viewer.test_scatter_density_map": "3379d655262769a6ccbdbaf1970bffa9237adbec23a93d3ab75da51b9a3e7f8b"
"glue.viewers.scatter.tests.test_viewer.test_scatter_density_map": "3379d655262769a6ccbdbaf1970bffa9237adbec23a93d3ab75da51b9a3e7f8b",
"glue.viewers.image.tests.test_viewer.TestWCSRegionDisplay.test_wcs_viewer": "651e0d954c6b77becb7064de24f5101b9f2882adabc1d5aedbc183b9762c59b1",
"glue.viewers.image.tests.test_viewer.TestWCSRegionDisplay.test_flipped_wcs_viewer": "fcc18a1398d1f3f61b6989a722402cde9365a34483eedf8eac222647e20eb8ab"
}
176 changes: 172 additions & 4 deletions glue/viewers/image/tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from glue.core.data import Data
from glue.core.link_helpers import LinkSame
from glue.core.data_region import RegionData
from astropy.wcs import WCS

from shapely.geometry import Polygon, MultiPolygon
from shapely.geometry import Polygon, MultiPolygon, Point
import shapely


@visual_test
Expand Down Expand Up @@ -41,24 +43,82 @@ def test_region_layer():
polygons = MultiPolygon([poly_3, poly_4])

geoms = np.array([poly_1, poly_2, polygons])
values = np.array([1, 2, 3])
region_data = RegionData(regions=geoms, values=values)
a_values = np.array([1, 2, 3])
b_values = np.array([1, 2, 3])

region_data = RegionData(regions=geoms, a=a_values, b=b_values)

image_data = Data(x=np.arange(10000).reshape((100, 100)), label='data1')
app = Application()
app.data_collection.append(image_data)
app.data_collection.append(region_data)

viewer = app.new_data_viewer(SimpleImageViewer)
viewer.add_data(image_data)
viewer.add_data(region_data)

# Link to region data components that are the not the x,y coordinates
link1 = LinkSame(region_data.id['a'], image_data.pixel_component_ids[0])
link2 = LinkSame(region_data.id['b'], image_data.pixel_component_ids[1])
app.data_collection.add_link(link1)
app.data_collection.add_link(link2)

app.data_collection.remove_link(link1)
app.data_collection.remove_link(link2)

link1 = LinkSame(region_data.center_x_id, image_data.pixel_component_ids[0])
link2 = LinkSame(region_data.center_y_id, image_data.pixel_component_ids[1])
app.data_collection.add_link(link1)
app.data_collection.add_link(link2)

return viewer.figure


def test_region_layer_logic():
poly_1 = Polygon([(20, 20), (60, 20), (60, 40), (20, 40)])
poly_2 = Polygon([(60, 50), (60, 70), (80, 70), (80, 50)])
poly_3 = Polygon([(10, 10), (15, 10), (15, 15), (10, 15)])
poly_4 = Polygon([(10, 20), (15, 20), (15, 30), (10, 30), (12, 25)])

polygons = MultiPolygon([poly_3, poly_4])

geoms = np.array([poly_1, poly_2, polygons])
a_values = np.array([1, 2, 3])
b_values = np.array([1, 2, 3])

region_data = RegionData(regions=geoms, a=a_values, b=b_values)

image_data = Data(x=np.arange(10000).reshape((100, 100)), label='data1')
app = Application()
app.data_collection.append(image_data)
app.data_collection.append(region_data)

viewer = app.new_data_viewer(SimpleImageViewer)
viewer.add_data(image_data)
viewer.add_data(region_data)

return viewer.figure
assert viewer.layers[0].enabled # image
assert not viewer.layers[1].enabled # regions

# Link to region data components that are the not the x,y coordinates
link1 = LinkSame(region_data.id['a'], image_data.pixel_component_ids[0])
link2 = LinkSame(region_data.id['b'], image_data.pixel_component_ids[1])
app.data_collection.add_link(link1)
app.data_collection.add_link(link2)

assert viewer.layers[0].enabled # image
assert not viewer.layers[1].enabled # regions

app.data_collection.remove_link(link1)
app.data_collection.remove_link(link2)

link1 = LinkSame(region_data.center_x_id, image_data.pixel_component_ids[0])
link2 = LinkSame(region_data.center_y_id, image_data.pixel_component_ids[1])
app.data_collection.add_link(link1)
app.data_collection.add_link(link2)

assert viewer.layers[0].enabled # image
assert viewer.layers[1].enabled # regions


@visual_test
Expand Down Expand Up @@ -97,3 +157,111 @@ def test_region_layer_flip():
viewer.state.y_att = image_data.pixel_component_ids[1]

return viewer.figure


class TestWCSRegionDisplay(object):
def setup_method(self, method):

wcs1 = WCS(naxis=2)
wcs1.wcs.ctype = 'RA---TAN', 'DEC--TAN'
wcs1.wcs.crpix = 15, 15
wcs1.wcs.cd = [[2, -1], [1, 2]]

wcs1.wcs.set()

np.random.seed(2)
self.image1 = Data(label='image1', a=np.random.rand(30, 30), coords=wcs1)
SHAPELY_ARRAY = np.array([Point(1.5, 2.5).buffer(4),
Polygon([(10, 10), (10, 15), (20, 15), (20, 10)])])
self.region_data = RegionData(label='My Regions',
color=np.array(['red', 'blue']),
area=shapely.area(SHAPELY_ARRAY),
boundary=SHAPELY_ARRAY)
self.application = Application()

self.application.data_collection.append(self.image1)
self.application.data_collection.append(self.region_data)

self.viewer = self.application.new_data_viewer(SimpleImageViewer)

def test_wcs_viewer_bad_link(self):
self.viewer.add_data(self.image1)

link1 = LinkSame(self.region_data.id['color'], self.image1.world_component_ids[1])
link2 = LinkSame(self.region_data.id['area'], self.image1.world_component_ids[0])

self.application.data_collection.add_link(link1)
self.application.data_collection.add_link(link2)

self.viewer.add_data(self.region_data)

assert self.viewer.state._display_world is True
assert len(self.viewer.state.layers) == 2
assert self.viewer.layers[0].enabled
assert not self.viewer.layers[1].enabled

def test_wcs_viewer_good_link(self):
self.viewer.add_data(self.image1)

link1 = LinkSame(self.region_data.center_x_id, self.image1.world_component_ids[1])
link2 = LinkSame(self.region_data.center_y_id, self.image1.world_component_ids[0])

self.application.data_collection.add_link(link1)
self.application.data_collection.add_link(link2)

self.viewer.add_data(self.region_data)

assert self.viewer.state._display_world is True
assert len(self.viewer.state.layers) == 2
assert self.viewer.layers[0].enabled
assert self.viewer.layers[1].enabled

@visual_test
def test_wcs_viewer(self):
self.viewer.add_data(self.image1)

link1 = LinkSame(self.region_data.center_x_id, self.image1.world_component_ids[1])
link2 = LinkSame(self.region_data.center_y_id, self.image1.world_component_ids[0])

self.application.data_collection.add_link(link1)
self.application.data_collection.add_link(link2)

self.viewer.add_data(self.region_data)

assert self.viewer.state._display_world is True
assert len(self.viewer.state.layers) == 2
assert self.viewer.layers[0].enabled
assert self.viewer.layers[1].enabled

return self.viewer.figure

@visual_test
def test_flipped_wcs_viewer(self):
self.viewer.add_data(self.image1)

link1 = LinkSame(self.region_data.center_x_id, self.image1.world_component_ids[1])
link2 = LinkSame(self.region_data.center_y_id, self.image1.world_component_ids[0])

self.application.data_collection.add_link(link1)
self.application.data_collection.add_link(link2)

self.viewer.add_data(self.region_data)
original_path_patch = self.viewer.layers[1].region_collection.patches[1].get_path().vertices

# Flip x,y in the viewer
with delay_callback(self.viewer.state, 'x_att_world', 'y_att_world', 'x_att', 'y_att'):
self.viewer.state.x_att_world = self.image1.world_component_ids[0]
self.viewer.state.y_att_world = self.image1.world_component_ids[1]
self.viewer.state.x_att = self.image1.pixel_component_ids[0]
self.viewer.state.y_att = self.image1.pixel_component_ids[1]

assert self.viewer.state._display_world is True
assert len(self.viewer.state.layers) == 2
assert self.viewer.layers[0].enabled
assert self.viewer.layers[1].enabled
new_path_patch = self.viewer.layers[1].region_collection.patches[1].get_path().vertices

# Because we have flipped the viewer, the patches should have changed
assert np.array_equal(original_path_patch, np.flip(new_path_patch, axis=1))

return self.viewer.figure
105 changes: 62 additions & 43 deletions glue/viewers/scatter/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,67 +636,64 @@ def _update_data(self):
data = self.layer.data
region_att = data.extended_component_id

# In order to display a region we need to check that the representative
# points (center_x_id and center_y_id) are linked to the viewer x and y
# attributes. While we check this we also get the _values_ of the center
# attributes (x,y) and the _values_ for the viewer x and y attributes
# (xx,yy). We can compare these to figure out if the regions need to be
# flipped x,y <-> y,x. For now, if the image is displayed in world coordinates,
# the regions must be specified in world coordinates as well.

try:
# These must be special attributes that are linked to a region_att
if ((not data.linked_to_center_comp(self._viewer_state.x_att)) and
(not data.linked_to_center_comp(self._viewer_state.x_att_world))):
raise IncompatibleAttribute
x = ensure_numerical(self.layer[self._viewer_state.x_att].ravel())
xx = ensure_numerical(data[data.center_x_id].ravel())
# These must be attributes that are linked to center_x_id
if self._viewer_state._display_world:
if not data.linked_to_center_comp(self._viewer_state.x_att_world):
raise IncompatibleAttribute
else:
xx = ensure_numerical(self.layer[self._viewer_state.x_att_world].ravel())
else:
if not data.linked_to_center_comp(self._viewer_state.x_att):
raise IncompatibleAttribute
else:
xx = ensure_numerical(self.layer[self._viewer_state.x_att].ravel())
x = ensure_numerical(self.layer[data.center_x_id].ravel())
except (IncompatibleAttribute, IndexError):
# The following includes a call to self.clear()
self.disable_invalid_attributes(self._viewer_state.x_att)
return
else:
self.enable()

try:
# These must be special attributes that are linked to a region_att
if ((not data.linked_to_center_comp(self._viewer_state.y_att)) and
(not data.linked_to_center_comp(self._viewer_state.y_att_world))):
raise IncompatibleAttribute
y = ensure_numerical(self.layer[self._viewer_state.y_att].ravel())
yy = ensure_numerical(data[data.center_y_id].ravel())
# These must be attributes that are linked to center_y_id
if self._viewer_state._display_world:
if not data.linked_to_center_comp(self._viewer_state.y_att_world):
raise IncompatibleAttribute
else:
yy = ensure_numerical(self.layer[self._viewer_state.y_att_world].ravel())
else:
if not data.linked_to_center_comp(self._viewer_state.y_att):
raise IncompatibleAttribute
else:
yy = ensure_numerical(self.layer[self._viewer_state.y_att].ravel())

y = ensure_numerical(self.layer[data.center_y_id].ravel())
except (IncompatibleAttribute, IndexError):
# The following includes a call to self.clear()
self.disable_invalid_attributes(self._viewer_state.y_att)
return
else:
self.enable()

# We need to make sure that x and y viewer attributes are
# really the center_x and center_y attributes of the underlying
# data, so we compare the values on the centroids using the
# glue data access machinery.

regions = self.layer[region_att]

def flip_xy(g):
return transform(lambda x, y: (y, x), g)

x_no_match = False
if np.array_equal(y, yy):
if np.array_equal(x, xx):
self.enable()
else:
x_no_match = True
else:
if np.array_equal(y, xx) and np.array_equal(x, yy): # This means x and y have been swapped
regions = [flip_xy(g) for g in regions]
self.enable()
else:
self.disable_invalid_attributes(self._viewer_state.y_att)
if x_no_match:
self.disable_invalid_attributes(self._viewer_state.x_att)
return

# If we are using world coordinates (i.e. the regions are specified in world coordinates)
# we need to transform the geometries of the regions into pixel coordinates for display
# Note that this calls a custom version of the transform function from shapely
# to accomodate glue WCS objects
if self._viewer_state._display_world:
# First, convert to world coordinates
# If we are using world coordinates (i.e. the regions are specified in world coordinates)
# we need to transform the geometries of the regions into pixel coordinates for display
# Note that this calls a custom version of the transform function from shapely
# to accomodate glue WCS objects

try:
# First, convert regions to world coordinates (relevant if there are multiple links or a
# transform is needed to get to world coordinates)
tfunc = data.get_transform_to_cids([self._viewer_state.x_att_world, self._viewer_state.y_att_world])
regions = np.array([transform(tfunc, g) for g in regions])

Expand All @@ -707,13 +704,35 @@ def flip_xy(g):
self.disable_invalid_attributes([self._viewer_state.x_att_world, self._viewer_state.y_att_world])
return
else:
# If the image is just in pixels we just lookup how to transform the points in the region
try:
tfunc = data.get_transform_to_cids([self._viewer_state.x_att, self._viewer_state.y_att])
regions = np.array([transform(tfunc, g) for g in regions])
except ValueError:
self.disable_invalid_attributes([self._viewer_state.x_att, self._viewer_state.y_att])
return

# Now we flip the x and y coordinates of each point in the regions if necessary.
# This has to happen after the transform into pixel coordinates.
def flip_xy(g):
return transform(lambda x, y: (y, x), g)

x_no_match = False
if np.array_equal(y, yy):
if np.array_equal(x, xx):
self.enable()
else:
x_no_match = True
else:
if np.array_equal(y, xx) and np.array_equal(x, yy): # This means x and y have been swapped
regions = np.array([flip_xy(g) for g in regions])
self.enable()
else:
self.disable_invalid_attributes(self._viewer_state.y_att)
if x_no_match:
self.disable_invalid_attributes(self._viewer_state.x_att)
return

# decompose GeometryCollections
geoms, multiindex = _sanitize_geoms(regions, prefix="Geom")
self.multiindex_geometry = multiindex
Expand Down

0 comments on commit 58fc3a3

Please sign in to comment.