Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cai2-huaiguang committed Jan 19, 2025
1 parent c25ca78 commit 9f2d539
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(
loss.backward(retain_graph=True)
else:
# keep the computational graph, create_graph = True is needed for hvp
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
# loss.backward(retain_graph=True, create_graph=True)
if 'hpu' in str(self.device):
Expand Down
4 changes: 1 addition & 3 deletions pytorch_grad_cam/shapley_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

"""
Weighting the activation maps using Gradient and Hessian-Vector Product.
Weights the activation maps using the gradient and Hessian-Vector product.
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
"""
class ShapleyCAM(BaseCAM):
Expand Down Expand Up @@ -51,12 +51,10 @@ def get_cam_weights(self,
if len(activations.shape) == 4:
weight = np.mean(weight, axis=(2, 3))
return weight

# 3D image
elif len(activations.shape) == 5:
weight = np.mean(weight, axis=(2, 3, 4))
return weight

else:
raise ValueError("Invalid grads shape."
"Shape of grads should be 4 (2D image) or 5 (3D image).")

0 comments on commit 9f2d539

Please sign in to comment.