diff --git a/models/encoder.py b/models/encoder.py index 5b7ba6e..1726f47 100644 --- a/models/encoder.py +++ b/models/encoder.py @@ -30,6 +30,12 @@ def __init__( } self.encoder = timm.create_model(**model_kwargs) + pixel_mean = torch.tensor(self.encoder.default_cfg["mean"]).reshape(1, -1, 1, 1) + pixel_std = torch.tensor(self.encoder.default_cfg["std"]).reshape(1, -1, 1, 1) + + self.register_buffer("pixel_mean", pixel_mean) + self.register_buffer("pixel_std", pixel_std) + self.grid_size = tuple(round(size / patch_size) for size in img_size) self.embed_dim = ( @@ -176,6 +182,8 @@ def interpolate_rel_pos( return nn.Parameter(rel_pos) def forward(self, x: torch.Tensor): + x = (x - self.pixel_mean) / self.pixel_std + x = self.encoder.forward_features(x) if x.dim() == 4: diff --git a/training/lightning_module.py b/training/lightning_module.py index f483d7a..5f20a86 100644 --- a/training/lightning_module.py +++ b/training/lightning_module.py @@ -4,7 +4,6 @@ import torch.nn as nn from torch.optim import AdamW from torchmetrics.classification import MulticlassJaccardIndex -from torchmetrics.detection import PanopticQuality from PIL import Image import matplotlib.colors as mcolors from matplotlib.lines import Line2D @@ -12,7 +11,7 @@ import matplotlib.pyplot as plt import numpy as np from torch.nn.functional import interpolate -from torchvision.transforms.v2.functional import resize, pad +from torchvision.transforms.v2.functional import resize class LightningModule(lightning.LightningModule): @@ -24,8 +23,6 @@ def __init__( weight_decay: float, lr: float, lr_multiplier_encoder: float, - pixel_mean=(123.675, 116.28, 103.53), - pixel_std=(58.395, 57.12, 57.375), ): super().__init__() @@ -35,12 +32,6 @@ def __init__( self.weight_decay = weight_decay self.lr_multiplier_encoder = lr_multiplier_encoder - pixel_mean = torch.tensor(pixel_mean).reshape(1, -1, 1, 1) - pixel_std = torch.tensor(pixel_std).reshape(1, -1, 1, 1) - - self.register_buffer("pixel_mean", pixel_mean, persistent=False) - self.register_buffer("pixel_std", pixel_std, persistent=False) - for param in self.network.encoder.parameters(): param.requires_grad = not freeze_encoder @@ -68,8 +59,8 @@ def update_metrics( preds[i][None, ...], targets[i][None, ...] ) - def forward(self, x): - x = (x - self.pixel_mean) / self.pixel_std + def forward(self, imgs): + x = imgs / 255.0 output = self.network(x)