Skip to content

Commit

Permalink
Merge pull request #115 from ImageMarkup/add-metric-to-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
danlamanna authored Dec 23, 2020
2 parents 9d4d794 + c5b8249 commit 8001b1c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 56 deletions.
4 changes: 2 additions & 2 deletions isic_challenge_scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from importlib.metadata import PackageNotFoundError, version

from isic_challenge_scoring.classification import ClassificationScore, ValidationMetric
from isic_challenge_scoring.classification import ClassificationMetric, ClassificationScore
from isic_challenge_scoring.segmentation import SegmentationScore
from isic_challenge_scoring.types import ScoreException

__all__ = ['ClassificationScore', 'SegmentationScore', 'ScoreException', 'ValidationMetric']
__all__ = ['ClassificationScore', 'SegmentationScore', 'ScoreException', 'ClassificationMetric']

try:
__version__ = version('isic-challenge-scoring')
Expand Down
14 changes: 11 additions & 3 deletions isic_challenge_scoring/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import click
import click_pathlib

from isic_challenge_scoring.classification import ClassificationScore
from isic_challenge_scoring.classification import ClassificationMetric, ClassificationScore
from isic_challenge_scoring.segmentation import SegmentationScore
from isic_challenge_scoring.types import ScoreException

Expand Down Expand Up @@ -40,11 +40,19 @@ def segmentation(ctx: click.Context, truth_dir: pathlib.Path, prediction_dir: pa
@click.pass_context
@click.argument('truth_file', type=FilePath)
@click.argument('prediction_file', type=FilePath)
@click.option(
'-m',
'--metric',
type=click.Choice([metric.value for metric in ClassificationMetric]),
default=ClassificationMetric.BALANCED_ACCURACY.value,
)
def classification(
ctx: click.Context, truth_file: pathlib.Path, prediction_file: pathlib.Path
ctx: click.Context, truth_file: pathlib.Path, prediction_file: pathlib.Path, metric: str
) -> None:
try:
score = ClassificationScore.from_file(truth_file, prediction_file)
score = ClassificationScore.from_file(
truth_file, prediction_file, ClassificationMetric(metric)
)
except ScoreException as e:
raise click.ClickException(str(e))

Expand Down
76 changes: 38 additions & 38 deletions isic_challenge_scoring/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
import enum
import pathlib
from typing import Dict, Optional, TextIO, cast
from typing import Dict, TextIO, cast

import pandas as pd

Expand All @@ -13,7 +13,7 @@
from isic_challenge_scoring.types import DataFrameDict, RocDict, Score, ScoreDict, SeriesDict


class ValidationMetric(enum.Enum):
class ClassificationMetric(enum.Enum):
BALANCED_ACCURACY = 'balanced_accuracy'
AUC = 'auc'
AVERAGE_PRECISION = 'ap'
Expand All @@ -31,7 +31,7 @@ def __init__(
truth_probabilities: pd.DataFrame,
prediction_probabilities: pd.DataFrame,
truth_weights: pd.DataFrame,
validation_metric: Optional[ValidationMetric] = None,
target_metric: ClassificationMetric,
) -> None:
categories = truth_probabilities.columns

Expand Down Expand Up @@ -68,37 +68,37 @@ def __init__(
name='aggregate',
)

self.overall = self.aggregate.at['balanced_accuracy']

if validation_metric:
if validation_metric == ValidationMetric.BALANCED_ACCURACY:
self.validation = metrics.balanced_multiclass_accuracy(
truth_probabilities, prediction_probabilities, truth_weights.validation_weight
)
elif validation_metric == ValidationMetric.AVERAGE_PRECISION:
per_category_ap = pd.Series(
[
metrics.average_precision(
truth_probabilities[category],
prediction_probabilities[category],
truth_weights.validation_weight,
)
for category in categories
]
)
self.validation = per_category_ap.mean()
elif validation_metric == ValidationMetric.AUC:
per_category_auc = pd.Series(
[
metrics.auc(
truth_probabilities[category],
prediction_probabilities[category],
truth_weights.validation_weight,
)
for category in categories
]
)
self.validation = per_category_auc.mean()
if target_metric == ClassificationMetric.BALANCED_ACCURACY:
self.overall = self.aggregate.at['balanced_accuracy']
self.validation = metrics.balanced_multiclass_accuracy(
truth_probabilities, prediction_probabilities, truth_weights.validation_weight
)
elif target_metric == ClassificationMetric.AVERAGE_PRECISION:
self.overall = self.macro_average['ap']
per_category_ap = pd.Series(
[
metrics.average_precision(
truth_probabilities[category],
prediction_probabilities[category],
truth_weights.validation_weight,
)
for category in categories
]
)
self.validation = per_category_ap.mean()
elif target_metric == ClassificationMetric.AUC:
self.overall = self.macro_average['auc']
per_category_auc = pd.Series(
[
metrics.auc(
truth_probabilities[category],
prediction_probabilities[category],
truth_weights.validation_weight,
)
for category in categories
]
)
self.validation = per_category_auc.mean()

@staticmethod
def _category_score(
Expand Down Expand Up @@ -191,7 +191,7 @@ def from_stream(
cls,
truth_file_stream: TextIO,
prediction_file_stream: TextIO,
validation_metric: Optional[ValidationMetric] = None,
target_metric: ClassificationMetric,
) -> ClassificationScore:
truth_probabilities, truth_weights = parse_truth_csv(truth_file_stream)
categories = truth_probabilities.columns
Expand All @@ -202,21 +202,21 @@ def from_stream(
sort_rows(truth_probabilities)
sort_rows(prediction_probabilities)

score = cls(truth_probabilities, prediction_probabilities, truth_weights, validation_metric)
score = cls(truth_probabilities, prediction_probabilities, truth_weights, target_metric)
return score

@classmethod
def from_file(
cls,
truth_file: pathlib.Path,
prediction_file: pathlib.Path,
validation_metric: Optional[ValidationMetric] = None,
target_metric: ClassificationMetric,
) -> ClassificationScore:
with truth_file.open('r') as truth_file_stream, prediction_file.open(
'r'
) as prediction_file_stream:
return cls.from_stream(
truth_file_stream,
prediction_file_stream,
validation_metric,
target_metric,
)
22 changes: 9 additions & 13 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import pytest

from isic_challenge_scoring.classification import ClassificationScore, ValidationMetric


def test_score(classification_truth_file_path, classification_prediction_file_path):
assert ClassificationScore.from_file(
classification_truth_file_path, classification_prediction_file_path
)
from isic_challenge_scoring.classification import ClassificationMetric, ClassificationScore


@pytest.mark.parametrize(
'validation_metric',
[ValidationMetric.AUC, ValidationMetric.BALANCED_ACCURACY, ValidationMetric.AVERAGE_PRECISION],
'target_metric',
[
ClassificationMetric.AUC,
ClassificationMetric.BALANCED_ACCURACY,
ClassificationMetric.AVERAGE_PRECISION,
],
)
def test_score_validation_metric(
classification_truth_file_path, classification_prediction_file_path, validation_metric
):
def test_score(classification_truth_file_path, classification_prediction_file_path, target_metric):
score = ClassificationScore.from_file(
classification_truth_file_path,
classification_prediction_file_path,
validation_metric=validation_metric,
target_metric,
)
assert isinstance(score.validation, float)

0 comments on commit 8001b1c

Please sign in to comment.