Skip to content

Commit

Permalink
Merge pull request #241 from pollytur/core3d_bn_fix
Browse files Browse the repository at this point in the history
batch norms fix for 3d core
  • Loading branch information
KonstantinWilleke authored May 30, 2024
2 parents 504a5a6 + b779fde commit 5947994
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
4 changes: 3 additions & 1 deletion neuralpredictors/layers/cores/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict

Expand Down Expand Up @@ -77,7 +78,8 @@ def add_bn_layer(self, layer: OrderedDict, layer_idx: int):
raise NotImplementedError(f"Subclasses must have a `{attr}` attribute.")
for attr in ["batch_norm", "hidden_channels", "bias", "batch_norm_scale"]:
if not isinstance(getattr(self, attr), list):
raise ValueError(f"`{attr}` must be a list.")
setattr(self, attr, [getattr(self, attr)] * self.layers)
warnings.warn(f"The {attr} is applied to all layers", UserWarning)

if self.batch_norm[layer_idx]:
hidden_channels = self.hidden_channels[layer_idx]
Expand Down
11 changes: 7 additions & 4 deletions neuralpredictors/layers/cores/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
padding=(0, input_kernel[1] // 2, input_kernel[2] // 2) if self.padding else 0,
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
self.add_bn_layer(layer=layer, layer_idx=0)

if layers > 1 or self.final_nonlinearity:
if hidden_nonlinearities == "adaptive_elu":
Expand All @@ -185,7 +185,7 @@ def __init__(
padding=(0, self.hidden_kernel[l][1] // 2, self.hidden_kernel[l][2] // 2) if self.padding else 0,
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
self.add_bn_layer(layer=layer, layer_idx=l + 1)

if self.final_nonlinearity or l < self.layers:
if hidden_nonlinearities == "adaptive_elu":
Expand Down Expand Up @@ -363,7 +363,10 @@ def __init__(
dilation=(self.temporal_dilation, 1, 1),
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[0])
self.add_bn_layer(
layer=layer,
layer_idx=0,
)

if layers > 1 or final_nonlin:
if hidden_nonlinearities == "adaptive_elu":
Expand Down Expand Up @@ -394,7 +397,7 @@ def __init__(
dilation=(self.hidden_temporal_dilation[l], 1, 1),
)

self.add_bn_layer(layer=layer, hidden_channels=self.hidden_channels[l + 1])
self.add_bn_layer(layer=layer, layer_idx=l + 1)

if final_nonlin or l < self.layers:
if hidden_nonlinearities == "adaptive_elu":
Expand Down

0 comments on commit 5947994

Please sign in to comment.