Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Util fixes #202

Merged
merged 18 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions mambular/arch_utils/layer_utils/embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
super().__init__()

self.d_model = getattr(config, "d_model", 128)
self.embedding_activation = getattr(config, "embedding_activation", nn.Identity())
self.layer_norm_after_embedding = getattr(config, "layer_norm_after_embedding", False)
self.embedding_activation = getattr(
config, "embedding_activation", nn.Identity()
)
self.layer_norm_after_embedding = getattr(
config, "layer_norm_after_embedding", False
)
self.use_cls = getattr(config, "use_cls", False)
self.cls_position = getattr(config, "cls_position", 0)
self.embedding_dropout = (
Expand Down Expand Up @@ -71,22 +75,26 @@ def __init__(self, num_feature_info, cat_feature_info, config):
# for splines and other embeddings
# splines followed by linear if n_knots actual knots is less than the defined knots
else:
raise ValueError("Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'.")
raise ValueError(
"Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'."
)

self.cat_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Embedding(feature_info["categories"] + 1, self.d_model),
self.embedding_activation,
)
if feature_info["dimension"] == 1
else nn.Sequential(
nn.Linear(
feature_info["dimension"],
self.d_model,
bias=self.embedding_bias,
),
self.embedding_activation,
(
nn.Sequential(
nn.Embedding(feature_info["categories"] + 1, self.d_model),
self.embedding_activation,
)
if feature_info["dimension"] == 1
else nn.Sequential(
nn.Linear(
feature_info["dimension"],
self.d_model,
bias=self.embedding_bias,
),
self.embedding_activation,
)
)
for feature_name, feature_info in cat_feature_info.items()
]
Expand Down Expand Up @@ -124,17 +132,17 @@ def forward(self, num_features=None, cat_features=None):
# Class token initialization
if self.use_cls:
batch_size = (
cat_features[0].size( # type: ignore
0
)
cat_features[0].size(0) # type: ignore
if cat_features != []
else num_features[0].size(0) # type: ignore
) # type: ignore
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

# Process categorical embeddings
if self.cat_embeddings and cat_features is not None:
cat_embeddings = [emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)]
cat_embeddings = [
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
Expand Down Expand Up @@ -182,7 +190,9 @@ def forward(self, num_features=None, cat_features=None):
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1) # type: ignore
else:
raise ValueError("Invalid cls_position value. It should be either 0 or 1.")
raise ValueError(
"Invalid cls_position value. It should be either 0 or 1."
)

