diff --git a/itkwidgets/cell_watcher.py b/itkwidgets/cell_watcher.py index e1175a90..c66fc52b 100644 --- a/itkwidgets/cell_watcher.py +++ b/itkwidgets/cell_watcher.py @@ -68,8 +68,8 @@ def setup(self): self.shell = get_ipython() self.kernel = self.shell.kernel self.shell_stream = getattr(self.kernel, "shell_stream", None) - self.execute_request_handler = self.kernel.shell_handlers["execute_request"] # Keep a reference to the ipykernel execute_request function + self.execute_request_handler = self.kernel.shell_handlers["execute_request"] self.current_request = None self.waiting_on_viewer = False self.results = {} @@ -204,24 +204,7 @@ def _callback(self, *args, **kwargs): self.update_namespace() self.create_task(self.execute_next_request) - def find_view_object_names(self): - from .viewer import Viewer - # Used to determine that all references to Viewer - # objects are ready before a cell is run - objs = self.viewers.viewer_objects - user_vars = [k for k in self.shell.user_ns.keys() if not k.startswith('_')] - for var in user_vars: - # Identify which variable the view object has been assigned to - value = self.shell.user_ns[var] - if isinstance(value, Viewer) and value.__str__() in objs: - idx = objs.index(value.__str__()) - self.viewers.set_name(objs[idx], var) - def post_run_cell(self, response): - # If a cell has been run and there are viewers with no variable - # associated with them check the user namespace to see if they have - # been added + # Abort remaining cells on error in execution if response.error_in_exec is not None: self.abort_all = True - if self.viewers.not_named: - self.find_view_object_names() diff --git a/itkwidgets/viewer.py b/itkwidgets/viewer.py index b1921fbd..3584de3d 100644 --- a/itkwidgets/viewer.py +++ b/itkwidgets/viewer.py @@ -1,13 +1,17 @@ import asyncio import functools +import queue +import threading import numpy as np from imjoy_rpc import api from inspect import isawaitable from typing import List, Union, Tuple from IPython.display import display, HTML +from IPython.lib import backgroundjobs as bg from ngff_zarr import from_ngff_zarr, to_ngff_image, NgffImage import uuid +from ._method_types import deferred_methods from ._type_aliases import Gaussians, Style, Image, PointSet from ._initialization_params import ( init_params_dict, @@ -56,6 +60,9 @@ def __init__( self.parent = parent if ENVIRONMENT is not Env.JUPYTERLITE: CellWatcher().add_viewer(self.parent) + if ENVIRONMENT is not Env.HYPHA: + self.viewer_event = threading.Event() + self.data_event = threading.Event() async def setup(self): pass @@ -86,11 +93,12 @@ async def run(self, ctx): # Create the initial screenshot await self.create_screenshot() itk_viewer.registerEventListener( - 'renderedImageAssigned', self.update_viewer_status + 'renderedImageAssigned', self.set_event ) if not defer_for_data_render(self.init_data): # Once the viewer has been created any queued requests can be run CellWatcher().update_viewer_status(self.parent, True) + asyncio.get_running_loop().call_soon_threadsafe(self.viewer_event.set) # Wait and then update the screenshot in case rendered level changed await asyncio.sleep(10) @@ -125,10 +133,16 @@ def update_screenshot(self, base64_image): ''') self.img.display(html) - def update_viewer_status(self, name): + def update_viewer_status(self): if not CellWatcher().viewer_ready(self.parent): CellWatcher().update_viewer_status(self.parent, True) + def set_event(self, event_data): + if not self.data_event.is_set(): + # Once the data has been set the deferred queue requests can be run + asyncio.get_running_loop().call_soon_threadsafe(self.data_event.set) + self.update_viewer_status() + class Viewer: """Pythonic Viewer class.""" @@ -148,12 +162,24 @@ def __init__( ui_collapsed=ui_collapsed, rotate=rotate, ui=ui, init_data=data, parent=self.name, **add_data_kwargs ) self.cw = CellWatcher() + if ENVIRONMENT is not Env.JUPYTERLITE: + self._setup_queueing() api.export(self.viewer_rpc) else: self._itk_viewer = add_data_kwargs.get('itk_viewer', None) self.server = add_data_kwargs.get('server', None) self.workspace = self.server.config.workspace + def _setup_queueing(self): + self.bg_jobs = bg.BackgroundJobManager() + self.queue = queue.Queue() + self.deferred_queue = queue.Queue() + self.bg_thread = self.bg_jobs.new(self.queue_worker) + + @property + def loop(self): + return asyncio.get_running_loop() + @property def has_viewer(self): if hasattr(self, "viewer_rpc"): @@ -166,11 +192,43 @@ def itk_viewer(self): return self.viewer_rpc.itk_viewer return self._itk_viewer + async def run_queued_requests(self): + def _run_queued_requests(queue): + method_name, args, kwargs = queue.get() + fn = getattr(self.itk_viewer, method_name) + self.loop.call_soon_threadsafe(asyncio.ensure_future, fn(*args, **kwargs)) + + # Wait for the viewer to be created + self.viewer_rpc.viewer_event.wait() + while self.queue.qsize(): + _run_queued_requests(self.queue) + # Wait for the data to be set + self.viewer_rpc.data_event.wait() + while self.deferred_queue.qsize(): + _run_queued_requests(self.deferred_queue) + + def queue_worker(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + task = loop.create_task(self.run_queued_requests()) + loop.run_until_complete(task) + def call_getter(self, future): name = uuid.uuid4() CellWatcher().results[name] = future future.add_done_callback(functools.partial(CellWatcher()._callback, name)) + def queue_request(self, method, *args, **kwargs): + if ( + ENVIRONMENT is Env.JUPYTERLITE or ENVIRONMENT is Env.HYPHA + ) or self.has_viewer: + fn = getattr(self.itk_viewer, method) + fn(*args, **kwargs) + elif method in deferred_methods(): + self.deferred_queue.put((method, args, kwargs)) + else: + self.queue.put((method, args, kwargs)) + def fetch_value(func): @functools.wraps(func) def _fetch_value(self, *args, **kwargs): @@ -184,28 +242,28 @@ def _fetch_value(self, *args, **kwargs): @fetch_value def set_annotations_enabled(self, enabled: bool): - self.viewer_rpc.itk_viewer.setAnnotationsEnabled(enabled) + self.queue_request('setAnnotationsEnabled', enabled) @fetch_value async def get_annotations_enabled(self): return await self.viewer_rpc.itk_viewer.getAnnotationsEnabled() @fetch_value def set_axes_enabled(self, enabled: bool): - self.viewer_rpc.itk_viewer.setAxesEnabled(enabled) + self.queue_request('setAxesEnabled', enabled) @fetch_value async def get_axes_enabled(self): return await self.viewer_rpc.itk_viewer.getAxesEnabled() @fetch_value def set_background_color(self, bgColor: List[float]): - self.viewer_rpc.itk_viewer.setBackgroundColor(bgColor) + self.queue_request('setBackgroundColor', bgColor) @fetch_value async def get_background_color(self): return await self.viewer_rpc.itk_viewer.getBackgroundColor() @fetch_value def set_cropping_planes(self, cropping_planes): - self.viewer_rpc.itk_viewer.setCroppingPlanes(cropping_planes) + self.queue_request('setCroppingPlanes', cropping_planes) @fetch_value async def get_cropping_planes(self): return await self.viewer_rpc.itk_viewer.getCroppingPlanes() @@ -222,11 +280,11 @@ def set_image(self, image: Image, name: str = 'Image'): svc = self.server.get_service(svc_name) svc.set_label_or_image('image') else: - self.viewer_rpc.itk_viewer.setImage(image, name) + self.queue_request('setImage', image, name) CellWatcher().update_viewer_status(self.name, False) elif render_type is RenderType.POINT_SET: image = _get_viewer_point_set(image) - self.viewer_rpc.itk_viewer.setPointSets(image) + self.queue_request('setPointSets', image) @fetch_value async def get_image(self, name: str = 'Image') -> NgffImage: """Get the full, highest resolution image. @@ -253,84 +311,84 @@ async def get_image(self, name: str = 'Image') -> NgffImage: @fetch_value def set_image_blend_mode(self, mode: str): - self.viewer_rpc.itk_viewer.setImageBlendMode(mode) + self.queue_request('setImageBlendMode', mode) @fetch_value async def get_image_blend_mode(self): return await self.viewer_rpc.itk_viewer.getImageBlendMode() @fetch_value def set_image_color_map(self, colorMap: str): - self.viewer_rpc.itk_viewer.setImageColorMap(colorMap) + self.queue_request('setImageColorMap', colorMap) @fetch_value async def get_image_color_map(self): return await self.viewer_rpc.itk_viewer.getImageColorMap() @fetch_value def set_image_color_range(self, range: List[float]): - self.viewer_rpc.itk_viewer.setImageColorRange(range) + self.queue_request('setImageColorRange', range) @fetch_value async def get_image_color_range(self): return await self.viewer_rpc.itk_viewer.getImageColorRange() @fetch_value def set_image_color_range_bounds(self, range: List[float]): - self.viewer_rpc.itk_viewer.setImageColorRangeBounds(range) + self.queue_request('setImageColorRangeBounds', range) @fetch_value async def get_image_color_range_bounds(self): return await self.viewer_rpc.itk_viewer.getImageColorRangeBounds() @fetch_value def set_image_component_visibility(self, visibility: bool, component: int): - self.viewer_rpc.itk_viewer.setImageComponentVisibility(visibility, component) + self.queue_request('setImageComponentVisibility', visibility, component) @fetch_value async def get_image_component_visibility(self, component: int): return await self.viewer_rpc.itk_viewer.getImageComponentVisibility(component) @fetch_value def set_image_gradient_opacity(self, opacity: float): - self.viewer_rpc.itk_viewer.setImageGradientOpacity(opacity) + self.queue_request('setImageGradientOpacity', opacity) @fetch_value async def get_image_gradient_opacity(self): return await self.viewer_rpc.itk_viewer.getImageGradientOpacity() @fetch_value def set_image_gradient_opacity_scale(self, min: float): - self.viewer_rpc.itk_viewer.setImageGradientOpacityScale(min) + self.queue_request('setImageGradientOpacityScale', min) @fetch_value async def get_image_gradient_opacity_scale(self): return await self.viewer_rpc.itk_viewer.getImageGradientOpacityScale() @fetch_value def set_image_interpolation_enabled(self, enabled: bool): - self.viewer_rpc.itk_viewer.setImageInterpolationEnabled(enabled) + self.queue_request('setImageInterpolationEnabled', enabled) @fetch_value async def get_image_interpolation_enabled(self): return await self.viewer_rpc.itk_viewer.getImageInterpolationEnabled() @fetch_value def set_image_piecewise_function_gaussians(self, gaussians: Gaussians): - self.viewer_rpc.itk_viewer.setImagePiecewiseFunctionGaussians(gaussians) + self.queue_request('setImagePiecewiseFunctionGaussians', gaussians) @fetch_value async def get_image_piecewise_function_gaussians(self): return await self.viewer_rpc.itk_viewer.getImagePiecewiseFunctionGaussians() @fetch_value def set_image_shadow_enabled(self, enabled: bool): - self.viewer_rpc.itk_viewer.setImageShadowEnabled(enabled) + self.queue_request('setImageShadowEnabled', enabled) @fetch_value async def get_image_shadow_enabled(self): return await self.viewer_rpc.itk_viewer.getImageShadowEnabled() @fetch_value def set_image_volume_sample_distance(self, distance: float): - self.viewer_rpc.itk_viewer.setImageVolumeSampleDistance(distance) + self.queue_request('setImageVolumeSampleDistance', distance) @fetch_value async def get_image_volume_sample_distance(self): return await self.viewer_rpc.itk_viewer.getImageVolumeSampleDistance() @fetch_value def set_image_volume_scattering_blend(self, scattering_blend: float): - self.viewer_rpc.itk_viewer.setImageVolumeScatteringBlend(scattering_blend) + self.queue_request('setImageVolumeScatteringBlend', scattering_blend) @fetch_value async def get_image_volume_scattering_blend(self): return await self.viewer_rpc.itk_viewer.getImageVolumeScatteringBlend() @@ -416,6 +474,7 @@ async def get_roi_slice(self, scale: int = -1): z0, z1 = idxs['z'] return np.index_exp[int(z0):int(z1+1), int(y0):int(y1+1), int(x0):int(x1+1)] + @fetch_value def compare_images(self, fixed_image: Union[str, Image], moving_image: Union[str, Image], method: str = None, image_mix: float = None, checkerboard: bool = None, pattern: Union[Tuple[int, int], Tuple[int, int, int]] = None, swap_image_order: bool = None): # image args may be image name or image object fixed_name = 'Fixed' @@ -440,7 +499,7 @@ def compare_images(self, fixed_image: Union[str, Image], moving_image: Union[str options['pattern'] = pattern if swap_image_order is not None: options['swapImageOrder'] = swap_image_order - self.viewer_rpc.itk_viewer.compareImages(fixed_name, moving_name, options) + self.queue_request('compareImages', fixed_name, moving_name, options) CellWatcher().update_viewer_status(self.name, False) @fetch_value @@ -455,11 +514,11 @@ def set_label_image(self, label_image: Image): svc = self.server.get_service(svc_name) svc.set_label_or_image('label_image') else: - self.viewer_rpc.itk_viewer.setLabelImage(label_image) + self.queue_request('setLabelImage', label_image) CellWatcher().update_viewer_status(self.name, False) elif render_type is RenderType.POINT_SET: label_image = _get_viewer_point_set(label_image) - self.viewer_rpc.itk_viewer.setPointSets(label_image) + self.queue_request('setPointSets', label_image) @fetch_value async def get_label_image(self) -> NgffImage: """Get the full, highest resolution label image. @@ -482,42 +541,42 @@ async def get_label_image(self) -> NgffImage: @fetch_value def set_label_image_blend(self, blend: float): - self.viewer_rpc.itk_viewer.setLabelImageBlend(blend) + self.queue_request('setLabelImageBlend', blend) @fetch_value async def get_label_image_blend(self): return await self.viewer_rpc.itk_viewer.getLabelImageBlend() @fetch_value def set_label_image_label_names(self, names: List[str]): - self.viewer_rpc.itk_viewer.setLabelImageLabelNames(names) + self.queue_request('setLabelImageLabelNames', names) @fetch_value async def get_label_image_label_names(self): return await self.viewer_rpc.itk_viewer.getLabelImageLabelNames() @fetch_value def set_label_image_lookup_table(self, lookupTable: str): - self.viewer_rpc.itk_viewer.setLabelImageLookupTable(lookupTable) + self.queue_request('setLabelImageLookupTable', lookupTable) @fetch_value async def get_label_image_lookup_table(self): return await self.viewer_rpc.itk_viewer.getLabelImageLookupTable() @fetch_value def set_label_image_weights(self, weights: float): - self.viewer_rpc.itk_viewer.setLabelImageWeights(weights) + self.queue_request('setLabelImageWeights', weights) @fetch_value async def get_label_image_weights(self): return await self.viewer_rpc.itk_viewer.getLabelImageWeights() @fetch_value def select_layer(self, name: str): - self.viewer_rpc.itk_viewer.selectLayer(name) + self.queue_request('selectLayer', name) @fetch_value async def get_layer_names(self): return await self.viewer_rpc.itk_viewer.getLayerNames() @fetch_value def set_layer_visibility(self, visible: bool, name: str): - self.viewer_rpc.itk_viewer.setLayerVisibility(visible, name) + self.queue_request('setLayerVisibility', visible, name) @fetch_value async def get_layer_visibility(self, name: str): return await self.viewer_rpc.itk_viewer.getLayerVisibility(name) @@ -529,64 +588,64 @@ def get_loaded_image_names(self): @fetch_value def add_point_set(self, pointSet: PointSet): pointSet = _get_viewer_point_set(pointSet) - self.viewer_rpc.itk_viewer.addPointSet(pointSet) + self.queue_request('addPointSet', pointSet) @fetch_value def set_point_set(self, pointSet: PointSet): pointSet = _get_viewer_point_set(pointSet) - self.viewer_rpc.itk_viewer.setPointSets(pointSet) + self.queue_request('setPointSets', pointSet) @fetch_value def set_rendering_view_container_style(self, containerStyle: Style): - self.viewer_rpc.itk_viewer.setRenderingViewContainerStyle(containerStyle) + self.queue_request('setRenderingViewContainerStyle', containerStyle) @fetch_value async def get_rendering_view_container_style(self): return await self.viewer_rpc.itk_viewer.getRenderingViewStyle() @fetch_value def set_rotate(self, enabled: bool): - self.viewer_rpc.itk_viewer.setRotateEnabled(enabled) + self.queue_request('setRotateEnabled', enabled) @fetch_value async def get_rotate(self): return await self.viewer_rpc.itk_viewer.getRotateEnabled() @fetch_value def set_ui_collapsed(self, collapsed: bool): - self.viewer_rpc.itk_viewer.setUICollapsed(collapsed) + self.queue_request('setUICollapsed', collapsed) @fetch_value async def get_ui_collapsed(self): return await self.viewer_rpc.itk_viewer.getUICollapsed() @fetch_value def set_units(self, units: str): - self.viewer_rpc.itk_viewer.setUnits(units) + self.queue_request('setUnits', units) @fetch_value async def get_units(self): return await self.viewer_rpc.itk_viewer.getUnits() @fetch_value def set_view_mode(self, mode: str): - self.viewer_rpc.itk_viewer.setViewMode(mode) + self.queue_request('setViewMode', mode) @fetch_value async def get_view_mode(self): return await self.viewer_rpc.itk_viewer.getViewMode() @fetch_value def set_x_slice(self, position: float): - self.viewer_rpc.itk_viewer.setXSlice(position) + self.queue_request('setXSlice', position) @fetch_value async def get_x_slice(self): return await self.viewer_rpc.itk_viewer.getXSlice() @fetch_value def set_y_slice(self, position: float): - self.viewer_rpc.itk_viewer.setYSlice(position) + self.queue_request('setYSlice', position) @fetch_value async def get_y_slice(self): return await self.viewer_rpc.itk_viewer.getYSlice() @fetch_value def set_z_slice(self, position: float): - self.viewer_rpc.itk_viewer.setZSlice(position) + self.queue_request('setZSlice', position) @fetch_value async def get_z_slice(self): return await self.viewer_rpc.itk_viewer.getZSlice()