Skip to content

Commit

Permalink
Merge pull request #498 from sarthakpati/497_multi-batch_per_label_ac…
Browse files Browse the repository at this point in the history
…curacy
  • Loading branch information
Geeks-Sid authored Sep 16, 2022
2 parents f577171 + 0e24f77 commit ee4f82f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions GANDLF/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ def per_label_accuracy(output, label, params):
torch.Tensor: The per class accuracy.
"""
if params["problem_type"] == "classification":
predicted_classes = np.array([0] * len(params["model"]["class_list"]))
label_cpu = np.array([0] * len(params["model"]["class_list"]))
predicted_classes[torch.argmax(output, 1).cpu().item()] = 1
label_cpu[label.cpu().item()] = 1
return torch.from_numpy((predicted_classes == label_cpu).astype(float))
# ensure this works for multiple batches
output_accuracy = torch.zeros(len(params["model"]["class_list"]))
for output_batch, label_batch in zip(output, label):
predicted_classes = torch.Tensor([0] * len(params["model"]["class_list"]))
label_cpu = torch.Tensor([0] * len(params["model"]["class_list"]))
predicted_classes[torch.argmax(output_batch, 0).cpu().item()] = 1
label_cpu[label_batch.cpu().item()] = 1
output_accuracy += (predicted_classes == label_cpu).type(torch.float)
return output_accuracy / len(output)
else:
return balanced_acc_score(output, label, params)

0 comments on commit ee4f82f

Please sign in to comment.