diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index c54b2a29..484e8865 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -111,7 +111,7 @@ def forward( loss.backward(retain_graph=True) else: # keep the computational graph, create_graph = True is needed for hvp - torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) + torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle" # loss.backward(retain_graph=True, create_graph=True) if 'hpu' in str(self.device): diff --git a/pytorch_grad_cam/shapley_cam.py b/pytorch_grad_cam/shapley_cam.py index 9398e8fc..e8331528 100644 --- a/pytorch_grad_cam/shapley_cam.py +++ b/pytorch_grad_cam/shapley_cam.py @@ -5,7 +5,7 @@ import numpy as np """ -Weighting the activation maps using Gradient and Hessian-Vector Product. +Weights the activation maps using the gradient and Hessian-Vector product. This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective. """ class ShapleyCAM(BaseCAM): @@ -51,12 +51,10 @@ def get_cam_weights(self, if len(activations.shape) == 4: weight = np.mean(weight, axis=(2, 3)) return weight - # 3D image elif len(activations.shape) == 5: weight = np.mean(weight, axis=(2, 3, 4)) return weight - else: raise ValueError("Invalid grads shape." "Shape of grads should be 4 (2D image) or 5 (3D image).")