-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #500 from sarthakpati/overall_classification_metrics
Included classification-specific metrics using overall predictions and ground truths
- Loading branch information
Showing
9 changed files
with
165 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torchmetrics as tm | ||
|
||
|
||
def get_output_from_calculator(predictions, ground_truth, calculator): | ||
""" | ||
Helper function to get the output from a calculator. | ||
Args: | ||
predictions (torch.Tensor): The output of the model. | ||
ground_truth (torch.Tensor): The ground truth labels. | ||
calculator (torchmetrics.Metric): The calculator to use. | ||
Returns: | ||
float: The output from the calculator. | ||
""" | ||
temp_output = calculator(predictions, ground_truth) | ||
if temp_output.dim() > 0: | ||
temp_output = temp_output.cpu().tolist() | ||
else: | ||
temp_output = temp_output.cpu().item() | ||
return temp_output | ||
|
||
|
||
def overall_stats(predictions, ground_truth, params): | ||
""" | ||
Generates a dictionary of metrics calculated on the overall predictions and ground truths. | ||
Args: | ||
predictions (torch.Tensor): The output of the model. | ||
ground_truth (torch.Tensor): The ground truth labels. | ||
params (dict): The parameter dictionary containing training and data information. | ||
Returns: | ||
dict: A dictionary of metrics. | ||
""" | ||
assert ( | ||
params["problem_type"] == "classification" | ||
), "Only classification is supported for overall stats" | ||
assert len(predictions) == len( | ||
ground_truth | ||
), "Predictions and ground truth must be of same length" | ||
|
||
output_metrics = {} | ||
|
||
average_types_keys = { | ||
"global": "micro", | ||
"per_class": "none", | ||
"per_class_average": "macro", | ||
"per_class_weighted": "weighted", | ||
} | ||
# metrics that need the "average" parameter | ||
for average_type, average_type_key in average_types_keys.items(): | ||
calculators = { | ||
"accuracy": tm.Accuracy( | ||
num_classes=params["model"]["num_classes"], average=average_type_key | ||
), | ||
"precision": tm.Precision( | ||
num_classes=params["model"]["num_classes"], average=average_type_key | ||
), | ||
"recall": tm.Recall( | ||
num_classes=params["model"]["num_classes"], average=average_type_key | ||
), | ||
"f1": tm.F1( | ||
num_classes=params["model"]["num_classes"], average=average_type_key | ||
), | ||
## weird error for multi-class problem, where pos_label is not getting set | ||
# "aucroc": tm.AUROC( | ||
# num_classes=params["model"]["num_classes"], average=average_type_key | ||
# ), | ||
} | ||
for metric_name, calculator in calculators.items(): | ||
output_metrics[ | ||
f"{metric_name}_{average_type}" | ||
] = get_output_from_calculator(predictions, ground_truth, calculator) | ||
# metrics that do not have any "average" parameter | ||
calculators = { | ||
"auc": tm.AUC(reorder=True), | ||
## weird error for multi-class problem, where pos_label is not getting set | ||
# "roc": tm.ROC(num_classes=params["model"]["num_classes"]), | ||
} | ||
for metric_name, calculator in calculators.items(): | ||
output_metrics[metric_name] = get_output_from_calculator( | ||
predictions, ground_truth, calculator | ||
) | ||
|
||
return output_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters