Skip to content

Commit

Permalink
Add gradient visualization during validation (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn authored Nov 2, 2024
1 parent 1d83a2c commit 45fccdc
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 0 deletions.
13 changes: 13 additions & 0 deletions luxonis_train/callbacks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ Callback to perform a test run at the end of the training.

Callback that uploads currently the best checkpoint (based on validation loss) to the tracker location - where all other logs are stored.

## `GradCamCallback`

Callback to visualize gradients using Grad-CAM. Works only during validation.

**Parameters:**

| Key | Type | Default value | Description |
| --------------- | ----- | ------------------ | ---------------------------------------------------- |
| `target_layer` | `int` | - | Layer to visualize gradients. |
| `class_idx` | `int` | 0 | Index of the class for visualization. Defaults to 0. |
| `log_n_batches` | `int` | 1 | Number of batches to log. Defaults to 1. |
| `task` | `str` | `"classification"` | The type of task. Defaults to "classification". |

## `EMACallback`

Callback that updates the stored parameters using a moving average.
Expand Down
3 changes: 3 additions & 0 deletions luxonis_train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .ema import EMACallback
from .export_on_train_end import ExportOnTrainEnd
from .gpu_stats_monitor import GPUStatsMonitor
from .gradcam_visializer import GradCamCallback
from .luxonis_progress_bar import (
BaseLuxonisProgressBar,
LuxonisRichProgressBar,
Expand All @@ -35,6 +36,7 @@
CALLBACKS.register_module(module=StochasticWeightAveraging)
CALLBACKS.register_module(module=Timer)
CALLBACKS.register_module(module=ModelPruning)
CALLBACKS.register_module(module=GradCamCallback)
CALLBACKS.register_module(module=EMACallback)


Expand All @@ -49,5 +51,6 @@
"TestOnTrainEnd",
"UploadCheckpoint",
"GPUStatsMonitor",
"GradCamCallback",
"EMACallback",
]
200 changes: 200 additions & 0 deletions luxonis_train/callbacks/gradcam_visializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import logging
from typing import Any, Union

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_grad_cam import HiResCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import (
ClassifierOutputTarget,
SemanticSegmentationTarget,
)
from pytorch_lightning.utilities.types import STEP_OUTPUT

from luxonis_train.attached_modules.visualizers import (
get_denormalized_images,
)

logger = logging.getLogger(__name__)


class ModelWrapper(pl.LightningModule):
def __init__(self, model: pl.LightningModule, task: str) -> None:
"""Constructs `ModelWrapper`.
@type model: pl.LightningModule
@param model: The model to be wrapped.
@type task: str
@param task: The type of task (e.g., segmentation, detection,
classification, keypoint_detection).
"""
super().__init__()
self.model = model
self.task = task

def forward(
self, inputs: torch.Tensor, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, Any]:
"""Forward pass through the model, returning the output based on
the task type.
@type inputs: torch.Tensor
@param inputs: Input tensor for the model.
@type args: Any
@param args: Additional positional arguments.
@type kwargs: Any
@param kwargs: Additional keyword arguments.
@rtype: Union[torch.Tensor, Any]
@return: The processed output based on the task type.
"""
input_dict = dict(image=inputs)
output = self.model(input_dict, *args, **kwargs)

if self.task == "segmentation":
return output.outputs["segmentation_head"]["segmentation"][0]
elif self.task == "detection":
scores = output.outputs["detection_head"]["class_scores"][0]
return scores.sum(dim=1)
elif self.task == "classification":
return output.outputs["classification_head"]["classification"][0]
elif self.task == "keypoint_detection":
scores = output.outputs["kpt_detection_head"]["class_scores"][0]
return scores.sum(dim=1)
else:
raise ValueError(f"Unknown task: {self.task}")


class GradCamCallback(pl.Callback):
"""Callback to visualize gradients using Grad-CAM (experimental).
Works only during validation.
"""

def __init__(
self,
target_layer: int,
class_idx: int = 0,
log_n_batches: int = 1,
task: str = "classification",
) -> None:
"""Constructs `GradCamCallback`.
@type target_layer: int
@param target_layer: Layer to visualize gradients.
@type class_idx: int | None
@param class_idx: Index of the class for visualization. Defaults
to None.
@type log_n_batches: int
@param log_n_batches: Number of batches to log. Defaults to 1.
@type task: str
@param task: The type of task. Defaults to "classification".
"""
super().__init__()
self.target_layer = target_layer
self.class_idx = class_idx
self.log_n_batches = log_n_batches
self.task = task

def setup(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
) -> None:
"""Initializes the model wrapper.
@type trainer: pl.Trainer
@param trainer: The PyTorch Lightning trainer.
@type pl_module: pl.LightningModule
@param pl_module: The PyTorch Lightning module.
@type stage: str
@param stage: The stage of the training loop.
"""

self.model = ModelWrapper(pl_module, self.task)

def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
"""At the end of first n batches, visualize the gradients using
Grad-CAM.
@type trainer: pl.Trainer
@param trainer: The PyTorch Lightning trainer.
@type pl_module: pl.LightningModule
@param pl_module: The PyTorch Lightning module.
@type outputs: STEP_OUTPUT
@param outputs: The output of the model.
@type batch: Any
@param batch: The input batch.
@type batch_idx: int
@param batch_idx: The index of the batch.
"""

if batch_idx < self.log_n_batches:
self.visualize_gradients(trainer, pl_module, batch, batch_idx)

def visualize_gradients(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
batch: Any,
batch_idx: int,
) -> None:
"""Visualizes the gradients using Grad-CAM.
@type trainer: pl.Trainer
@param trainer: The PyTorch Lightning trainer.
@type pl_module: pl.LightningModule
@param pl_module: The PyTorch Lightning module.
@type batch: Any
@param batch: The input batch.
@type batch_idx: int
@param batch_idx: The index of the batch.
"""

target_layers = [m[1] for m in pl_module.named_modules()][
self.target_layer : self.target_layer + 1
]
self.gradcam = HiResCAM(self.model, target_layers)

x, y = batch
model_input = x["image"]

if self.task == "segmentation":
output = self.model(model_input)
normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu()
mask = normalized_masks.argmax(dim=1).detach().cpu().numpy()
mask_float = (mask == self.class_idx).astype(np.float32)
targets = [
SemanticSegmentationTarget(self.class_idx, mask_float[i])
for i in range(mask_float.shape[0])
]
else:
targets = [
ClassifierOutputTarget(self.class_idx)
] * model_input.size(0)

with torch.enable_grad():
grayscale_cams = self.gradcam(
input_tensor=model_input,
targets=targets, # type: ignore
)

images = get_denormalized_images(pl_module.cfg, x).cpu().numpy()
for zip_idx, (image, grayscale_cam) in enumerate(
zip(images, grayscale_cams)
):
image = image / 255.0
image = image.transpose(1, 2, 0)
visualization = show_cam_on_image(
image, grayscale_cam, use_rgb=True
)
trainer.logger.log_image( # type: ignore
f"gradcam/gradcam_{batch_idx}_{zip_idx}",
visualization,
step=trainer.global_step,
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ typer>=0.9.0
mlflow>=2.10.0
psutil>=5.0.0
tabulate>=0.9.0
grad-cam>=1.5.4

0 comments on commit 45fccdc

Please sign in to comment.