From 6a6a46a8b3757c2e19154e4645fc4e813e8b41c2 Mon Sep 17 00:00:00 2001 From: AFThielmann Date: Fri, 2 Aug 2024 11:37:20 +0200 Subject: [PATCH] adding hotfix for mamba arch --- mambular/arch_utils/mamba_arch.py | 51 ++++++++++++++- mambular/base_models/mambular.py | 25 +++++-- mambular/configs/mambular_config.py | 18 +++-- mambular/models/sklearn_base_classifier.py | 2 +- mambular/models/sklearn_base_lss.py | 76 +++++++++++----------- mambular/models/sklearn_base_regressor.py | 70 ++++++++++---------- 6 files changed, 158 insertions(+), 84 deletions(-) diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_arch.py index 3db39ed..6417c7f 100644 --- a/mambular/arch_utils/mamba_arch.py +++ b/mambular/arch_utils/mamba_arch.py @@ -43,6 +43,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=True, ): super().__init__() @@ -66,6 +69,9 @@ def __init__( activation, bidirectional, use_learnable_interaction, + layer_norm_eps, + AB_weight_decay, + AB_layer_norm, ) for _ in range(n_layers) ] @@ -105,6 +111,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() @@ -149,8 +158,11 @@ def __init__( activation=activation, bidirectional=bidirectional, use_learnable_interaction=use_learnable_interaction, + layer_norm_eps=layer_norm_eps, + AB_weight_decay=AB_weight_decay, + AB_layer_norm=AB_layer_norm, ) - self.norm = norm(d_model) + self.norm = norm(d_model, eps=layer_norm_eps) def forward(self, x): output = self.layers(self.norm(x)) + x @@ -189,6 +201,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() self.d_inner = d_model * expand_factor @@ -239,6 +254,7 @@ def __init__( elif dt_init == "random": nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std) if self.bidirectional: + nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError @@ -262,17 +278,35 @@ def __init__( A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log_fwd = nn.Parameter(torch.log(A)) + self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) + if self.bidirectional: self.A_log_bwd = nn.Parameter(torch.log(A)) + self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_fwd._no_weight_decay = True + self.D_fwd._no_weight_decay = True - self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) if self.bidirectional: - self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_bwd._no_weight_decay = True + self.D_bwd._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) self.dt_rank = dt_rank self.d_state = d_state + if AB_layer_norm: + self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) + self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + else: + self.dt_layernorm = None + self.B_layernorm = None + self.C_layernorm = None + def forward(self, x): _, L, _ = x.shape @@ -316,6 +350,15 @@ def forward(self, x): return output + def _apply_layernorms(self, dt, B, C): + if self.dt_layernorm is not None: + dt = self.dt_layernorm(dt) + if self.B_layernorm is not None: + B = self.B_layernorm(B) + if self.C_layernorm is not None: + C = self.C_layernorm(C) + return dt, B, C + def ssm(self, x, forward=True): if forward: A = -torch.exp(self.A_log_fwd.float()) @@ -324,6 +367,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_fwd(delta)) else: A = -torch.exp(self.A_log_bwd.float()) @@ -332,6 +376,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_bwd(delta)) y = self.selective_scan_seq(x, delta, A, B, C, D) diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index de012d7..b690851 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -109,19 +109,34 @@ def __init__( use_learnable_interaction=self.hparams.get( "use_learnable_interactions", config.use_learnable_interaction ), + AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay), + AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm), + layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), ) norm_layer = self.hparams.get("norm", config.norm) if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = RMSNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = LayerNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = BatchNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) + self.norm_f = InstanceNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) + self.norm_f = GroupNorm( + 1, + self.hparams.get("d_model", config.d_model), + eps=config.layer_norm_eps, + ) elif norm_layer == "LearnableLayerScaling": self.norm_f = LearnableLayerScaling( self.hparams.get("d_model", config.d_model) diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index 6075e31..c9b8afa 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -49,8 +49,8 @@ class DefaultMambularConfig: Normalization method to be used. activation : callable, default=nn.SELU() Activation function for the model. - num_embedding_activation : callable, default=nn.Identity() - Activation function for numerical embeddings. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. head_layer_sizes : list, default=(128, 64, 32) Sizes of the layers in the head of the model. head_dropout : float, default=0.5 @@ -70,7 +70,13 @@ class DefaultMambularConfig: use_learnable_interaction : bool, default=False Whether to use learnable feature interactions before passing through mamba blocks. use_cls : bool, default=True - Whether to append a cls to the beginning of each 'sequence'. + Whether to append a cls to the end of each 'sequence'. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AB_weight_decay : bool, default=False + wether weight decay is also applied to A-B matrices """ lr: float = 1e-04 @@ -93,7 +99,7 @@ class DefaultMambularConfig: dt_init_floor: float = 1e-04 norm: str = "LayerNorm" activation: callable = nn.SiLU() - num_embedding_activation: callable = nn.Identity() + embedding_activation: callable = nn.Identity() head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False @@ -104,3 +110,7 @@ class DefaultMambularConfig: bidirectional: bool = False use_learnable_interaction: bool = False use_cls: bool = False + shuffle_embeddings: bool = False + layer_norm_eps: float = 1e-05 + AB_weight_decay: bool = False + AB_layer_norm: bool = True diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 4a6935b..d0d5027 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -315,7 +315,7 @@ def fit( self : object The fitted classifier. """ - if not self.built and not rebuild: + if (not self.built) or (self.built and rebuild): if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) if isinstance(y, pd.Series): diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 62f2d3a..506f04e 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -282,6 +282,7 @@ def fit( checkpoint_path="model_checkpoints", distributional_kwargs=None, dataloader_kwargs={}, + rebuild=True, **trainer_kwargs ): """ @@ -357,45 +358,46 @@ def fit( else: raise ValueError("Unsupported family: {}".format(family)) - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) + if (not self.built) or (self.built and rebuild): + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=True, + **dataloader_kwargs + ) - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) - self.model = TaskModel( - model_class=self.base_model, - num_classes=self.family.param_count, - family=self.family, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - lss=True, - ) + self.model = TaskModel( + model_class=self.base_model, + num_classes=self.family.param_count, + family=self.family, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + lss=True, + ) early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index 1f57d58..bdcb5b8 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -257,6 +257,7 @@ def fit( weight_decay: float = 1e-06, checkpoint_path="model_checkpoints", dataloader_kwargs={}, + rebuild=True, **trainer_kwargs ): """ @@ -308,42 +309,43 @@ def fit( self : object The fitted regressor. """ - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) + if (not self.built) or (self.built and rebuild): + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=True, + **dataloader_kwargs + ) - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) - self.model = TaskModel( - model_class=self.base_model, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) + self.model = TaskModel( + model_class=self.base_model, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode