diff --git a/btrack/config.py b/btrack/config.py index d4a13a51..c5aa15e3 100644 --- a/btrack/config.py +++ b/btrack/config.py @@ -66,6 +66,9 @@ class TrackerConfig(BaseModel): tracking_updates : list A list of features to be used for tracking, such as MOTION or VISUAL. Must have at least one entry. + enable_optimisation + A flag which, if `False`, will report a warning to the user if they then + subsequently run the `BayesianTracker.optimise()` step. Notes ----- @@ -92,6 +95,7 @@ class TrackerConfig(BaseModel): ) = [ constants.BayesianUpdateFeatures.MOTION, ] + enable_optimisation = True @validator("volume", pre=True, always=True) def _parse_volume(cls, v): diff --git a/btrack/core.py b/btrack/core.py index 2dc017ad..208d506e 100644 --- a/btrack/core.py +++ b/btrack/core.py @@ -529,6 +529,9 @@ def optimise(self, options: Optional[dict] = None) -> list[hypothesis.Hypothesis optimiser and then performs track merging, removal of track fragments, renumbering and assignment of branches. """ + if not self.configuration.enable_optimisation: + logger.warning("The `enable_optimisation` flag is set to False") + logger.info(f"Loading hypothesis model: {self.hypothesis_model.name}") logger.info(f"Calculating hypotheses (relax: {self.hypothesis_model.relax})...") diff --git a/btrack/napari/main.py b/btrack/napari/main.py index ea97e19e..cf4792a0 100644 --- a/btrack/napari/main.py +++ b/btrack/napari/main.py @@ -80,6 +80,14 @@ def create_btrack_widget() -> btrack.napari.widgets.BtrackWidget: lambda selected: select_config(btrack_widget, all_configs, selected), ) + # Disable the Optimiser tab if unchecked + for tab in range(btrack_widget._tabs.count()): + if btrack_widget._tabs.tabText(tab) == "Optimiser": + break + btrack_widget.enable_optimisation.toggled.connect( + lambda is_checked: btrack_widget._tabs.setTabEnabled(tab, is_checked) + ) + btrack_widget.call_button.clicked.connect( lambda: run(btrack_widget, all_configs), ) @@ -264,7 +272,11 @@ def _run_tracker( """ Runs BayesianTracker with given segmentation and configuration. """ - with btrack.BayesianTracker() as tracker, napari.utils.progress(total=5) as pbr: + num_steps = 5 if tracker_config.enable_optimisation else 4 + + with btrack.BayesianTracker() as tracker, napari.utils.progress( + total=num_steps + ) as pbr: pbr.set_description("Initialising the tracker") tracker.configure(tracker_config) pbr.update(1) @@ -287,10 +299,11 @@ def _run_tracker( tracker.track(step_size=100) pbr.update(1) - # generate hypotheses and run the global optimizer - pbr.set_description("Run optimisation") - tracker.optimize() - pbr.update(1) + if tracker.enable_optimisation: + # generate hypotheses and run the global optimizer + pbr.set_description("Run optimisation") + tracker.optimize() + pbr.update(1) # get the tracks in a format for napari visualization pbr.set_description("Convert to napari tracks layer") diff --git a/btrack/napari/sync.py b/btrack/napari/sync.py index 22655aba..73eb0011 100644 --- a/btrack/napari/sync.py +++ b/btrack/napari/sync.py @@ -20,27 +20,28 @@ def update_config_from_widgets( btrack_widget: btrack.napari.widgets.BtrackWidget, ) -> UnscaledTrackerConfig: """Update an UnscaledTrackerConfig with the current widget values.""" - - # Update MotionModel matrix scaling factors - sigmas: Sigmas = unscaled_config.sigmas - for matrix_name in sigmas: - sigmas[matrix_name] = btrack_widget[f"{matrix_name}_sigma"].value() - - # Update TrackerConfig values + ## Retrieve model configs config = unscaled_config.tracker_config - update_method_index = btrack_widget.update_method.currentIndex() + motion_model = config.motion_model + hypothesis_model = config.hypothesis_model - config.update_method = update_method_index + ## Update widgets from the Method tab + config.update_method = btrack_widget.update_method.currentIndex() config.max_search_radius = btrack_widget.max_search_radius.value() - - # Update MotionModel values - motion_model = config.motion_model - motion_model.accuracy = btrack_widget.accuracy.value() motion_model.max_lost = btrack_widget.max_lost.value() motion_model.prob_not_assign = btrack_widget.prob_not_assign.value() + config.enable_optimisation = ( + btrack_widget.enable_optimisation.checkState() == QtCore.Qt.CheckState.Checked + ) - # Update HypothesisModel.hypotheses values - hypothesis_model = config.hypothesis_model + ## Update widgets from the Motion tab + sigmas: Sigmas = unscaled_config.sigmas + for matrix_name in sigmas: + sigmas[matrix_name] = btrack_widget[f"{matrix_name}_sigma"].value() + motion_model.accuracy = btrack_widget.accuracy.value() + + ## Update widgets from the Optimiser tab + # HypothesisModel.hypotheses values hypothesis_model.hypotheses = [ hypothesis for i, hypothesis in enumerate(btrack.optimise.hypothesis.H_TYPES) @@ -48,7 +49,7 @@ def update_config_from_widgets( == QtCore.Qt.CheckState.Checked ] - # Update HypothesisModel scaling factors + # HypothesisModel scaling factors for scaling_factor in btrack.napari.constants.HYPOTHESIS_SCALING_FACTORS: setattr( hypothesis_model, @@ -56,13 +57,17 @@ def update_config_from_widgets( btrack_widget[scaling_factor].value(), ) - # Update HypothesisModel thresholds + # HypothesisModel thresholds for threshold in btrack.napari.constants.HYPOTHESIS_THRESHOLDS: setattr(hypothesis_model, threshold, btrack_widget[threshold].value()) + # other hypothesis_model.segmentation_miss_rate = ( btrack_widget.segmentation_miss_rate.value() ) + hypothesis_model.relax = ( + btrack_widget.relax.checkState() == QtCore.Qt.CheckState.Checked + ) return unscaled_config @@ -75,25 +80,26 @@ def update_widgets_from_config( Update the widgets in a btrack_widget with the values in an UnscaledTrackerConfig. """ - - # Update widgets from MotionModel matrix scaling factors - sigmas: Sigmas = unscaled_config.sigmas - for matrix_name in sigmas: - btrack_widget[f"{matrix_name}_sigma"].setValue(sigmas[matrix_name]) - - # Update widgets from TrackerConfig values + ## Retrieve model configs config = unscaled_config.tracker_config + motion_model = config.motion_model + hypothesis_model = config.hypothesis_model + + ## Update widgets from the Method tab btrack_widget.update_method.setCurrentText(config.update_method.name) btrack_widget.max_search_radius.setValue(config.max_search_radius) - - # Update widgets from MotionModel values - motion_model = config.motion_model - btrack_widget.accuracy.setValue(motion_model.accuracy) btrack_widget.max_lost.setValue(motion_model.max_lost) btrack_widget.prob_not_assign.setValue(motion_model.prob_not_assign) + btrack_widget.enable_optimisation.setChecked(config.enable_optimisation) - # Update widgets from HypothesisModel.hypotheses values - hypothesis_model = config.hypothesis_model + ## Update widgets from the Motion tab + sigmas: Sigmas = unscaled_config.sigmas + for matrix_name in sigmas: + btrack_widget[f"{matrix_name}_sigma"].setValue(sigmas[matrix_name]) + btrack_widget.accuracy.setValue(motion_model.accuracy) + + ## Update widgets from the Optimiser tab + # HypothesisModel.hypotheses values for i, hypothesis in enumerate(btrack.optimise.hypothesis.H_TYPES): is_checked = ( QtCore.Qt.CheckState.Checked @@ -102,19 +108,20 @@ def update_widgets_from_config( ) btrack_widget["hypotheses"].item(i).setCheckState(is_checked) - # Update widgets from HypothesisModel scaling factors + # HypothesisModel scaling factors for scaling_factor in btrack.napari.constants.HYPOTHESIS_SCALING_FACTORS: new_value = getattr(hypothesis_model, scaling_factor) btrack_widget[scaling_factor].setValue(new_value) - # Update widgets from HypothesisModel thresholds + # HypothesisModel thresholds for threshold in btrack.napari.constants.HYPOTHESIS_THRESHOLDS: new_value = getattr(hypothesis_model, threshold) btrack_widget[threshold].setValue(new_value) - btrack_widget.relax.setChecked(hypothesis_model.relax) + # other btrack_widget.segmentation_miss_rate.setValue( hypothesis_model.segmentation_miss_rate ) + btrack_widget.relax.setChecked(hypothesis_model.relax) return btrack_widget diff --git a/btrack/napari/widgets/_general.py b/btrack/napari/widgets/_general.py index ab69ea40..ab8be9c9 100644 --- a/btrack/napari/widgets/_general.py +++ b/btrack/napari/widgets/_general.py @@ -13,9 +13,16 @@ def create_logo_widgets() -> dict[str, QtWidgets.QWidget]: widgets = {"title": title} logo = QtWidgets.QLabel() + pixmap = QtGui.QPixmap( + str(Path(__file__).resolve().parents[1] / "assets" / "btrack_logo.png") + ) + logo.setAlignment(QtCore.Qt.AlignHCenter) + scale = 0.8 logo.setPixmap( - QtGui.QPixmap( - str(Path(__file__).resolve().parents[1] / "assets" / "btrack_logo.png") + pixmap.scaled( + int(pixmap.width() * scale), + int(pixmap.height() * scale), + QtCore.Qt.KeepAspectRatio, ) ) widgets["logo"] = logo @@ -52,7 +59,7 @@ def create_input_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: return widgets -def create_update_method_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: +def create_basic_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: """Create widgets for selecting the update method""" update_method = QtWidgets.QComboBox() @@ -98,6 +105,16 @@ def create_update_method_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: not_assign, ) + optimise = QtWidgets.QCheckBox() + optimise.setChecked(True) # noqa: FBT003 + optimise.setToolTip( + "Enable the track optimisation.\n" + "This means that tracks will be optimised using the hypotheses" + "specified in the optimiser tab." + ) + optimise.setTristate(False) # noqa: FBT003 + widgets["enable_optimisation"] = ("enable optimisation", optimise) + return widgets @@ -114,9 +131,9 @@ def create_config_widgets() -> dict[str, QtWidgets.QWidget]: "reset_button", ] labels = [ - "Load configuration", - "Save configuration", - "Reset defaults", + "Load Configuration", + "Save Configuration", + "Reset Defaults", ] tooltips = [ "Load a TrackerConfig json file.", diff --git a/btrack/napari/widgets/_hypothesis.py b/btrack/napari/widgets/_optimiser.py similarity index 95% rename from btrack/napari/widgets/_hypothesis.py rename to btrack/napari/widgets/_optimiser.py index a35c5b87..f3658b3f 100644 --- a/btrack/napari/widgets/_hypothesis.py +++ b/btrack/napari/widgets/_optimiser.py @@ -21,7 +21,7 @@ def _create_hypotheses_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: widget = QtWidgets.QListWidget() widget.addItems([f"{h.replace('_', '(')})" for h in hypotheses]) - flags = QtCore.Qt.ItemFlags(QtCore.Qt.ItemIsUserCheckable + QtCore.Qt.ItemIsEnabled) + flags = QtCore.Qt.ItemFlags(QtCore.Qt.ItemIsUserCheckable | QtCore.Qt.ItemIsEnabled) for i, tooltip in enumerate(tooltips): widget.item(i).setFlags(flags) widget.item(i).setToolTip(tooltip) @@ -37,12 +37,7 @@ def _create_hypotheses_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: def _create_scaling_factor_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: """Create widgets for setting the scaling factors of the HypothesisModel""" - names = [ - "lambda_time", - "lambda_dist", - "lambda_link", - "lambda_branch", - ] + names = btrack.napari.constants.HYPOTHESIS_SCALING_FACTORS labels = [ "λ time", "λ distance", @@ -118,7 +113,7 @@ def _create_bin_size_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: return widgets -def create_hypothesis_model_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: +def create_optimiser_widgets() -> dict[str, tuple[str, QtWidgets.QWidget]]: """Create widgets for setting parameters of the HypothesisModel""" widgets = { diff --git a/btrack/napari/widgets/create_ui.py b/btrack/napari/widgets/create_ui.py index e1b30b01..ca37703c 100644 --- a/btrack/napari/widgets/create_ui.py +++ b/btrack/napari/widgets/create_ui.py @@ -5,14 +5,14 @@ from napari.viewer import Viewer from btrack.napari.widgets._general import ( + create_basic_widgets, create_config_widgets, create_input_widgets, create_logo_widgets, create_track_widgets, - create_update_method_widgets, ) -from btrack.napari.widgets._hypothesis import create_hypothesis_model_widgets from btrack.napari.widgets._motion import create_motion_model_widgets +from btrack.napari.widgets._optimiser import create_optimiser_widgets class BtrackWidget(QtWidgets.QScrollArea): @@ -46,13 +46,13 @@ def __init__(self, napari_viewer: Viewer) -> None: self._add_logo_widgets() self._add_input_widgets() - # This must be added after the input widgets + self._add_track_widgets() + # This must be added after the track widget self._main_layout.addWidget(self._tabs, stretch=0) - self._add_update_method_widgets() + self._add_basic_widgets() self._add_motion_model_widgets() - self._add_hypothesis_model_widgets() + self._add_optimiser_widgets() self._add_config_widgets() - self._add_track_widgets() # Expand the main widget self._main_layout.addStretch(stretch=1) @@ -87,9 +87,9 @@ def _add_input_widgets(self) -> None: widget_holder.setLayout(layout) self._main_layout.addWidget(widget_holder, stretch=0) - def _add_update_method_widgets(self) -> None: + def _add_basic_widgets(self) -> None: """Create update method widgets and add to main layout""" - labels_and_widgets = create_update_method_widgets() + labels_and_widgets = create_basic_widgets() self._widgets.update( {key: value[1] for key, value in labels_and_widgets.items()} ) @@ -102,7 +102,7 @@ def _add_update_method_widgets(self) -> None: tab = QtWidgets.QWidget() tab.setLayout(layout) - self._tabs.addTab(tab, "Method") + self._tabs.addTab(tab, "Basic") def _add_motion_model_widgets(self) -> None: """Create motion model widgets and add to main layout""" @@ -121,9 +121,9 @@ def _add_motion_model_widgets(self) -> None: tab.setLayout(layout) self._tabs.addTab(tab, "Motion") - def _add_hypothesis_model_widgets(self) -> None: + def _add_optimiser_widgets(self) -> None: """Create hypothesis model widgets and add to main layout""" - labels_and_widgets = create_hypothesis_model_widgets() + labels_and_widgets = create_optimiser_widgets() self._widgets.update( {key: value[1] for key, value in labels_and_widgets.items()} ) @@ -136,7 +136,7 @@ def _add_hypothesis_model_widgets(self) -> None: tab = QtWidgets.QWidget() tab.setLayout(layout) - self._tabs.addTab(tab, "Hypothesis") + self._tabs.addTab(tab, "Optimiser") def _add_config_widgets(self) -> None: """Creates the IO widgets related to the user config""" diff --git a/tests/_utils.py b/tests/_utils.py index 22ff985c..f14b532d 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -213,7 +213,9 @@ def full_tracker_example( """Set up a full tracker example. kwargs can supply configuration options.""" # run the tracking tracker = btrack.BayesianTracker() - tracker.configure(CONFIG_FILE) + cfg = btrack.config.load_config(CONFIG_FILE) + cfg.motion_model.prob_not_assign = 0.001 + tracker.configure(cfg) for cfg_key, cfg_value in kwargs.items(): setattr(tracker, cfg_key, cfg_value) tracker.append(objects) diff --git a/tests/napari/test_dock_widget.py b/tests/napari/test_dock_widget.py index 4b59243b..53649f8c 100644 --- a/tests/napari/test_dock_widget.py +++ b/tests/napari/test_dock_widget.py @@ -98,19 +98,25 @@ def test_reset_button(track_widget): original_max_search_radius = track_widget.max_search_radius.value() original_relax = track_widget.relax.isChecked() + original_optimise = track_widget.enable_optimisation.isChecked() # change some widget values track_widget.max_search_radius.setValue(track_widget.max_search_radius.value() + 10) track_widget.relax.setChecked(not track_widget.relax.isChecked()) + track_widget.enable_optimisation.setChecked( + not track_widget.enable_optimisation.isChecked() + ) # click reset button - restores defaults of the currently-selected base config track_widget.reset_button.click() new_max_search_radius = track_widget.max_search_radius.value() new_relax = track_widget.relax.isChecked() + new_optimise = track_widget.enable_optimisation.isChecked() assert new_max_search_radius == original_max_search_radius assert new_relax == original_relax + assert new_optimise == original_optimise def test_run_button(track_widget, simplistic_tracker_outputs):