Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati authored Aug 14, 2024
2 parents b7c9930 + 72bd1db commit bc66a28
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,17 @@ def __convert_tensor_to_int(input_tensor: torch.Tensor) -> torch.Tensor:
),
}
for metric_name, calculator in calculators.items():
metric_prediction = prediction
metric_target = target
if "auroc" in metric_name:
output_metrics[metric_name] = get_output_from_calculator(
predictions_prob, target_wrap, calculator
)
else:
output_metrics[metric_name] = get_output_from_calculator(
prediction, target, calculator
)
metric_prediction = predictions_prob
metric_target = target_wrap
if task == "binary":
metric_prediction = predictions_prob[:, 1]

output_metrics[metric_name] = get_output_from_calculator(
metric_prediction, metric_target, calculator
)

# metrics that do not need the "average" parameter
calculators = {
Expand Down

0 comments on commit bc66a28

Please sign in to comment.