Skip to content

Commit

Permalink
Dynamically move mask to device in SemanticSegmentationTarget (#546)
Browse files Browse the repository at this point in the history
Signed-off-by: Shreyas Ranganatha <[email protected]>
  • Loading branch information
shreyass-ranganatha authored Dec 12, 2024
1 parent 0d86e52 commit 5cef718
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions pytorch_grad_cam/utils/model_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,9 @@ class SemanticSegmentationTarget:
def __init__(self, category, mask):
self.category = category
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
if torch.backends.mps.is_available():
self.mask = self.mask.to("mps")

def __call__(self, model_output):
return (model_output[self.category, :, :] * self.mask).sum()
return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum()


class FasterRCNNBoxScoreTarget:
Expand Down

0 comments on commit 5cef718

Please sign in to comment.