Skip to content

Commit

Permalink
Move fixtures into conftest.py (#331)
Browse files Browse the repository at this point in the history
* Move `fixture`s into `conftest.py`

* Just import `btrack`

* Move `Container`

* Fix tests

* Rename writer

* Use `qtpy`
  • Loading branch information
paddyroddy authored Aug 7, 2023
1 parent 44a1187 commit 99301cc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 40 deletions.
52 changes: 38 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Union

import numpy as np
import numpy.typing as npt
import pytest
from qtpy import QtWidgets

import btrack

Expand All @@ -14,6 +16,16 @@
)


def _write_h5_file(file_path: os.PathLike, test_objects) -> os.PathLike:
"""
Write a h5 file with test objects and return path.
"""
with btrack.io.HDF5FileHandler(file_path, "w") as h:
h.write_objects(test_objects)

return file_path


@pytest.fixture
def test_objects():
"""
Expand All @@ -31,24 +43,14 @@ def test_real_objects():
return btrack.io.import_CSV(TEST_DATA_PATH / "test_data.csv")


def write_h5_file(file_path: os.PathLike, test_objects) -> os.PathLike:
"""
Write a h5 file with test objects and return path.
"""
with btrack.io.HDF5FileHandler(file_path, "w") as h:
h.write_objects(test_objects)

return file_path


@pytest.fixture
def hdf5_file_path(tmp_path, test_objects) -> os.PathLike:
"""
Create and save a btrack HDF5 file, and return the path.
Note that this only saves segmentation results, not tracking results.
"""
return write_h5_file(tmp_path / "test.h5", test_objects)
return _write_h5_file(tmp_path / "test.h5", test_objects)


@pytest.fixture(params=["single", "list"])
Expand All @@ -61,11 +63,11 @@ def hdf5_file_path_or_paths(
Note that this only saves segmentation results, not tracking results.
"""
if request.param == "single":
return write_h5_file(tmp_path / "test.h5", test_objects)
return _write_h5_file(tmp_path / "test.h5", test_objects)
elif request.param == "list":
return [
write_h5_file(tmp_path / "test1.h5", test_objects),
write_h5_file(tmp_path / "test2.h5", test_objects),
_write_h5_file(tmp_path / "test1.h5", test_objects),
_write_h5_file(tmp_path / "test2.h5", test_objects),
]
else:
raise ValueError("Invalid requests.param, must be one of 'single' or 'list'")
Expand All @@ -85,3 +87,25 @@ def default_rng():
Create a default PRNG to use for tests.
"""
return np.random.default_rng(seed=RANDOM_SEED)


@pytest.fixture
def track_widget(make_napari_viewer) -> QtWidgets.QWidget:
"""Provides an instance of the track widget to test"""
make_napari_viewer() # make sure there is a viewer available
return btrack.napari.main.create_btrack_widget()


@pytest.fixture
def simplistic_tracker_outputs() -> (
tuple[npt.NDArray, dict[str, npt.NDArray], dict[int, list]]
):
"""Provides simplistic return values of a btrack run.
They have the correct types and dimensions, but contain zeros.
Useful for mocking the tracker.
"""
n, d = 10, 3
data = np.zeros((n, d + 1))
properties = {"some_property": np.zeros(n)}
graph = {0: [0]}
return data, properties, graph
26 changes: 0 additions & 26 deletions tests/napari/test_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import json
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
import pytest
from qtpy import QtWidgets

import napari

Expand All @@ -32,13 +29,6 @@ def test_add_widget(make_napari_viewer):
assert len(list(viewer.window._dock_widgets)) == num_dw + 1


@pytest.fixture
def track_widget(make_napari_viewer) -> QtWidgets.QWidget:
"""Provides an instance of the track widget to test"""
make_napari_viewer() # make sure there is a viewer available
return btrack.napari.main.create_btrack_widget()


@pytest.mark.parametrize(
"config",
[btrack.datasets.cell_config(), btrack.datasets.particle_config()],
Expand Down Expand Up @@ -123,22 +113,6 @@ def test_reset_button(track_widget):
assert new_relax == original_relax


@pytest.fixture
def simplistic_tracker_outputs() -> (
tuple[npt.NDArray, dict[str, npt.NDArray], dict[int, list]]
):
"""Provides simplistic return values of a btrack run.
They have the correct types and dimensions, but contain zeros.
Useful for mocking the tracker.
"""
n, d = 10, 3
data = np.zeros((n, d + 1))
properties = {"some_property": np.zeros(n)}
graph = {0: [0]}
return data, properties, graph


def test_run_button(track_widget, simplistic_tracker_outputs):
"""Tests that clicking the run button calls run_tracker,
and that the napari viewer has an additional tracks layer after running.
Expand Down

0 comments on commit 99301cc

Please sign in to comment.