From dfa17cbd94858b79673b1e5006b9c95017b3367e Mon Sep 17 00:00:00 2001 From: Polina Turishcheva Date: Tue, 9 Jul 2024 17:21:39 +0200 Subject: [PATCH] pr review --- neuralpredictors/layers/readouts/gaussian.py | 31 ++++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/neuralpredictors/layers/readouts/gaussian.py b/neuralpredictors/layers/readouts/gaussian.py index 6e6f53b5..80ba2002 100644 --- a/neuralpredictors/layers/readouts/gaussian.py +++ b/neuralpredictors/layers/readouts/gaussian.py @@ -286,13 +286,17 @@ def __init__( self.mean_activity = mean_activity # determines whether the Gaussian is isotropic or not self.gauss_type = gauss_type - self.regularizer_type = regularizer_type + self._regularizer_type = regularizer_type - if self.regularizer_type == "adaptive_log_norm": + if self._regularizer_type == "adaptive_log_norm": self.gamma_sigma = gamma_sigma - self.adaptive = torch.nn.Parameter(torch.normal(mean=torch.ones(1, outdims), std=torch.ones(1, outdims))) - elif self.regularizer_type != "l1": - raise ValueError(f"regularizer_type should be 'l1' or 'adaptive_log_norm' but got {self.regularizer_type}") + self.adaptive_neuron_reg_coefs = torch.nn.Parameter( + torch.normal(mean=torch.ones(1, outdims), std=torch.ones(1, outdims)) + ) + elif self._regularizer_type != "l1": + raise ValueError( + f"_regularizer_type should be 'l1' or 'adaptive_log_norm' but got {self._regularizer_type}" + ) if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma <= 0.0: raise ValueError("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive") @@ -391,20 +395,21 @@ def feature_l1(self, reduction="sum", average=None): def adaptive_feature_l1_lognorm(self, reduction="sum", average=None): if self._original_features: - features = self.adaptive.abs() * self.features - return self.apply_reduction(features.abs(), reduction=reduction, average=average) + features = self.adaptive_neuron_reg_coefs.abs() * self.features + features_regularization = ( + self.apply_reduction(features.abs(), reduction=reduction, average=average) * self.feature_reg_weight + ) + # adaptive_neuron_reg_coefs (betas) are supposted to be from lognorm distribution + coef_prior = 1 / (self.gamma_sigma**2) * ((torch.log(self.adaptive_neuron_reg_coefs.abs()) ** 2).sum()) + return regularization_loss + coef_prior else: return 0 def regularizer(self, reduction="sum", average=None): - if self.regularizer_type == "l1": + if self._regularizer_type == "l1": return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight else: - new_gamma = self.feature_reg_weight - readout_reg = self.adaptive_l1_lognorm(reduction=reduction, average=average) * new_gamma - # gammas are supposted to be from lognorm distribution - gamma_prior = 1 / (self.gamma_sigma**2) * ((torch.log(self.adaptive.abs()) ** 2).sum()) - return readout_reg + gamma_prior + return self.adaptive_feature_l1_lognorm(reduction=reduction, average=average) def regularizer(self, reduction="sum", average=None): return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight