Skip to content

Commit

Permalink
pr review
Browse files Browse the repository at this point in the history
  • Loading branch information
pollytur committed Jul 9, 2024
1 parent 2fe6121 commit dfa17cb
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions neuralpredictors/layers/readouts/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,17 @@ def __init__(
self.mean_activity = mean_activity
# determines whether the Gaussian is isotropic or not
self.gauss_type = gauss_type
self.regularizer_type = regularizer_type
self._regularizer_type = regularizer_type

if self.regularizer_type == "adaptive_log_norm":
if self._regularizer_type == "adaptive_log_norm":
self.gamma_sigma = gamma_sigma
self.adaptive = torch.nn.Parameter(torch.normal(mean=torch.ones(1, outdims), std=torch.ones(1, outdims)))
elif self.regularizer_type != "l1":
raise ValueError(f"regularizer_type should be 'l1' or 'adaptive_log_norm' but got {self.regularizer_type}")
self.adaptive_neuron_reg_coefs = torch.nn.Parameter(
torch.normal(mean=torch.ones(1, outdims), std=torch.ones(1, outdims))
)
elif self._regularizer_type != "l1":
raise ValueError(
f"_regularizer_type should be 'l1' or 'adaptive_log_norm' but got {self._regularizer_type}"
)

if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma <= 0.0:
raise ValueError("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive")
Expand Down Expand Up @@ -391,20 +395,21 @@ def feature_l1(self, reduction="sum", average=None):

def adaptive_feature_l1_lognorm(self, reduction="sum", average=None):
if self._original_features:
features = self.adaptive.abs() * self.features
return self.apply_reduction(features.abs(), reduction=reduction, average=average)
features = self.adaptive_neuron_reg_coefs.abs() * self.features
features_regularization = (
self.apply_reduction(features.abs(), reduction=reduction, average=average) * self.feature_reg_weight
)
# adaptive_neuron_reg_coefs (betas) are supposted to be from lognorm distribution
coef_prior = 1 / (self.gamma_sigma**2) * ((torch.log(self.adaptive_neuron_reg_coefs.abs()) ** 2).sum())
return regularization_loss + coef_prior
else:
return 0

def regularizer(self, reduction="sum", average=None):
if self.regularizer_type == "l1":
if self._regularizer_type == "l1":
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight
else:
new_gamma = self.feature_reg_weight
readout_reg = self.adaptive_l1_lognorm(reduction=reduction, average=average) * new_gamma
# gammas are supposted to be from lognorm distribution
gamma_prior = 1 / (self.gamma_sigma**2) * ((torch.log(self.adaptive.abs()) ** 2).sum())
return readout_reg + gamma_prior
return self.adaptive_feature_l1_lognorm(reduction=reduction, average=average)

def regularizer(self, reduction="sum", average=None):
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight
Expand Down

0 comments on commit dfa17cb

Please sign in to comment.