diff --git a/docs/api/base_models/BaseModels.rst b/docs/api/base_models/BaseModels.rst index 613bca0..8411d34 100644 --- a/docs/api/base_models/BaseModels.rst +++ b/docs/api/base_models/BaseModels.rst @@ -28,3 +28,23 @@ mambular.base_models .. autoclass:: mambular.base_models.TabTransformer :members: :no-inherited-members: + +.. autoclass:: mambular.base_models.TabulaRNN + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.MambAttention + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.TabM + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.NODE + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.NDTF + :members: + :no-inherited-members: diff --git a/docs/api/base_models/index.rst b/docs/api/base_models/index.rst index 8b09e9b..1ed06eb 100644 --- a/docs/api/base_models/index.rst +++ b/docs/api/base_models/index.rst @@ -5,18 +5,23 @@ BaseModels ========== -This module provides base classes for the Mambular models. +This module provides foundational classes and architectures for Mambular models, including various neural network architectures tailored for tabular data. ========================================= ======================================================================================================= Modules Description ========================================= ======================================================================================================= -:class:`BaseModel` Initializes the BaseModel with given hyperparameters -:class:`TaskModel` PyTorch Lightning Module for training and evaluating a model -:class:`Mambular` PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques -:class:`MLP` Initializes the MLP model with the given configuration -:class:`ResNet` ResNet model for structured data -:class:`FTTransformer` PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques -:class:`TabTransformer` PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques +:class:`BaseModel` Abstract base class defining the core structure and initialization logic for Mambular models. +:class:`TaskModel` PyTorch Lightning module for managing model training, validation, and testing workflows. +:class:`Mambular` Flexible neural network model leveraging the Mamba architecture with configurable normalization techniques for tabular data. +:class:`MLP` Multi-layer perceptron (MLP) model designed for tabular tasks, initialized with a custom configuration. +:class:`ResNet` Deep residual network (ResNet) model optimized for structured/tabular datasets. +:class:`FTTransformer` Feature Tokenizer (FTTransformer) model for tabular tasks, incorporating advanced embedding and normalization techniques. +:class:`TabTransformer` TabTransformer model leveraging attention mechanisms for tabular data processing. +:class:`NODE` Neural Oblivious Decision Ensembles (NODE) for tabular tasks, combining decision tree logic with deep learning. +:class:`TabM` TabM architecture designed for tabular data, implementing batch-ensembling MLP techniques. +:class:`NDTF` Neural Decision Tree Forest (NDTF) model for tabular tasks, blending decision tree concepts with neural networks. +:class:`TabulaRNN` Recurrent neural network (RNN) model, including LSTM and GRU architectures, tailored for sequential or time-series tabular data. +:class:`MambAttention` Attention-based architecture for tabular tasks, combining feature importance weighting with advanced normalization techniques. ========================================= ======================================================================================================= @@ -24,6 +29,3 @@ Modules Description :maxdepth: 1 BaseModels - - - diff --git a/docs/api/configs/Configurations.rst b/docs/api/configs/Configurations.rst new file mode 100644 index 0000000..f0caea6 --- /dev/null +++ b/docs/api/configs/Configurations.rst @@ -0,0 +1,46 @@ +Configurations +=============== + +.. autoclass:: mambular.configs.DefaultMambularConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultFTTransformerConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultResNetConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMLPConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabTransformerConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMambaTabConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabulaRNNConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMambAttentionConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultNDTFConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultNODEConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabMConfig + :members: + :undoc-members: diff --git a/docs/api/configs/index.rst b/docs/api/configs/index.rst new file mode 100644 index 0000000..c872dbf --- /dev/null +++ b/docs/api/configs/index.rst @@ -0,0 +1,101 @@ +.. -*- mode: rst -*- + +.. currentmodule:: mambular.configs + +Configurations +============== + +This module provides default configurations for Mambular models. Each configuration is implemented as a dataclass, offering a structured way to define model-specific hyperparameters. + +Mambular +-------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambularConfig` Default configuration for the Mambular model. +======================================= ======================================================================================================= + +FTTransformer +------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultFTTransformerConfig` Default configuration for the FTTransformer model. +======================================= ======================================================================================================= + +ResNet +------ +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultResNetConfig` Default configuration for the ResNet model. +======================================= ======================================================================================================= + +MLP +--- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMLPConfig` Default configuration for the MLP model. +======================================= ======================================================================================================= + +TabTransformer +-------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabTransformerConfig` Default configuration for the TabTransformer model. +======================================= ======================================================================================================= + +MambaTab +-------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambaTabConfig` Default configuration for the MambaTab model. +======================================= ======================================================================================================= + +RNN +--- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabulaRNNConfig` Default configuration for RNN models (LSTM, GRU). +======================================= ======================================================================================================= + +MambAttention +------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambAttentionConfig` Default configuration for the MambAttention model. +======================================= ======================================================================================================= + +NDTF +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultNDTFConfig` Default configuration for the Neural Decision Tree Forest (NDTF) model. +======================================= ======================================================================================================= + +NODE +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultNODEConfig` Default configuration for the Neural Oblivious Decision Ensembles (NODE) model. +======================================= ======================================================================================================= + +TabM +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabMConfig` Default configuration for the TabM model (Batch-Ensembling MLP). +======================================= ======================================================================================================= + +.. toctree:: + :maxdepth: 1 + + Configurations diff --git a/docs/api/models/Models.rst b/docs/api/models/Models.rst index bb5aaf8..1a52268 100644 --- a/docs/api/models/Models.rst +++ b/docs/api/models/Models.rst @@ -73,6 +73,18 @@ mambular.models :members: :undoc-members: +.. autoclass:: mambular.models.MambAttentionClassifier + :members: + :undoc-members: + +.. autoclass:: mambular.models.MambAttentionRegressor + :members: + :undoc-members: + +.. autoclass:: mambular.models.MambAttentionLSS + :members: + :undoc-members: + .. autoclass:: mambular.models.TabulaRNNClassifier :members: :undoc-members: @@ -85,6 +97,42 @@ mambular.models :members: :undoc-members: +.. autoclass:: mambular.models.TabMClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.TabMRegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.TabMLSS + :members: + :undoc-members: + +.. autoclass:: mambular.models.NODEClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.NODERegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.NODELSS + :members: + :undoc-members: + +.. autoclass:: mambular.models.NDTFClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.NDTFRegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.NDTFLSS + :members: + :undoc-members: + .. autoclass:: mambular.models.SklearnBaseClassifier :members: :undoc-members: diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst index 5bb5980..9b689e3 100644 --- a/docs/api/models/index.rst +++ b/docs/api/models/index.rst @@ -7,30 +7,121 @@ Models This module provides classes for the Mambular models that adhere to scikit-learn's `BaseEstimator` interface. +Mambular +-------- ======================================= ======================================================================================================= Modules Description ======================================= ======================================================================================================= :class:`MambularClassifier` Multi-class and binary classification tasks with a sequential Mambular Model. :class:`MambularRegressor` Regression tasks with a sequential Mambular Model. :class:`MambularLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +FTTransformer +------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`FTTransformerClassifier` FT transformer for classification tasks. -:class:`FTTransformerRegressor` FT transformer for regression tasks. -:class:`FTTransformerLSS` Various statistical distribution families for different types of regression and classification tasks. +:class:`FTTransformerRegressor` FT transformer for regression tasks. +:class:`FTTransformerLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +MLP Models +---------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`MLPClassifier` Multi-class and binary classification tasks. :class:`MLPRegressor` MLP for regression tasks. -:class:`MLPLSS` Various statistical distribution families for different types of regression and classification tasks. +:class:`MLPLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +TabTransformer +-------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`TabTransformerClassifier` TabTransformer for classification tasks. :class:`TabTransformerRegressor` TabTransformer for regression tasks. :class:`TabTransformerLSS` TabTransformer for distributional tasks. +======================================= ======================================================================================================= + +ResNet +------ +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`ResNetClassifier` Multi-class and binary classification tasks using ResNet. :class:`ResNetRegressor` Regression tasks using ResNet. :class:`ResNetLSS` Distributional tasks using ResNet. +======================================= ======================================================================================================= + +MambaTab +-------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`MambaTabClassifier` Multi-class and binary classification tasks using MambaTab. :class:`MambaTabRegressor` Regression tasks using MambaTab. :class:`MambaTabLSS` Distributional tasks using MambaTab. +======================================= ======================================================================================================= + +MambaAttention +-------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`MambAttentionClassifier` Multi-class and binary classification tasks using a Combination between Mamba and Attention layers. +:class:`MambAttentionRegressor` Regression tasks using sing a Combination between Mamba and Attention layers. +:class:`MambAttentionLSS` Distributional tasks using sing a Combination between Mamba and Attention layers. +======================================= ======================================================================================================= + +RNN Models Including LSTM and GRU +--------------------------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`TabulaRNNClassifier` Multi-class and binary classification tasks using a RNN. :class:`TabulaRNNRegressor` Regression tasks using a RNN. :class:`TabulaRNNLSS` Distributional tasks using a RNN. +======================================= ======================================================================================================= + +TabM +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`TabMClassifier` Multi-class and binary classification tasks using TabM - Batch Ensembling MLP. +:class:`TabMRegressor` Regression tasks using TabM - Batch Ensembling MLP. +:class:`TabMLSS` Distributional tasks using TabM - Batch Ensembling MLP. +======================================= ======================================================================================================= + +NODE +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`NODEClassifier` Multi-class and binary classification tasks using Neural Oblivious Decision Ensembles. +:class:`NODERegressor` Regression tasks using Neural Oblivious Decision Ensembles. +:class:`NODELSS` Distributional tasks using Neural Oblivious Decision Ensembles. +======================================= ======================================================================================================= + +NDTF +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`NDTFClassifier` Multi-class and binary classification tasks using a Neural Decision Forest. +:class:`NDTFRegressor` Regression tasks using a Neural Decision Forest +:class:`NDTFLSS` Distributional tasks using a Neural Decision Forest. +======================================= ======================================================================================================= + +Base Classes +------------ +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`SklearnBaseClassifier` Base class for classification tasks. :class:`SklearnBaseLSS` Base class for distributional tasks. :class:`SklearnBaseRegressor` Base class for regression tasks. @@ -40,4 +131,3 @@ Modules Description :maxdepth: 1 Models - diff --git a/docs/homepage.md b/docs/homepage.md index c11fcf1..9c8f684 100644 --- a/docs/homepage.md +++ b/docs/homepage.md @@ -1,68 +1,93 @@ -# Mambular: Tabular Deep Learning with Mamba Architectures +# Mambular: Tabular Deep Learning Made Simple -Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning. +Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models. + + +# 🏃 Quickstart +Similar to any sklearn model, Mambular models can be fit as easy as this: + +```python +from mambular.models import MambularClassifier +# Initialize and fit your model +model = MambularClassifier() -## Features +# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array +model.fit(X, y, max_epochs=150, lr=1e-04) +``` -- **Comprehensive Model Suite**: Includes modules for regression, classification, and distributional regression, catering to a wide range of tabular data tasks. -- **State-of-the-Art Architectures**: Leverages various advanced architectures known for their effectiveness in handling tabular data. Mambular models include powerful Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) and can include bidirectional processing as well as feature interaction layers. -- **Seamless Integration**: Designed to work effortlessly with scikit-learn, allowing for easy inclusion in existing machine learning pipelines, cross-validation, and hyperparameter tuning workflows. -- **Extensive Preprocessing**: Comes with a powerful preprocessing module that supports a broad array of data transformation techniques, ensuring that your data is optimally prepared for model training. -- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn. -- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision. +# 📖 Introduction +Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning. -## Models +# 🤖 Models | Model | Description | | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | -| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. | +| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). | +| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) | +| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) | | `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. | | `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. | | `ResNet` | An adaptation of the ResNet architecture for tabular data applications. | | `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. | | `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. | -| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks | +| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). | +| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). | +| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. | + +All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`. +Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS` -## Documentation -You can find the Mamba-Tabular API documentation [here](https://mamba-tabular.readthedocs.io/en/latest/index.html). +# 📚 Documentation -## Installation +You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/). + +# 🛠️ Installation Install Mambular using pip: ```sh pip install mambular ``` -## Preprocessing +If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via: -Mambular simplifies the preprocessing stage of model development with a comprehensive set of techniques to prepare your data for Mamba architectures. Our preprocessing module is designed to be both powerful and easy to use, offering a variety of options to efficiently transform your tabular data. +```sh +pip install mamba-ssm +``` -### Data Type Detection and Transformation +Be careful to use the correct torch and cuda versions: -Mambular automatically identifies the type of each feature in your dataset and applies the most appropriate transformations for numerical and categorical variables. This includes: -- **Ordinal Encoding**: Categorical features are seamlessly transformed into numerical values, preserving their inherent order and making them model-ready. -- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality. -- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts. -- **Decision Tree Binning**: Optionally, Mambular can use decision trees to find the optimal binning strategy for numerical features, enhancing model interpretability and performance. -- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible. -- **Standardization**: Similarly, standardization instead of normalization can be used to scale features based on the mean and standard deviation. -- **PLE (Periodic Linear Encoding)**: This technique can be applied to numerical features to enhance the performance of tabular deep learning methods by encoding periodicity. -- **Quantile Transformation**: Numerical features can be transformed to follow a uniform or normal distribution, improving model robustness to outliers. -- **Spline Transformation**: Applies piecewise polynomial functions to numerical features, capturing nonlinear relationships more effectively. -- **Polynomial Features**: Generates polynomial and interaction features, increasing the feature space to capture more complex relationships within the data. +```sh +pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html +pip install mamba-ssm +``` + +# 🚀 Usage + +

Preprocessing

+ +Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data. +

Data Type Detection and Transformation

-### Handling Missing Values +- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats using continuous ordinal encoding or one-hot encoding. Includes options for transforming outputs to `float` for compatibility with downstream models. +- **Binning**: Discretizes numerical features into bins, with support for both fixed binning strategies and optimal binning derived from decision tree models. +- **MinMax**: Scales numerical data to a specific range, such as [-1, 1], using Min-Max scaling or similar techniques. +- **Standardization**: Centers and scales numerical features to have a mean of zero and unit variance for better compatibility with certain models. +- **Quantile Transformations**: Normalizes numerical data to follow a uniform or normal distribution, handling distributional shifts effectively. +- **Spline Transformations**: Captures nonlinearity in numerical features using spline-based transformations, ideal for complex relationships. +- **Piecewise Linear Encodings (PLE)**: Captures complex numerical patterns by applying piecewise linear encoding, suitable for data with periodic or nonlinear structures. +- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships. +- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions. +- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data. + -Our preprocessing pipeline effectively handles missing data by using mean imputation for numerical features and mode imputation for categorical features. This ensures that your models receive complete data inputs without needing manual intervention. -Additionally, Mambular can manage unknown categorical values during inference by incorporating classical tokens in categorical preprocessing. -## Fit a Model +

