Skip to content

Commit

Permalink
Merge pull request #500 from sarthakpati/overall_classification_metrics
Browse files Browse the repository at this point in the history
Included classification-specific metrics using overall predictions and ground truths
  • Loading branch information
AlexanderGetka-cbica authored Sep 19, 2022
2 parents ee4f82f + 34fbc34 commit b650b4c
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 8 deletions.
26 changes: 25 additions & 1 deletion GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
get_unique_timestamp,
resample_image,
reverse_one_hot,
get_ground_truths_and_predictions_tensor,
)
from GANDLF.metrics import overall_stats
from tqdm import tqdm


Expand Down Expand Up @@ -99,6 +101,13 @@ def validate_network(
"output_predictions_" + get_unique_timestamp() + ".csv",
)

# get ground truths for classification problem, validation set
if is_classification and mode == "validation":
(
ground_truth_array,
predictions_array,
) = get_ground_truths_and_predictions_tensor(params, "validation_data")

for batch_idx, (subject) in enumerate(
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
):
Expand Down Expand Up @@ -192,6 +201,11 @@ def validate_network(
final_loss, final_metric = get_loss_and_metrics(
image, valuesToPredict, pred_output, params
)

if is_classification and mode == "validation":
predictions_array[batch_idx] = (
torch.argmax(pred_output[0], 0).cpu().item()
)
# # Non network validation related
total_epoch_valid_loss += final_loss.detach().cpu().item()
for metric in final_metric.keys():
Expand Down Expand Up @@ -283,7 +297,7 @@ def validate_network(
attention_map, patches_batch[torchio.LOCATION]
)
else:
_, _, output = result
_, _, output, _ = result

if params["problem_type"] == "segmentation":
aggregator.add_batch(
Expand Down Expand Up @@ -359,6 +373,10 @@ def validate_network(
else:
# final regression output
output_prediction = output_prediction / len(patch_loader)
if is_classification and mode == "validation":
predictions_array[batch_idx] = (
torch.argmax(output_prediction[0], 0).cpu().item()
)
if params["save_output"]:
outputToWrite += (
str(epoch)
Expand Down Expand Up @@ -453,6 +471,11 @@ def validate_network(
if label_ground_truth is not None:
average_epoch_valid_loss = total_epoch_valid_loss / len(valid_dataloader)
print(" Epoch Final " + mode + " loss : ", average_epoch_valid_loss)
# get overall stats for classification
if is_classification and mode == "validation":
average_epoch_valid_metric = overall_stats(
predictions_array, ground_truth_array, params
)
for metric in params["metrics"]:
if isinstance(total_epoch_valid_metric[metric], np.ndarray):
to_print = (
Expand All @@ -461,6 +484,7 @@ def validate_network(
else:
to_print = total_epoch_valid_metric[metric] / len(valid_dataloader)
average_epoch_valid_metric[metric] = to_print
for metric in average_epoch_valid_metric.keys():
print(
" Epoch Final " + mode + " " + metric + " : ",
average_epoch_valid_metric[metric],
Expand Down
6 changes: 2 additions & 4 deletions GANDLF/compute/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def step(model, image, label, params, train=True):
else:
output = model(image)

attention_map = None
if "medcam_enabled" in params and params["medcam_enabled"]:
output, attention_map = output

Expand All @@ -97,7 +98,4 @@ def step(model, image, label, params, train=True):
if "medcam_enabled" in params and params["medcam_enabled"]:
attention_map = torch.unsqueeze(attention_map, -1)

if not ("medcam_enabled" in params and params["medcam_enabled"]):
return loss, metric_output, output
else:
return loss, metric_output, output, attention_map
return loss, metric_output, output, attention_map
25 changes: 24 additions & 1 deletion GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
version_check,
write_training_patches,
print_model_summary,
get_ground_truths_and_predictions_tensor,
)
from GANDLF.metrics import overall_stats
from GANDLF.logger import Logger
from .step import step
from .forward_pass import validate_network
Expand Down Expand Up @@ -69,6 +71,12 @@ def train_network(model, train_dataloader, optimizer, params):
if params["verbose"]:
print("Using Automatic mixed precision", flush=True)

# get ground truths
if params["problem_type"] == "classification":
(
ground_truth_array,
predictions_array,
) = get_ground_truths_and_predictions_tensor(params, "training_data")
# Set the model to train
model.train()
for batch_idx, (subject) in enumerate(
Expand Down Expand Up @@ -104,7 +112,15 @@ def train_network(model, train_dataloader, optimizer, params):
params["subject_spacing"] = subject["spacing"]
else:
params["subject_spacing"] = None
loss, calculated_metrics, _ = step(model, image, label, params)
loss, calculated_metrics, output, _ = step(model, image, label, params)
# store predictions for classification
if params["problem_type"] == "classification":
predictions_array[
batch_idx
* params["batch_size"] : (batch_idx + 1)
* params["batch_size"]
] = (torch.argmax(output[0], 0).cpu().item())

nan_loss = torch.isnan(loss)
second_order = (
hasattr(optimizer, "is_second_order") and optimizer.is_second_order
Expand Down Expand Up @@ -175,6 +191,12 @@ def train_network(model, train_dataloader, optimizer, params):

average_epoch_train_loss = total_epoch_train_loss / len(train_dataloader)
print(" Epoch Final train loss : ", average_epoch_train_loss)

# get overall stats for classification
if params["problem_type"] == "classification":
average_epoch_train_metric = overall_stats(
predictions_array, ground_truth_array, params
)
for metric in params["metrics"]:
if isinstance(total_epoch_train_metric[metric], np.ndarray):
to_print = (
Expand All @@ -183,6 +205,7 @@ def train_network(model, train_dataloader, optimizer, params):
else:
to_print = total_epoch_train_metric[metric] / len(train_dataloader)
average_epoch_train_metric[metric] = to_print
for metric in average_epoch_train_metric.keys():
print(
" Epoch Final train " + metric + " : ",
average_epoch_train_metric[metric],
Expand Down
3 changes: 2 additions & 1 deletion GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from .regression import classification_accuracy, balanced_acc_score, per_label_accuracy
from .generic import recall_score, precision_score, iou_score, f1_score, accuracy
from .classification import overall_stats


# global defines for the metrics
Expand All @@ -35,5 +36,5 @@
"recall": recall_score,
"iou": iou_score,
"balanced_accuracy": balanced_acc_score,
"per_label_accuracy": per_label_accuracy,
"per_label_one_hot_accuracy": per_label_accuracy,
}
86 changes: 86 additions & 0 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torchmetrics as tm


def get_output_from_calculator(predictions, ground_truth, calculator):
"""
Helper function to get the output from a calculator.
Args:
predictions (torch.Tensor): The output of the model.
ground_truth (torch.Tensor): The ground truth labels.
calculator (torchmetrics.Metric): The calculator to use.
Returns:
float: The output from the calculator.
"""
temp_output = calculator(predictions, ground_truth)
if temp_output.dim() > 0:
temp_output = temp_output.cpu().tolist()
else:
temp_output = temp_output.cpu().item()
return temp_output


def overall_stats(predictions, ground_truth, params):
"""
Generates a dictionary of metrics calculated on the overall predictions and ground truths.
Args:
predictions (torch.Tensor): The output of the model.
ground_truth (torch.Tensor): The ground truth labels.
params (dict): The parameter dictionary containing training and data information.
Returns:
dict: A dictionary of metrics.
"""
assert (
params["problem_type"] == "classification"
), "Only classification is supported for overall stats"
assert len(predictions) == len(
ground_truth
), "Predictions and ground truth must be of same length"

output_metrics = {}

average_types_keys = {
"global": "micro",
"per_class": "none",
"per_class_average": "macro",
"per_class_weighted": "weighted",
}
# metrics that need the "average" parameter
for average_type, average_type_key in average_types_keys.items():
calculators = {
"accuracy": tm.Accuracy(
num_classes=params["model"]["num_classes"], average=average_type_key
),
"precision": tm.Precision(
num_classes=params["model"]["num_classes"], average=average_type_key
),
"recall": tm.Recall(
num_classes=params["model"]["num_classes"], average=average_type_key
),
"f1": tm.F1(
num_classes=params["model"]["num_classes"], average=average_type_key
),
## weird error for multi-class problem, where pos_label is not getting set
# "aucroc": tm.AUROC(
# num_classes=params["model"]["num_classes"], average=average_type_key
# ),
}
for metric_name, calculator in calculators.items():
output_metrics[
f"{metric_name}_{average_type}"
] = get_output_from_calculator(predictions, ground_truth, calculator)
# metrics that do not have any "average" parameter
calculators = {
"auc": tm.AUC(reorder=True),
## weird error for multi-class problem, where pos_label is not getting set
# "roc": tm.ROC(num_classes=params["model"]["num_classes"]),
}
for metric_name, calculator in calculators.items():
output_metrics[metric_name] = get_output_from_calculator(
predictions, ground_truth, calculator
)

return output_metrics
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_class_imbalance_weights_classification,
get_linear_interpolation_mode,
print_model_summary,
get_ground_truths_and_predictions_tensor,
)

from .write_parse import (
Expand Down
23 changes: 23 additions & 0 deletions GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,26 @@ def print_model_summary(
)
temp_output = stats.to_readable(stats.total_mult_adds)
print("\tTotal # of operations:", temp_output[1], temp_output[0])


def get_ground_truths_and_predictions_tensor(params, loader_type):
"""
This function is used to get the ground truths and predictions for a given loader type.
Args:
params (dict): The parameters passed by the user yaml.
loader_type (str): The loader type for which the ground truths and predictions are to be returned.
Returns:
torch.Tensor, torch.Tensor: The ground truths and base predictions for the given loader type.
"""
ground_truth_array = torch.from_numpy(
params[loader_type][
params[loader_type].columns[params["headers"]["predictionHeaders"]]
]
.to_numpy()
.ravel()
).type(torch.int)
predictions_array = torch.zeros_like(ground_truth_array)

return ground_truth_array, predictions_array
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Per class accuracy has been added as a metric
- Dedicated rescaling preprocessing function added for increased flexibility
- Largest Connected Component Analysis is now added
- Included metrics using overall predictions and ground truths

## 0.0.14

Expand Down
2 changes: 1 addition & 1 deletion testing/config_classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ metrics:
}
- accuracy
- balanced_accuracy
- per_label_accuracy
- per_label_one_hot_accuracy
- precision: {
average: weighted,
}
Expand Down

0 comments on commit b650b4c

Please sign in to comment.