Skip to content

Commit

Permalink
Fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kostrykin committed Mar 10, 2024
1 parent 15734c3 commit 5a19dd3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
4 changes: 3 additions & 1 deletion segmetrics/measure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import (
List,
Callable,
List,
Protocol,
runtime_checkable,
)

from scipy import ndimage
Expand All @@ -10,6 +11,7 @@
from segmetrics.typing import LabelImage


@runtime_checkable
class MeasureProtocol(Protocol):
"""
Type protocol of performance measures.
Expand Down
14 changes: 7 additions & 7 deletions segmetrics/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import scipy.stats.mstats
import skimage.measure

from segmetrics.measure import Measure
from segmetrics.measure import MeasureProtocol
from segmetrics.typing import (
Image,
LabelImage,
Expand Down Expand Up @@ -71,7 +71,7 @@ def _label(im: Image, background: int = 0, neighbors: int = 4) -> LabelImage:


def _aggregate(
measure: Measure,
measure: MeasureProtocol,
values: List[float],
num_objects: int,
) -> float:
Expand Down Expand Up @@ -99,7 +99,7 @@ class Study:
"""

def __init__(self) -> None:
self.measures: Dict[str, Measure] = dict()
self.measures: Dict[str, MeasureProtocol] = dict()
self.csv_sample_id_column_name: str = 'Sample'

self._num_objects: Dict[Any, int] = dict()
Expand Down Expand Up @@ -146,7 +146,7 @@ def merge(
self._sample_ids.append(sample_id)
self._results_cache.clear()

def add_measure(self, measure: Measure, name: Optional[str] = None):
def add_measure(self, measure: MeasureProtocol, name: Optional[str] = None):
"""
Adds a performance measure to this study.
Expand All @@ -161,9 +161,9 @@ def add_measure(self, measure: Measure, name: Optional[str] = None):
:return:
The name used for the measure (see above).
"""
if not isinstance(measure, Measure):
if not isinstance(measure, MeasureProtocol):
raise ValueError(
'Argument "measure" must be a Measure object'
'Argument "measure" must implement MeasureProtocol'
f' ({type(measure)}, {measure})'
)
if name is None:
Expand Down Expand Up @@ -265,7 +265,7 @@ def process(

intermediate_results: Dict[str, List[float]] = dict()
for measure_name in self.measures:
measure: Measure = self.measures[measure_name]
measure: MeasureProtocol = self.measures[measure_name]
result: List[float] = measure.compute(actual)
self._results[measure_name][sample_id] = result
intermediate_results[measure_name] = result
Expand Down

0 comments on commit 5a19dd3

Please sign in to comment.