Fit a Model

Fitting a model in mambular is as simple as it gets. All models in mambular are sklearn BaseEstimators. Thus the `.fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools. ```python @@ -70,9 +95,10 @@ from mambular.models import MambularClassifier # Initialize and fit your model model = MambularClassifier( d_model=64, - n_layers=8, + n_layers=4, numerical_preprocessing="ple", - n_bins=50 + n_bins=50, + d_conv=8 ) # X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array @@ -88,38 +114,88 @@ preds = model.predict(X) preds = model.predict_proba(X) ``` +

Hyperparameter Optimization

+Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn. -## Distributional Regression with MambularLSS +```python +from sklearn.model_selection import RandomizedSearchCV -Mambular introduces an approach to distributional regression through its `MambularLSS` module, allowing users to model the full distribution of a response variable, not just its mean. This method is particularly valuable in scenarios where understanding the variability, skewness, or kurtosis of the response distribution is as crucial as predicting its central tendency. All available moedls in mambular are also available as distributional models. +param_dist = { + 'd_model': randint(32, 128), + 'n_layers': randint(2, 10), + 'lr': uniform(1e-5, 1e-3) +} -### Key Features of MambularLSS: +random_search = RandomizedSearchCV( + estimator=model, + param_distributions=param_dist, + n_iter=50, # Number of parameter settings sampled + cv=5, # 5-fold cross-validation + scoring='accuracy', # Metric to optimize + random_state=42 +) -- **Full Distribution Modeling**: Unlike traditional regression models that predict a single value (e.g., the mean), `MambularLSS` models the entire distribution of the response variable. This allows for more informative predictions, including quantiles, variance, and higher moments. -- **Customizable Distribution Types**: `MambularLSS` supports a variety of distribution families (e.g., Gaussian, Poisson, Binomial), making it adaptable to different types of response variables, from continuous to count data. -- **Location, Scale, Shape Parameters**: The model predicts parameters corresponding to the location, scale, and shape of the distribution, offering a nuanced understanding of the data's underlying distributional characteristics. -- **Enhanced Predictive Uncertainty**: By modeling the full distribution, `MambularLSS` provides richer information on predictive uncertainty, enabling more robust decision-making processes in uncertain environments. +fit_params = {"max_epochs":5, "rebuild":False} +# Fit the model +random_search.fit(X, y, **fit_params) +# Best parameters and score +print("Best Parameters:", random_search.best_params_) +print("Best Score:", random_search.best_score_) +``` +Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize: +```python +param_dist = { + 'd_model': randint(32, 128), + 'n_layers': randint(2, 10), + 'lr': uniform(1e-5, 1e-3), + "prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"] +} -### Available Distribution Classes: +``` -`MambularLSS` offers a wide range of distribution classes to cater to various statistical modeling needs. The available distribution classes include: -- `normal`: Normal Distribution for modeling continuous data with a symmetric distribution around the mean. -- `poisson`: Poisson Distribution for modeling count data that for instance represent the number of events occurring within a fixed interval. -- `gamma`: Gamma Distribution for modeling continuous data that is skewed and bounded at zero, often used for waiting times. -- `beta`: Beta Distribution for modeling data that is bounded between 0 and 1, useful for proportions and percentages. -- `dirichlet`: Dirichlet Distribution for modeling multivariate data where individual components are correlated, and the sum is constrained to 1. -- `studentt`: Student's T-Distribution for modeling data with heavier tails than the normal distribution, useful when the sample size is small. -- `negativebinom`: Negative Binomial Distribution for modeling count data with over-dispersion relative to the Poisson distribution. -- `inversegamma`: Inverse Gamma Distribution, often used as a prior distribution in Bayesian inference for scale parameters. -- `categorical`: Categorical Distribution for modeling categorical data with more than two categories. +Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible. -These distribution classes allow `MambularLSS` to flexibly model a wide variety of data types and distributions, providing users with the tools needed to capture the full complexity of their data. +Or use the built-in bayesian hpo simply by running: -### Getting Started with MambularLSS: +```python +best_params = model.optimize_hparams(X, y) +``` + +This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here. + + +

