From b779fdea97e582cbcddb4d097dad50c9390193dc Mon Sep 17 00:00:00 2001 From: Polina Turishcheva Date: Tue, 28 May 2024 15:14:18 +0200 Subject: [PATCH] batch norms fix for 3d core --- neuralpredictors/layers/cores/base.py | 4 +++- neuralpredictors/layers/cores/conv3d.py | 11 +++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/neuralpredictors/layers/cores/base.py b/neuralpredictors/layers/cores/base.py index 001d2d3b..5e41c728 100644 --- a/neuralpredictors/layers/cores/base.py +++ b/neuralpredictors/layers/cores/base.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC, abstractmethod from collections import OrderedDict @@ -77,7 +78,8 @@ def add_bn_layer(self, layer: OrderedDict, layer_idx: int): raise NotImplementedError(f"Subclasses must have a `{attr}` attribute.") for attr in ["batch_norm", "hidden_channels", "bias", "batch_norm_scale"]: if not isinstance(getattr(self, attr), list): - raise ValueError(f"`{attr}` must be a list.") + setattr(self, attr, [getattr(self, attr)] * self.layers) + warnings.warn(f"The {attr} is applied to all layers", UserWarning) if self.batch_norm[layer_idx]: hidden_channels = self.hidden_channels[layer_idx] diff --git a/neuralpredictors/layers/cores/conv3d.py b/neuralpredictors/layers/cores/conv3d.py index cd39d263..8b8fe4c5 100644 --- a/neuralpredictors/layers/cores/conv3d.py +++ b/neuralpredictors/layers/cores/conv3d.py @@ -160,7 +160,7 @@ def __init__( padding=(0, input_kernel[1] // 2, input_kernel[2] // 2) if self.padding else 0, ) - self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0]) + self.add_bn_layer(layer=layer, layer_idx=0) if layers > 1 or self.final_nonlinearity: if hidden_nonlinearities == "adaptive_elu": @@ -185,7 +185,7 @@ def __init__( padding=(0, self.hidden_kernel[l][1] // 2, self.hidden_kernel[l][2] // 2) if self.padding else 0, ) - self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1]) + self.add_bn_layer(layer=layer, layer_idx=l + 1) if self.final_nonlinearity or l < self.layers: if hidden_nonlinearities == "adaptive_elu": @@ -363,7 +363,10 @@ def __init__( dilation=(self.temporal_dilation, 1, 1), ) - self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0]) + self.add_bn_layer( + layer=layer, + layer_idx=0, + ) if layers > 1 or final_nonlin: if hidden_nonlinearities == "adaptive_elu": @@ -394,7 +397,7 @@ def __init__( dilation=(self.hidden_temporal_dilation[l], 1, 1), ) - self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1]) + self.add_bn_layer(layer=layer, layer_idx=l + 1) if final_nonlin or l < self.layers: if hidden_nonlinearities == "adaptive_elu":