# Apply dropout to embeddings if specified in config
if self.embedding_dropout is not None:
Expand Down
36 changes: 34 additions & 2 deletions mambular/base_models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]):
List of keys to ignore while saving hyperparameters, by default [].
"""
# Filter the config and extra hparams for ignored keys
config_hparams = {k: v for k, v in vars(self.config).items() if k not in ignore} if self.config else {}
config_hparams = (
{k: v for k, v in vars(self.config).items() if k not in ignore}
if self.config
else {}
)
extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore}
config_hparams.update(extra_hparams)

Expand Down Expand Up @@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs):
"""Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method."""
if self.hparams.pooling_method == "learned_flatten":
# Flattening + Linear layer
self.learned_flatten_pooling = nn.Linear(n_inputs * config.dim_feedforward, config.dim_feedforward)
self.learned_flatten_pooling = nn.Linear(
n_inputs * config.dim_feedforward, config.dim_feedforward
)

elif self.hparams.pooling_method == "attention":
# Attention-based pooling with learnable attention weights
Expand Down Expand Up @@ -216,3 +222,29 @@ def pool_sequence(self, out):
return out
else:
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")

def encode(self, num_features, cat_features):
if not hasattr(self, "embedding_layer"):
raise ValueError("The model does not have an embedding layer")

# Check if at least one of the contextualized embedding methods exists
valid_layers = ["mamba", "rnn", "lstm", "encoder"]
available_layer = next(
(attr for attr in valid_layers if hasattr(self, attr)), None
)

if not available_layer:
raise ValueError("The model does not generate contextualized embeddings")

# Get the actual layer and call it
x = self.embedding_layer(num_features=num_features, cat_features=cat_features)

if getattr(self.hparams, "shuffle_embeddings", False):
x = x[:, self.perm, :]

layer = getattr(self, available_layer)
if available_layer == "rnn":
embeddings, _ = layer(x)
else:
embeddings = layer(x)
return embeddings
99 changes: 52 additions & 47 deletions mambular/base_models/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import lightning as pl
import torch
import torch.nn as nn
import torchmetrics


class TaskModel(pl.LightningModule):
Expand Down Expand Up @@ -41,6 +40,8 @@ def __init__(
pruning_epoch=5,
optimizer_type: str = "Adam",
optimizer_args: dict | None = None,
train_metrics: dict[str, Callable] | None = None,
val_metrics: dict[str, Callable] | None = None,
**kwargs,
):
super().__init__()
Expand All @@ -53,6 +54,10 @@ def __init__(
self.pruning_epoch = pruning_epoch
self.val_losses = []

# Store custom metrics
self.train_metrics = train_metrics or {}
self.val_metrics = val_metrics or {}

self.optimizer_params = {
k.replace("optimizer_", ""): v
for k, v in optimizer_args.items() # type: ignore
Expand All @@ -65,16 +70,10 @@ def __init__(
if num_classes == 2:
if not self.loss_fct:
self.loss_fct = nn.BCEWithLogitsLoss()
self.acc = torchmetrics.Accuracy(task="binary")
self.auroc = torchmetrics.AUROC(task="binary")
self.precision = torchmetrics.Precision(task="binary")
self.num_classes = 1
elif num_classes > 2:
if not self.loss_fct:
self.loss_fct = nn.CrossEntropyLoss()
self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.auroc = torchmetrics.AUROC(task="multiclass", num_classes=num_classes)
self.precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes)
else:
self.loss_fct = nn.MSELoss()

Expand Down Expand Up @@ -187,7 +186,7 @@ def training_step(self, batch, batch_idx): # type: ignore
Tensor
Training loss.
"""
cat_features, num_features, labels = batch
num_features, cat_features, labels = batch

# Check if the model has a `penalty_forward` method
if hasattr(self.base_model, "penalty_forward"):
Expand All @@ -200,18 +199,17 @@ def training_step(self, batch, batch_idx): # type: ignore
# Log the training loss
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

# Log additional metrics
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
if self.num_classes > 1:
acc = self.acc(preds, labels)
self.log(
"train_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
# Log custom training metrics
for metric_name, metric_fn in self.train_metrics.items():
metric_value = metric_fn(preds, labels)
self.log(
f"train_{metric_name}",
metric_value,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

Expand All @@ -231,7 +229,7 @@ def validation_step(self, batch, batch_idx): # type: ignore
Validation loss.
"""

cat_features, num_features, labels = batch
num_features, cat_features, labels = batch
preds = self(num_features=num_features, cat_features=cat_features)
val_loss = self.compute_loss(preds, labels)

Expand All @@ -244,18 +242,17 @@ def validation_step(self, batch, batch_idx): # type: ignore
logger=True,
)

# Log additional metrics
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
if self.num_classes > 1:
acc = self.acc(preds, labels)
self.log(
"val_acc",
acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
# Log custom validation metrics
for metric_name, metric_fn in self.val_metrics.items():
metric_value = metric_fn(preds, labels)
self.log(
f"val_{metric_name}",
metric_value,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return val_loss

Expand All @@ -274,7 +271,7 @@ def test_step(self, batch, batch_idx): # type: ignore
Tensor
Test loss.
"""
cat_features, num_features, labels = batch
num_features, cat_features, labels = batch
preds = self(num_features=num_features, cat_features=cat_features)
test_loss = self.compute_loss(preds, labels)

Expand All @@ -287,21 +284,29 @@ def test_step(self, batch, batch_idx): # type: ignore
logger=True,
)

# Log additional metrics
if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
if self.num_classes > 1:
acc = self.acc(preds, labels)
self.log(
"test_acc",
acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return test_loss

def predict_step(self, batch, batch_idx):
"""Predict step for a single batch.

Parameters
----------
batch : tuple
Batch of data containing numerical features, categorical features, and labels.
batch_idx : int
Index of the batch.

Returns
-------
Tensor
Predictions.
"""

num_features, cat_features = batch
preds = self(num_features=num_features, cat_features=cat_features)

return preds

def on_validation_epoch_end(self):
"""Callback executed at the end of each validation epoch.

Expand Down
Loading