⚖️ Distributional Regression with MambularLSS

+ +MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All Mambular models are available as distributional models. + +

Key Features of MambularLSS:

+ +- **Full Distribution Modeling**: Predicts the entire distribution, not just a single value, providing richer insights. +- **Customizable Distribution Types**: Supports various distributions (e.g., Gaussian, Poisson, Binomial) for different data types. +- **Location, Scale, Shape Parameters**: Predicts key distributional parameters for deeper insights. +- **Enhanced Predictive Uncertainty**: Offers more robust predictions by modeling the entire distribution. + +

Available Distribution Classes:

+ +- **normal**: For continuous data with a symmetric distribution. +- **poisson**: For count data within a fixed interval. +- **gamma**: For skewed continuous data, often used for waiting times. +- **beta**: For data bounded between 0 and 1, like proportions. +- **dirichlet**: For multivariate data with correlated components. +- **studentt**: For data with heavier tails, useful with small samples. +- **negativebinom**: For over-dispersed count data. +- **inversegamma**: Often used as a prior in Bayesian inference. +- **categorical**: For data with more than two categories. +- **Quantile**: For quantile regression using the pinball loss. + +These distribution classes make MambularLSS versatile in modeling various data types and distributions. + + +

Getting Started with MambularLSS:

To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other Mambular models: @@ -147,7 +223,7 @@ model.fit( ``` -### Implement Your Own Model +# 💻 Implement Your Own Model Mambular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from Mambular's `BaseModel`. Each Mambular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs. @@ -155,79 +231,127 @@ One of the key advantages of using Mambular is that the inputs to the forward pa Here's how you can implement a custom model with Mambular: - -1. First, define your config: -The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass. - -```python -from dataclasses import dataclass - -@dataclass -class MyConfig: - lr: float = 1e-04 - lr_patience: int = 10 - weight_decay: float = 1e-06 - lr_factor: float = 0.1 -``` - -2. Second, define your model: -Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass. +1. **First, define your config:** + The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass. + + ```python + from dataclasses import dataclass + + @dataclass + class MyConfig: + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + ``` + +2. **Second, define your model:** + Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass. + + ```python + from mambular.base_models import BaseModel + from mambular.utils.get_feature_dimensions import get_feature_dimensions + import torch + import torch.nn + + class MyCustomModel(BaseModel): + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes: int = 1, + config=None, + **kwargs, + ): + super().__init__(**kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) + + self.linear = nn.Linear(input_dim, num_classes) + + def forward(self, num_features, cat_features): + x = num_features + cat_features + x = torch.cat(x, dim=1) + + # Pass through linear layer + output = self.linear(x) + return output + ``` + +3. **Leverage the Mambular API:** + You can build a regression, classification, or distributional regression model that can leverage all of Mambular's built-in methods by using the following: + + ```python + from mambular.models import SklearnBaseRegressor + + class MyRegressor(SklearnBaseRegressor): + def __init__(self, **kwargs): + super().__init__(model=MyCustomModel, config=MyConfig, **kwargs) + ``` + +4. **Train and evaluate your model:** + You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively. + + ```python + regressor = MyRegressor(numerical_preprocessing="ple") + regressor.fit(X_train, y_train, max_epochs=50) + ``` + +# Custom Training +If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`. +Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this. ```python -from mambular.base_models import BaseModel import torch -import torch.nn - -class MyCustomModel(BaseModel): - def __init__( - self, - cat_feature_info, - num_feature_info, - num_classes: int = 1, - config=None, - **kwargs, - ): - super().__init__(**kwargs) - self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - - input_dim = 0 - for feature_name, input_shape in num_feature_info.items(): - input_dim += input_shape - for feature_name, input_shape in cat_feature_info.items(): - input_dim += 1 - - self.linear = nn.Linear(input_dim, num_classes) - - def forward(self, num_features, cat_features): - x = num_features + cat_features - x = torch.cat(x, dim=1) - - # Pass through linear layer - output = self.linear(x) - return output -``` - -3. Leverage the Mambular API: -You can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following: - -```python -from mambular.models import SklearnBaseRegressor +import torch.nn as nn +import torch.optim as optim +from mambular.base_models import Mambular +from mambular.configs import DefaultMambularConfig + +# Dummy data and configuration +cat_feature_info = { + "cat1": { + "preprocessing": "imputer -> continuous_ordinal", + "dimension": 1, + "categories": 4, + } +} # Example categorical feature information +num_feature_info = { + "num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None} +} # Example numerical feature information +num_classes = 1 +config = DefaultMambularConfig() # Use the desired configuration + +# Initialize model, loss function, and optimizer +model = Mambular(cat_feature_info, num_feature_info, num_classes, config) +criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task +optimizer = optim.Adam(model.parameters(), lr=0.001) + +# Example training loop +for epoch in range(10): # Number of epochs + model.train() + optimizer.zero_grad() + + # Dummy Data + num_features = [torch.randn(32, 1) for _ in num_feature_info] + cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info] + labels = torch.randn(32, num_classes) + + # Forward pass + outputs = model(num_features, cat_features) + loss = criterion(outputs, labels) + + # Backward pass and optimization + loss.backward() + optimizer.step() + + # Print loss for monitoring + print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}") -class MyRegressor(SklearnBaseRegressor): - def __init__(self, **kwargs): - super().__init__(model=MyCustomModel, config=MyConfig, **kwargs) ``` -4. Train and evaluate your model: -You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively. - -```python -regressor = MyRegressor(numerical_preprocessing="ple") -regressor.fit(X_train, y_train, max_epochs=50) -``` - - -## Citation +# 🏷️ Citation If you find this project useful in your research, please consider cite: ```BibTeX @@ -239,6 +363,16 @@ If you find this project useful in your research, please consider cite: } ``` -## License +If you use TabulaRNN please consider to cite: +```BibTeX +@article{thielmann2024efficiency, + title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning}, + author={Thielmann, Anton Frederik and Samiee, Soheila}, + journal={arXiv preprint arXiv:2411.17207}, + year={2024} +} +``` + +# License -The entire codebase is under MIT license. \ No newline at end of file +The entire codebase is under MIT license. diff --git a/docs/index.rst b/docs/index.rst index 26d238a..1bea8b4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ api/base_models/index api/preprocessing/index api/data_utils/index + api/configs/index .. toctree:: diff --git a/mambular/base_models/__init__.py b/mambular/base_models/__init__.py index b3eda43..7452c90 100644 --- a/mambular/base_models/__init__.py +++ b/mambular/base_models/__init__.py @@ -6,9 +6,11 @@ from .resnet import ResNet from .tabtransformer import TabTransformer from .mambatab import MambaTab -from .mambattn import MambAttn +from .mambattn import MambAttention from .node import NODE from .tabm import TabM +from .tabularnn import TabulaRNN +from .ndtf import NDTF __all__ = [ "TaskModel", @@ -19,7 +21,9 @@ "MLP", "BaseModel", "MambaTab", - "MambAttn", + "MambAttention", "TabM", "NODE", + "NDTF", + "TabulaRNN", ]