From 06eb52d056753150e9404e1ee0c8024b73e50ca8 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 18:34:51 -0400 Subject: [PATCH 01/22] added baseline implementation to calculate global stats for classification --- GANDLF/metrics/regression.py | 50 ++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index fa5274251..eb7473ee1 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -4,6 +4,7 @@ import torch from sklearn.metrics import balanced_accuracy_score import numpy as np +import torchmetrics as tm def classification_accuracy(output, label, params): @@ -73,3 +74,52 @@ def per_label_accuracy(output, label, params): return output_accuracy / len(output) else: return balanced_acc_score(output, label, params) + + +def overall_stats(predictions, ground_truth, params): + """ + Generates a dictionary of metrics calculated on the overall predictions and ground truths. + + Args: + predictions (torch.Tensor): The output of the model. + ground_truth (torch.Tensor): The ground truth labels. + params (dict): The parameter dictionary containing training and data information. + + Returns: + dict: A dictionary of metrics. + """ + assert ( + params["problem_type"] == "classification" + ), "Only classification is supported for overall stats" + assert len(predictions) == len( + ground_truth + ), "Predictions and ground truth must be of same length" + + output_metrics = {} + average_types_keys = { + "global": "micro", + "per_class": "none", + "per_class_average": "macro", + "per_class_weighted": "weighted", + } + for average_type, average_type_key in average_types_keys.items(): + calculators = { + "accuracy": tm.Accuracy( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "precision": tm.Precision( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "recall": tm.Recall( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "f1": tm.F1( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + } + for metric_name, calculator in calculators.items(): + output_metrics[f"{metric_name}_{average_type}"] = calculator( + predictions, ground_truth + ) + + return output_metrics From 7b7a86af1830bb229f328c595188e4d2e6a9c6d5 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 18:52:01 -0400 Subject: [PATCH 02/22] added a shell output to make api standardized --- GANDLF/compute/step.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/GANDLF/compute/step.py b/GANDLF/compute/step.py index edfa6565e..fc506390d 100644 --- a/GANDLF/compute/step.py +++ b/GANDLF/compute/step.py @@ -80,6 +80,7 @@ def step(model, image, label, params, train=True): else: output = model(image) + attention_map = None if "medcam_enabled" in params and params["medcam_enabled"]: output, attention_map = output @@ -97,7 +98,4 @@ def step(model, image, label, params, train=True): if "medcam_enabled" in params and params["medcam_enabled"]: attention_map = torch.unsqueeze(attention_map, -1) - if not ("medcam_enabled" in params and params["medcam_enabled"]): - return loss, metric_output, output - else: - return loss, metric_output, output, attention_map + return loss, metric_output, output, attention_map From 1f7ed65977f15bebf186d892663a02ef111a1a88 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 18:52:40 -0400 Subject: [PATCH 03/22] using new step api and `overall_stats` function for training --- GANDLF/compute/training_loop.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 811055896..9ac2afcdc 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -17,6 +17,7 @@ write_training_patches, print_model_summary, ) +from GANDLF.metrics.regression import overall_stats from GANDLF.logger import Logger from .step import step from .forward_pass import validate_network @@ -69,6 +70,16 @@ def train_network(model, train_dataloader, optimizer, params): if params["verbose"]: print("Using Automatic mixed precision", flush=True) + # get ground truths + if params["problem_type"] == "classification": + ground_truth_array = ( + params["training_data"][ + params["training_data"].columns[params["headers"]["predictionHeaders"]] + ] + .to_numpy() + .ravel() + ) + predictions_array = np.zeros_like(ground_truth_array) # Set the model to train model.train() for batch_idx, (subject) in enumerate( @@ -104,7 +115,15 @@ def train_network(model, train_dataloader, optimizer, params): params["subject_spacing"] = subject["spacing"] else: params["subject_spacing"] = None - loss, calculated_metrics, _ = step(model, image, label, params) + loss, calculated_metrics, output, _ = step(model, image, label, params) + # store predictions for classification + if params["problem_type"] == "classification": + predictions_array[ + batch_idx + * params["batch_size"] : (batch_idx + 1) + * params["batch_size"] + ] = output + nan_loss = torch.isnan(loss) second_order = ( hasattr(optimizer, "is_second_order") and optimizer.is_second_order @@ -175,6 +194,12 @@ def train_network(model, train_dataloader, optimizer, params): average_epoch_train_loss = total_epoch_train_loss / len(train_dataloader) print(" Epoch Final train loss : ", average_epoch_train_loss) + + # get overall stats for classification + if params["problem_type"] == "classification": + average_epoch_train_metric = overall_stats( + predictions_array, ground_truth_array, params + ) for metric in params["metrics"]: if isinstance(total_epoch_train_metric[metric], np.ndarray): to_print = ( From 9d3603594499c896d585c7b4da6d51780f15ab44 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 18:52:45 -0400 Subject: [PATCH 04/22] using new step api and `overall_stats` function for validation --- GANDLF/compute/forward_pass.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index f109b8d44..f0ddfd966 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -16,6 +16,7 @@ resample_image, reverse_one_hot, ) +from GANDLF.metrics.regression import overall_stats from tqdm import tqdm @@ -99,6 +100,19 @@ def validate_network( "output_predictions_" + get_unique_timestamp() + ".csv", ) + # get ground truths for classification problem, validation set + if is_classification and mode == "validation": + ground_truth_array = ( + params["validation_data"][ + params["validation_data"].columns[ + params["headers"]["predictionHeaders"] + ] + ] + .to_numpy() + .ravel() + ) + predictions_array = np.zeros_like(ground_truth_array) + for batch_idx, (subject) in enumerate( tqdm(valid_dataloader, desc="Looping over " + mode + " data") ): @@ -192,6 +206,9 @@ def validate_network( final_loss, final_metric = get_loss_and_metrics( image, valuesToPredict, pred_output, params ) + + if is_classification: + predictions_array[batch_idx] = pred_output.max().item() # # Non network validation related total_epoch_valid_loss += final_loss.detach().cpu().item() for metric in final_metric.keys(): @@ -283,7 +300,7 @@ def validate_network( attention_map, patches_batch[torchio.LOCATION] ) else: - _, _, output = result + _, _, output, _ = result if params["problem_type"] == "segmentation": aggregator.add_batch( @@ -359,6 +376,8 @@ def validate_network( else: # final regression output output_prediction = output_prediction / len(patch_loader) + if is_classification: + predictions_array[batch_idx] = output_prediction if params["save_output"]: outputToWrite += ( str(epoch) @@ -453,6 +472,11 @@ def validate_network( if label_ground_truth is not None: average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader) print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss) + # get overall stats for classification + if is_classification: + average_epoch_valid_metric = overall_stats( + predictions_array, ground_truth_array, params + ) for metric in params["metrics"]: if isinstance(total_epoch_valid_metric[metric], np.ndarray): to_print = ( From 05840c07de1d7afb1266254d6f268524e9a30c27 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:05:49 -0400 Subject: [PATCH 05/22] syntax updated --- GANDLF/compute/forward_pass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index f0ddfd966..2823a522a 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -102,7 +102,7 @@ def validate_network( # get ground truths for classification problem, validation set if is_classification and mode == "validation": - ground_truth_array = ( + ground_truth_array = torch.from_numpy( params["validation_data"][ params["validation_data"].columns[ params["headers"]["predictionHeaders"] @@ -110,8 +110,8 @@ def validate_network( ] .to_numpy() .ravel() - ) - predictions_array = np.zeros_like(ground_truth_array) + ).type(torch.int) + predictions_array = torch.zeros_like(ground_truth_array) for batch_idx, (subject) in enumerate( tqdm(valid_dataloader, desc="Looping over " + mode + " data") From 89af597485d5b2c96d89bdc64eb1b9d92bca1ec6 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:06:04 -0400 Subject: [PATCH 06/22] syntax updated --- GANDLF/compute/training_loop.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 9ac2afcdc..2b513db8e 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -72,14 +72,14 @@ def train_network(model, train_dataloader, optimizer, params): # get ground truths if params["problem_type"] == "classification": - ground_truth_array = ( + ground_truth_array = torch.from_numpy( params["training_data"][ params["training_data"].columns[params["headers"]["predictionHeaders"]] ] .to_numpy() .ravel() - ) - predictions_array = np.zeros_like(ground_truth_array) + ).type(torch.int) + predictions_array = torch.zeros_like(ground_truth_array) # Set the model to train model.train() for batch_idx, (subject) in enumerate( @@ -118,11 +118,10 @@ def train_network(model, train_dataloader, optimizer, params): loss, calculated_metrics, output, _ = step(model, image, label, params) # store predictions for classification if params["problem_type"] == "classification": - predictions_array[ - batch_idx - * params["batch_size"] : (batch_idx + 1) - * params["batch_size"] - ] = output + for i in range(params["batch_size"]): + predictions_array[batch_idx * params["batch_size"] + i] = ( + torch.argmax(output[i], 0).cpu().item() + ) nan_loss = torch.isnan(loss) second_order = ( From e70b28ad4f84fc212f2b5b4266f93f663d5b5355 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:08:18 -0400 Subject: [PATCH 07/22] syntax updated --- GANDLF/compute/training_loop.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 2b513db8e..bbcef615b 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -118,10 +118,15 @@ def train_network(model, train_dataloader, optimizer, params): loss, calculated_metrics, output, _ = step(model, image, label, params) # store predictions for classification if params["problem_type"] == "classification": - for i in range(params["batch_size"]): - predictions_array[batch_idx * params["batch_size"] + i] = ( - torch.argmax(output[i], 0).cpu().item() - ) + # for i in range(params["batch_size"]): + # predictions_array[batch_idx * params["batch_size"] + i] = ( + # torch.argmax(output[i], 0).cpu().item() + # ) + predictions_array[ + batch_idx + * params["batch_size"] : (batch_idx + 1) + * params["batch_size"] + ] = (torch.argmax(output[0], 0).cpu().item()) nan_loss = torch.isnan(loss) second_order = ( From 2db7c280bbe78a83eb2171a37cb1fdbfb0ea3f3c Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:28:55 -0400 Subject: [PATCH 08/22] updated logic to return simple data structures instead of tensors --- GANDLF/metrics/regression.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index eb7473ee1..a54c18268 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -118,8 +118,11 @@ def overall_stats(predictions, ground_truth, params): ), } for metric_name, calculator in calculators.items(): - output_metrics[f"{metric_name}_{average_type}"] = calculator( - predictions, ground_truth - ) + temp_output = calculator(predictions, ground_truth) + if temp_output.dim() > 0: + temp_output = temp_output.cpu().tolist() + else: + temp_output = temp_output.cpu().item() + output_metrics[f"{metric_name}_{average_type}"] = temp_output return output_metrics From 997fbf375d327a23fae2951b0245c6e578c36f24 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:29:12 -0400 Subject: [PATCH 09/22] added another check to ensure this only happens for validation --- GANDLF/compute/forward_pass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index 2823a522a..7bba243ec 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -207,7 +207,7 @@ def validate_network( image, valuesToPredict, pred_output, params ) - if is_classification: + if is_classification and mode == "validation": predictions_array[batch_idx] = pred_output.max().item() # # Non network validation related total_epoch_valid_loss += final_loss.detach().cpu().item() @@ -376,7 +376,7 @@ def validate_network( else: # final regression output output_prediction = output_prediction / len(patch_loader) - if is_classification: + if is_classification and mode == "validation": predictions_array[batch_idx] = output_prediction if params["save_output"]: outputToWrite += ( @@ -473,7 +473,7 @@ def validate_network( average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader) print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss) # get overall stats for classification - if is_classification: + if is_classification and mode == "validation": average_epoch_valid_metric = overall_stats( predictions_array, ground_truth_array, params ) @@ -485,6 +485,7 @@ def validate_network( else: to_print = total_epoch_valid_metric[metric] / len(valid_dataloader) average_epoch_valid_metric[metric] = to_print + for metric in average_epoch_valid_metric.keys(): print( " Epoch Final " + mode + " " + metric + " : ", average_epoch_valid_metric[metric], From 7040fac5d37204f2c7e58ccd4b0854e1ce01941f Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 15 Sep 2022 21:29:29 -0400 Subject: [PATCH 10/22] ensure correct print --- GANDLF/compute/training_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index bbcef615b..233d2e143 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -212,6 +212,7 @@ def train_network(model, train_dataloader, optimizer, params): else: to_print = total_epoch_train_metric[metric] / len(train_dataloader) average_epoch_train_metric[metric] = to_print + for metric in average_epoch_train_metric.keys(): print( " Epoch Final train " + metric + " : ", average_epoch_train_metric[metric], From a929295099b3b25b4474cd4694cfb2afb8b8fe45 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 10:11:22 -0400 Subject: [PATCH 11/22] added auc --- GANDLF/metrics/regression.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index a54c18268..ffc5615c5 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -102,6 +102,7 @@ def overall_stats(predictions, ground_truth, params): "per_class_average": "macro", "per_class_weighted": "weighted", } + # metrics that need the "average" parameter for average_type, average_type_key in average_types_keys.items(): calculators = { "accuracy": tm.Accuracy( @@ -116,6 +117,10 @@ def overall_stats(predictions, ground_truth, params): "f1": tm.F1( num_classes=params["model"]["num_classes"], average=average_type_key ), + ## weird error for multi-class problem, where pos_label is not getting set + # "aucroc": tm.AUROC( + # num_classes=params["model"]["num_classes"], average=average_type_key + # ), } for metric_name, calculator in calculators.items(): temp_output = calculator(predictions, ground_truth) @@ -124,5 +129,18 @@ def overall_stats(predictions, ground_truth, params): else: temp_output = temp_output.cpu().item() output_metrics[f"{metric_name}_{average_type}"] = temp_output + # metrics that do not have any "average" parameter + calculators = { + "auc": tm.AUC(reorder=True), + ## weird error for multi-class problem, where pos_label is not getting set + # "roc": tm.ROC(num_classes=params["model"]["num_classes"]), + } + for metric_name, calculator in calculators.items(): + temp_output = calculator(predictions, ground_truth) + if temp_output.dim() > 0: + temp_output = temp_output.cpu().tolist() + else: + temp_output = temp_output.cpu().item() + output_metrics[metric_name] = temp_output return output_metrics From f429cb168c36a059456a7ae2ef33d0deede5b25e Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:19:34 -0400 Subject: [PATCH 12/22] renamed to make it clearer --- GANDLF/metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index f989684c3..c6fb154b9 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -35,5 +35,5 @@ "recall": recall_score, "iou": iou_score, "balanced_accuracy": balanced_acc_score, - "per_label_accuracy": per_label_accuracy, + "per_label_one_hot_accuracy": per_label_accuracy, } From 1255aedbdf59cf037d66a93e68b899913a6d75a5 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:19:41 -0400 Subject: [PATCH 13/22] rename completed --- testing/config_classification.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/config_classification.yaml b/testing/config_classification.yaml index 5e3d78d09..4fc15f733 100644 --- a/testing/config_classification.yaml +++ b/testing/config_classification.yaml @@ -14,7 +14,7 @@ metrics: } - accuracy - balanced_accuracy - - per_label_accuracy + - per_label_one_hot_accuracy - precision: { average: weighted, } From 60b3f7292139549550b5fdc85bf7b032b092c3e4 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:22:09 -0400 Subject: [PATCH 14/22] note updated --- HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.md b/HISTORY.md index 3c90e6387..b8758b86f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,6 +9,7 @@ - Per class accuracy has been added as a metric - Dedicated rescaling preprocessing function added for increased flexibility - Largest Connected Component Analysis is now added +- Included metrics using overall predictions and ground truths ## 0.0.14 From 9f374e100c1f5b588423f159e605fa9e9a15911f Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:28:34 -0400 Subject: [PATCH 15/22] including the classification metrics separately --- GANDLF/metrics/classification.py | 72 ++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 GANDLF/metrics/classification.py diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py new file mode 100644 index 000000000..c9a730530 --- /dev/null +++ b/GANDLF/metrics/classification.py @@ -0,0 +1,72 @@ +import torchmetrics as tm + + +def overall_stats(predictions, ground_truth, params): + """ + Generates a dictionary of metrics calculated on the overall predictions and ground truths. + + Args: + predictions (torch.Tensor): The output of the model. + ground_truth (torch.Tensor): The ground truth labels. + params (dict): The parameter dictionary containing training and data information. + + Returns: + dict: A dictionary of metrics. + """ + assert ( + params["problem_type"] == "classification" + ), "Only classification is supported for overall stats" + assert len(predictions) == len( + ground_truth + ), "Predictions and ground truth must be of same length" + + output_metrics = {} + + average_types_keys = { + "global": "micro", + "per_class": "none", + "per_class_average": "macro", + "per_class_weighted": "weighted", + } + # metrics that need the "average" parameter + for average_type, average_type_key in average_types_keys.items(): + calculators = { + "accuracy": tm.Accuracy( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "precision": tm.Precision( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "recall": tm.Recall( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + "f1": tm.F1( + num_classes=params["model"]["num_classes"], average=average_type_key + ), + ## weird error for multi-class problem, where pos_label is not getting set + # "aucroc": tm.AUROC( + # num_classes=params["model"]["num_classes"], average=average_type_key + # ), + } + for metric_name, calculator in calculators.items(): + temp_output = calculator(predictions, ground_truth) + if temp_output.dim() > 0: + temp_output = temp_output.cpu().tolist() + else: + temp_output = temp_output.cpu().item() + output_metrics[f"{metric_name}_{average_type}"] = temp_output + # metrics that do not have any "average" parameter + calculators = { + "auc": tm.AUC(reorder=True), + ## weird error for multi-class problem, where pos_label is not getting set + # "roc": tm.ROC(num_classes=params["model"]["num_classes"]), + } + for metric_name, calculator in calculators.items(): + temp_output = calculator(predictions, ground_truth) + if temp_output.dim() > 0: + temp_output = temp_output.cpu().tolist() + else: + temp_output = temp_output.cpu().item() + output_metrics[metric_name] = temp_output + + return output_metrics From 0e13d14b1238e0f452fa78284a4ac691389f450d Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:28:43 -0400 Subject: [PATCH 16/22] calling new module --- GANDLF/compute/forward_pass.py | 2 +- GANDLF/compute/training_loop.py | 2 +- GANDLF/metrics/__init__.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index 7bba243ec..0f1546bb9 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -16,7 +16,7 @@ resample_image, reverse_one_hot, ) -from GANDLF.metrics.regression import overall_stats +from GANDLF.metrics import overall_stats from tqdm import tqdm diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 233d2e143..af933846a 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -17,7 +17,7 @@ write_training_patches, print_model_summary, ) -from GANDLF.metrics.regression import overall_stats +from GANDLF.metrics import overall_stats from GANDLF.logger import Logger from .step import step from .forward_pass import validate_network diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index c6fb154b9..caea9bcb6 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -12,6 +12,7 @@ ) from .regression import classification_accuracy, balanced_acc_score, per_label_accuracy from .generic import recall_score, precision_score, iou_score, f1_score, accuracy +from .classification import overall_stats # global defines for the metrics From 7258e71e0995a7d66d4192252d7dda98a63e1aea Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:28:51 -0400 Subject: [PATCH 17/22] removed older implementation --- GANDLF/metrics/regression.py | 71 ------------------------------------ 1 file changed, 71 deletions(-) diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index ffc5615c5..fa5274251 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -4,7 +4,6 @@ import torch from sklearn.metrics import balanced_accuracy_score import numpy as np -import torchmetrics as tm def classification_accuracy(output, label, params): @@ -74,73 +73,3 @@ def per_label_accuracy(output, label, params): return output_accuracy / len(output) else: return balanced_acc_score(output, label, params) - - -def overall_stats(predictions, ground_truth, params): - """ - Generates a dictionary of metrics calculated on the overall predictions and ground truths. - - Args: - predictions (torch.Tensor): The output of the model. - ground_truth (torch.Tensor): The ground truth labels. - params (dict): The parameter dictionary containing training and data information. - - Returns: - dict: A dictionary of metrics. - """ - assert ( - params["problem_type"] == "classification" - ), "Only classification is supported for overall stats" - assert len(predictions) == len( - ground_truth - ), "Predictions and ground truth must be of same length" - - output_metrics = {} - average_types_keys = { - "global": "micro", - "per_class": "none", - "per_class_average": "macro", - "per_class_weighted": "weighted", - } - # metrics that need the "average" parameter - for average_type, average_type_key in average_types_keys.items(): - calculators = { - "accuracy": tm.Accuracy( - num_classes=params["model"]["num_classes"], average=average_type_key - ), - "precision": tm.Precision( - num_classes=params["model"]["num_classes"], average=average_type_key - ), - "recall": tm.Recall( - num_classes=params["model"]["num_classes"], average=average_type_key - ), - "f1": tm.F1( - num_classes=params["model"]["num_classes"], average=average_type_key - ), - ## weird error for multi-class problem, where pos_label is not getting set - # "aucroc": tm.AUROC( - # num_classes=params["model"]["num_classes"], average=average_type_key - # ), - } - for metric_name, calculator in calculators.items(): - temp_output = calculator(predictions, ground_truth) - if temp_output.dim() > 0: - temp_output = temp_output.cpu().tolist() - else: - temp_output = temp_output.cpu().item() - output_metrics[f"{metric_name}_{average_type}"] = temp_output - # metrics that do not have any "average" parameter - calculators = { - "auc": tm.AUC(reorder=True), - ## weird error for multi-class problem, where pos_label is not getting set - # "roc": tm.ROC(num_classes=params["model"]["num_classes"]), - } - for metric_name, calculator in calculators.items(): - temp_output = calculator(predictions, ground_truth) - if temp_output.dim() > 0: - temp_output = temp_output.cpu().tolist() - else: - temp_output = temp_output.cpu().item() - output_metrics[metric_name] = temp_output - - return output_metrics From 5a54e23ab1371cb3b0b0b25f57c192a3858b035f Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:49:44 -0400 Subject: [PATCH 18/22] ensuring correct prediction is picked up --- GANDLF/compute/forward_pass.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index 0f1546bb9..ac88f7e72 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -208,7 +208,9 @@ def validate_network( ) if is_classification and mode == "validation": - predictions_array[batch_idx] = pred_output.max().item() + predictions_array[batch_idx] = ( + torch.argmax(pred_output[0], 0).cpu().item() + ) # # Non network validation related total_epoch_valid_loss += final_loss.detach().cpu().item() for metric in final_metric.keys(): @@ -377,7 +379,9 @@ def validate_network( # final regression output output_prediction = output_prediction / len(patch_loader) if is_classification and mode == "validation": - predictions_array[batch_idx] = output_prediction + predictions_array[batch_idx] = ( + torch.argmax(output_prediction[0], 0).cpu().item() + ) if params["save_output"]: outputToWrite += ( str(epoch) From b85957c03725876b8bd60dd6ea7eb2aedcca7861 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:49:52 -0400 Subject: [PATCH 19/22] no need for extra code --- GANDLF/compute/training_loop.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index af933846a..99f6c46ed 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -118,10 +118,6 @@ def train_network(model, train_dataloader, optimizer, params): loss, calculated_metrics, output, _ = step(model, image, label, params) # store predictions for classification if params["problem_type"] == "classification": - # for i in range(params["batch_size"]): - # predictions_array[batch_idx * params["batch_size"] + i] = ( - # torch.argmax(output[i], 0).cpu().item() - # ) predictions_array[ batch_idx * params["batch_size"] : (batch_idx + 1) From 0827df1859221a15066ded625fb0dc8252ccc972 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 16:53:43 -0400 Subject: [PATCH 20/22] reduce code duplication --- GANDLF/metrics/classification.py | 38 ++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py index c9a730530..4d5e350ed 100644 --- a/GANDLF/metrics/classification.py +++ b/GANDLF/metrics/classification.py @@ -1,6 +1,26 @@ import torchmetrics as tm +def get_output_from_calculator(predictions, ground_truth, calculator): + """ + Helper function to get the output from a calculator. + + Args: + predictions (torch.Tensor): The output of the model. + ground_truth (torch.Tensor): The ground truth labels. + calculator (torchmetrics.Metric): The calculator to use. + + Returns: + float: The output from the calculator. + """ + temp_output = calculator(predictions, ground_truth) + if temp_output.dim() > 0: + temp_output = temp_output.cpu().tolist() + else: + temp_output = temp_output.cpu().item() + return temp_output + + def overall_stats(predictions, ground_truth, params): """ Generates a dictionary of metrics calculated on the overall predictions and ground truths. @@ -49,12 +69,9 @@ def overall_stats(predictions, ground_truth, params): # ), } for metric_name, calculator in calculators.items(): - temp_output = calculator(predictions, ground_truth) - if temp_output.dim() > 0: - temp_output = temp_output.cpu().tolist() - else: - temp_output = temp_output.cpu().item() - output_metrics[f"{metric_name}_{average_type}"] = temp_output + output_metrics[ + f"{metric_name}_{average_type}" + ] = get_output_from_calculator(predictions, ground_truth, calculator) # metrics that do not have any "average" parameter calculators = { "auc": tm.AUC(reorder=True), @@ -62,11 +79,8 @@ def overall_stats(predictions, ground_truth, params): # "roc": tm.ROC(num_classes=params["model"]["num_classes"]), } for metric_name, calculator in calculators.items(): - temp_output = calculator(predictions, ground_truth) - if temp_output.dim() > 0: - temp_output = temp_output.cpu().tolist() - else: - temp_output = temp_output.cpu().item() - output_metrics[metric_name] = temp_output + output_metrics[metric_name] = get_output_from_calculator( + predictions, ground_truth, calculator + ) return output_metrics From fb3718f16cab285de5b947951a6ad93d8090493b Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 18:16:13 -0400 Subject: [PATCH 21/22] created a function to get ground truth array --- GANDLF/utils/__init__.py | 1 + GANDLF/utils/tensor.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/GANDLF/utils/__init__.py b/GANDLF/utils/__init__.py index b120701e8..20cbfeb7c 100644 --- a/GANDLF/utils/__init__.py +++ b/GANDLF/utils/__init__.py @@ -19,6 +19,7 @@ get_class_imbalance_weights_classification, get_linear_interpolation_mode, print_model_summary, + get_ground_truths_and_predictions_tensor, ) from .write_parse import ( diff --git a/GANDLF/utils/tensor.py b/GANDLF/utils/tensor.py index 681bca223..ed2d23016 100644 --- a/GANDLF/utils/tensor.py +++ b/GANDLF/utils/tensor.py @@ -419,3 +419,26 @@ def print_model_summary( ) temp_output = stats.to_readable(stats.total_mult_adds) print("\tTotal # of operations:", temp_output[1], temp_output[0]) + + +def get_ground_truths_and_predictions_tensor(params, loader_type): + """ + This function is used to get the ground truths and predictions for a given loader type. + + Args: + params (dict): The parameters passed by the user yaml. + loader_type (str): The loader type for which the ground truths and predictions are to be returned. + + Returns: + torch.Tensor, torch.Tensor: The ground truths and base predictions for the given loader type. + """ + ground_truth_array = torch.from_numpy( + params[loader_type][ + params[loader_type].columns[params["headers"]["predictionHeaders"]] + ] + .to_numpy() + .ravel() + ).type(torch.int) + predictions_array = torch.zeros_like(ground_truth_array) + + return ground_truth_array, predictions_array From 34fbc34ddd380411ac64f83478e8dd3f00e777e6 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Fri, 16 Sep 2022 18:16:53 -0400 Subject: [PATCH 22/22] using the new function --- GANDLF/compute/forward_pass.py | 15 +++++---------- GANDLF/compute/training_loop.py | 13 +++++-------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index ac88f7e72..0316d4e95 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -15,6 +15,7 @@ get_unique_timestamp, resample_image, reverse_one_hot, + get_ground_truths_and_predictions_tensor, ) from GANDLF.metrics import overall_stats from tqdm import tqdm @@ -102,16 +103,10 @@ def validate_network( # get ground truths for classification problem, validation set if is_classification and mode == "validation": - ground_truth_array = torch.from_numpy( - params["validation_data"][ - params["validation_data"].columns[ - params["headers"]["predictionHeaders"] - ] - ] - .to_numpy() - .ravel() - ).type(torch.int) - predictions_array = torch.zeros_like(ground_truth_array) + ( + ground_truth_array, + predictions_array, + ) = get_ground_truths_and_predictions_tensor(params, "validation_data") for batch_idx, (subject) in enumerate( tqdm(valid_dataloader, desc="Looping over " + mode + " data") diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 99f6c46ed..141d1894c 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -16,6 +16,7 @@ version_check, write_training_patches, print_model_summary, + get_ground_truths_and_predictions_tensor, ) from GANDLF.metrics import overall_stats from GANDLF.logger import Logger @@ -72,14 +73,10 @@ def train_network(model, train_dataloader, optimizer, params): # get ground truths if params["problem_type"] == "classification": - ground_truth_array = torch.from_numpy( - params["training_data"][ - params["training_data"].columns[params["headers"]["predictionHeaders"]] - ] - .to_numpy() - .ravel() - ).type(torch.int) - predictions_array = torch.zeros_like(ground_truth_array) + ( + ground_truth_array, + predictions_array, + ) = get_ground_truths_and_predictions_tensor(params, "training_data") # Set the model to train model.train() for batch_idx, (subject) in enumerate(