Skip to content

Commit

Permalink
fix normalization per model
Browse files Browse the repository at this point in the history
  • Loading branch information
tommiekerssies committed Jun 7, 2024
1 parent 6527c5a commit f0e75bb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
8 changes: 8 additions & 0 deletions models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def __init__(
}
self.encoder = timm.create_model(**model_kwargs)

pixel_mean = torch.tensor(self.encoder.default_cfg["mean"]).reshape(1, -1, 1, 1)
pixel_std = torch.tensor(self.encoder.default_cfg["std"]).reshape(1, -1, 1, 1)

self.register_buffer("pixel_mean", pixel_mean)
self.register_buffer("pixel_std", pixel_std)

self.grid_size = tuple(round(size / patch_size) for size in img_size)

self.embed_dim = (
Expand Down Expand Up @@ -176,6 +182,8 @@ def interpolate_rel_pos(
return nn.Parameter(rel_pos)

def forward(self, x: torch.Tensor):
x = (x - self.pixel_mean) / self.pixel_std

x = self.encoder.forward_features(x)

if x.dim() == 4:
Expand Down
15 changes: 3 additions & 12 deletions training/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import torch.nn as nn
from torch.optim import AdamW
from torchmetrics.classification import MulticlassJaccardIndex
from torchmetrics.detection import PanopticQuality
from PIL import Image
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
import io
import matplotlib.pyplot as plt
import numpy as np
from torch.nn.functional import interpolate
from torchvision.transforms.v2.functional import resize, pad
from torchvision.transforms.v2.functional import resize


class LightningModule(lightning.LightningModule):
Expand All @@ -24,8 +23,6 @@ def __init__(
weight_decay: float,
lr: float,
lr_multiplier_encoder: float,
pixel_mean=(123.675, 116.28, 103.53),
pixel_std=(58.395, 57.12, 57.375),
):
super().__init__()

Expand All @@ -35,12 +32,6 @@ def __init__(
self.weight_decay = weight_decay
self.lr_multiplier_encoder = lr_multiplier_encoder

pixel_mean = torch.tensor(pixel_mean).reshape(1, -1, 1, 1)
pixel_std = torch.tensor(pixel_std).reshape(1, -1, 1, 1)

self.register_buffer("pixel_mean", pixel_mean, persistent=False)
self.register_buffer("pixel_std", pixel_std, persistent=False)

for param in self.network.encoder.parameters():
param.requires_grad = not freeze_encoder

Expand Down Expand Up @@ -68,8 +59,8 @@ def update_metrics(
preds[i][None, ...], targets[i][None, ...]
)

def forward(self, x):
x = (x - self.pixel_mean) / self.pixel_std
def forward(self, imgs):
x = imgs / 255.0

output = self.network(x)

Expand Down

0 comments on commit f0e75bb

Please sign in to comment.