diff --git a/docs/api/base_models/BaseModels.rst b/docs/api/base_models/BaseModels.rst index 613bca0..8411d34 100644 --- a/docs/api/base_models/BaseModels.rst +++ b/docs/api/base_models/BaseModels.rst @@ -28,3 +28,23 @@ mambular.base_models .. autoclass:: mambular.base_models.TabTransformer :members: :no-inherited-members: + +.. autoclass:: mambular.base_models.TabulaRNN + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.MambAttention + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.TabM + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.NODE + :members: + :no-inherited-members: + +.. autoclass:: mambular.base_models.NDTF + :members: + :no-inherited-members: diff --git a/docs/api/base_models/index.rst b/docs/api/base_models/index.rst index 8b09e9b..1ed06eb 100644 --- a/docs/api/base_models/index.rst +++ b/docs/api/base_models/index.rst @@ -5,18 +5,23 @@ BaseModels ========== -This module provides base classes for the Mambular models. +This module provides foundational classes and architectures for Mambular models, including various neural network architectures tailored for tabular data. ========================================= ======================================================================================================= Modules Description ========================================= ======================================================================================================= -:class:`BaseModel` Initializes the BaseModel with given hyperparameters -:class:`TaskModel` PyTorch Lightning Module for training and evaluating a model -:class:`Mambular` PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques -:class:`MLP` Initializes the MLP model with the given configuration -:class:`ResNet` ResNet model for structured data -:class:`FTTransformer` PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques -:class:`TabTransformer` PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques +:class:`BaseModel` Abstract base class defining the core structure and initialization logic for Mambular models. +:class:`TaskModel` PyTorch Lightning module for managing model training, validation, and testing workflows. +:class:`Mambular` Flexible neural network model leveraging the Mamba architecture with configurable normalization techniques for tabular data. +:class:`MLP` Multi-layer perceptron (MLP) model designed for tabular tasks, initialized with a custom configuration. +:class:`ResNet` Deep residual network (ResNet) model optimized for structured/tabular datasets. +:class:`FTTransformer` Feature Tokenizer (FTTransformer) model for tabular tasks, incorporating advanced embedding and normalization techniques. +:class:`TabTransformer` TabTransformer model leveraging attention mechanisms for tabular data processing. +:class:`NODE` Neural Oblivious Decision Ensembles (NODE) for tabular tasks, combining decision tree logic with deep learning. +:class:`TabM` TabM architecture designed for tabular data, implementing batch-ensembling MLP techniques. +:class:`NDTF` Neural Decision Tree Forest (NDTF) model for tabular tasks, blending decision tree concepts with neural networks. +:class:`TabulaRNN` Recurrent neural network (RNN) model, including LSTM and GRU architectures, tailored for sequential or time-series tabular data. +:class:`MambAttention` Attention-based architecture for tabular tasks, combining feature importance weighting with advanced normalization techniques. ========================================= ======================================================================================================= @@ -24,6 +29,3 @@ Modules Description :maxdepth: 1 BaseModels - - - diff --git a/docs/api/configs/Configurations.rst b/docs/api/configs/Configurations.rst new file mode 100644 index 0000000..f0caea6 --- /dev/null +++ b/docs/api/configs/Configurations.rst @@ -0,0 +1,46 @@ +Configurations +=============== + +.. autoclass:: mambular.configs.DefaultMambularConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultFTTransformerConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultResNetConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMLPConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabTransformerConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMambaTabConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabulaRNNConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultMambAttentionConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultNDTFConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultNODEConfig + :members: + :undoc-members: + +.. autoclass:: mambular.configs.DefaultTabMConfig + :members: + :undoc-members: diff --git a/docs/api/configs/index.rst b/docs/api/configs/index.rst new file mode 100644 index 0000000..c872dbf --- /dev/null +++ b/docs/api/configs/index.rst @@ -0,0 +1,101 @@ +.. -*- mode: rst -*- + +.. currentmodule:: mambular.configs + +Configurations +============== + +This module provides default configurations for Mambular models. Each configuration is implemented as a dataclass, offering a structured way to define model-specific hyperparameters. + +Mambular +-------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambularConfig` Default configuration for the Mambular model. +======================================= ======================================================================================================= + +FTTransformer +------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultFTTransformerConfig` Default configuration for the FTTransformer model. +======================================= ======================================================================================================= + +ResNet +------ +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultResNetConfig` Default configuration for the ResNet model. +======================================= ======================================================================================================= + +MLP +--- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMLPConfig` Default configuration for the MLP model. +======================================= ======================================================================================================= + +TabTransformer +-------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabTransformerConfig` Default configuration for the TabTransformer model. +======================================= ======================================================================================================= + +MambaTab +-------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambaTabConfig` Default configuration for the MambaTab model. +======================================= ======================================================================================================= + +RNN +--- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabulaRNNConfig` Default configuration for RNN models (LSTM, GRU). +======================================= ======================================================================================================= + +MambAttention +------------- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultMambAttentionConfig` Default configuration for the MambAttention model. +======================================= ======================================================================================================= + +NDTF +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultNDTFConfig` Default configuration for the Neural Decision Tree Forest (NDTF) model. +======================================= ======================================================================================================= + +NODE +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultNODEConfig` Default configuration for the Neural Oblivious Decision Ensembles (NODE) model. +======================================= ======================================================================================================= + +TabM +---- +======================================= ======================================================================================================= +Dataclass Description +======================================= ======================================================================================================= +:class:`DefaultTabMConfig` Default configuration for the TabM model (Batch-Ensembling MLP). +======================================= ======================================================================================================= + +.. toctree:: + :maxdepth: 1 + + Configurations diff --git a/docs/api/models/Models.rst b/docs/api/models/Models.rst index bb5aaf8..1a52268 100644 --- a/docs/api/models/Models.rst +++ b/docs/api/models/Models.rst @@ -73,6 +73,18 @@ mambular.models :members: :undoc-members: +.. autoclass:: mambular.models.MambAttentionClassifier + :members: + :undoc-members: + +.. autoclass:: mambular.models.MambAttentionRegressor + :members: + :undoc-members: + +.. autoclass:: mambular.models.MambAttentionLSS + :members: + :undoc-members: + .. autoclass:: mambular.models.TabulaRNNClassifier :members: :undoc-members: @@ -85,6 +97,42 @@ mambular.models :members: :undoc-members: +.. autoclass:: mambular.models.TabMClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.TabMRegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.TabMLSS + :members: + :undoc-members: + +.. autoclass:: mambular.models.NODEClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.NODERegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.NODELSS + :members: + :undoc-members: + +.. autoclass:: mambular.models.NDTFClassifier + :members: + :inherited-members: + +.. autoclass:: mambular.models.NDTFRegressor + :members: + :inherited-members: + +.. autoclass:: mambular.models.NDTFLSS + :members: + :undoc-members: + .. autoclass:: mambular.models.SklearnBaseClassifier :members: :undoc-members: diff --git a/docs/api/models/index.rst b/docs/api/models/index.rst index 5bb5980..9b689e3 100644 --- a/docs/api/models/index.rst +++ b/docs/api/models/index.rst @@ -7,30 +7,121 @@ Models This module provides classes for the Mambular models that adhere to scikit-learn's `BaseEstimator` interface. +Mambular +-------- ======================================= ======================================================================================================= Modules Description ======================================= ======================================================================================================= :class:`MambularClassifier` Multi-class and binary classification tasks with a sequential Mambular Model. :class:`MambularRegressor` Regression tasks with a sequential Mambular Model. :class:`MambularLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +FTTransformer +------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`FTTransformerClassifier` FT transformer for classification tasks. -:class:`FTTransformerRegressor` FT transformer for regression tasks. -:class:`FTTransformerLSS` Various statistical distribution families for different types of regression and classification tasks. +:class:`FTTransformerRegressor` FT transformer for regression tasks. +:class:`FTTransformerLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +MLP Models +---------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`MLPClassifier` Multi-class and binary classification tasks. :class:`MLPRegressor` MLP for regression tasks. -:class:`MLPLSS` Various statistical distribution families for different types of regression and classification tasks. +:class:`MLPLSS` Various statistical distribution families for different types of regression and classification tasks. +======================================= ======================================================================================================= + +TabTransformer +-------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`TabTransformerClassifier` TabTransformer for classification tasks. :class:`TabTransformerRegressor` TabTransformer for regression tasks. :class:`TabTransformerLSS` TabTransformer for distributional tasks. +======================================= ======================================================================================================= + +ResNet +------ +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`ResNetClassifier` Multi-class and binary classification tasks using ResNet. :class:`ResNetRegressor` Regression tasks using ResNet. :class:`ResNetLSS` Distributional tasks using ResNet. +======================================= ======================================================================================================= + +MambaTab +-------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`MambaTabClassifier` Multi-class and binary classification tasks using MambaTab. :class:`MambaTabRegressor` Regression tasks using MambaTab. :class:`MambaTabLSS` Distributional tasks using MambaTab. +======================================= ======================================================================================================= + +MambaAttention +-------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`MambAttentionClassifier` Multi-class and binary classification tasks using a Combination between Mamba and Attention layers. +:class:`MambAttentionRegressor` Regression tasks using sing a Combination between Mamba and Attention layers. +:class:`MambAttentionLSS` Distributional tasks using sing a Combination between Mamba and Attention layers. +======================================= ======================================================================================================= + +RNN Models Including LSTM and GRU +--------------------------------- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`TabulaRNNClassifier` Multi-class and binary classification tasks using a RNN. :class:`TabulaRNNRegressor` Regression tasks using a RNN. :class:`TabulaRNNLSS` Distributional tasks using a RNN. +======================================= ======================================================================================================= + +TabM +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`TabMClassifier` Multi-class and binary classification tasks using TabM - Batch Ensembling MLP. +:class:`TabMRegressor` Regression tasks using TabM - Batch Ensembling MLP. +:class:`TabMLSS` Distributional tasks using TabM - Batch Ensembling MLP. +======================================= ======================================================================================================= + +NODE +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`NODEClassifier` Multi-class and binary classification tasks using Neural Oblivious Decision Ensembles. +:class:`NODERegressor` Regression tasks using Neural Oblivious Decision Ensembles. +:class:`NODELSS` Distributional tasks using Neural Oblivious Decision Ensembles. +======================================= ======================================================================================================= + +NDTF +---- +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= +:class:`NDTFClassifier` Multi-class and binary classification tasks using a Neural Decision Forest. +:class:`NDTFRegressor` Regression tasks using a Neural Decision Forest +:class:`NDTFLSS` Distributional tasks using a Neural Decision Forest. +======================================= ======================================================================================================= + +Base Classes +------------ +======================================= ======================================================================================================= +Modules Description +======================================= ======================================================================================================= :class:`SklearnBaseClassifier` Base class for classification tasks. :class:`SklearnBaseLSS` Base class for distributional tasks. :class:`SklearnBaseRegressor` Base class for regression tasks. @@ -40,4 +131,3 @@ Modules Description :maxdepth: 1 Models - diff --git a/docs/homepage.md b/docs/homepage.md index c11fcf1..9c8f684 100644 --- a/docs/homepage.md +++ b/docs/homepage.md @@ -1,68 +1,93 @@ -# Mambular: Tabular Deep Learning with Mamba Architectures +# Mambular: Tabular Deep Learning Made Simple -Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning. +Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models. + + +# 🏃 Quickstart +Similar to any sklearn model, Mambular models can be fit as easy as this: + +```python +from mambular.models import MambularClassifier +# Initialize and fit your model +model = MambularClassifier() -## Features +# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array +model.fit(X, y, max_epochs=150, lr=1e-04) +``` -- **Comprehensive Model Suite**: Includes modules for regression, classification, and distributional regression, catering to a wide range of tabular data tasks. -- **State-of-the-Art Architectures**: Leverages various advanced architectures known for their effectiveness in handling tabular data. Mambular models include powerful Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) and can include bidirectional processing as well as feature interaction layers. -- **Seamless Integration**: Designed to work effortlessly with scikit-learn, allowing for easy inclusion in existing machine learning pipelines, cross-validation, and hyperparameter tuning workflows. -- **Extensive Preprocessing**: Comes with a powerful preprocessing module that supports a broad array of data transformation techniques, ensuring that your data is optimally prepared for model training. -- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn. -- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision. +# 📖 Introduction +Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning. -## Models +# 🤖 Models | Model | Description | | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | -| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. | +| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). | +| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) | +| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) | | `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. | | `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. | | `ResNet` | An adaptation of the ResNet architecture for tabular data applications. | | `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. | | `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. | -| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks | +| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). | +| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). | +| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. | + +All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`. +Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS` -## Documentation -You can find the Mamba-Tabular API documentation [here](https://mamba-tabular.readthedocs.io/en/latest/index.html). +# 📚 Documentation -## Installation +You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/). + +# 🛠️ Installation Install Mambular using pip: ```sh pip install mambular ``` -## Preprocessing +If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via: -Mambular simplifies the preprocessing stage of model development with a comprehensive set of techniques to prepare your data for Mamba architectures. Our preprocessing module is designed to be both powerful and easy to use, offering a variety of options to efficiently transform your tabular data. +```sh +pip install mamba-ssm +``` -### Data Type Detection and Transformation +Be careful to use the correct torch and cuda versions: -Mambular automatically identifies the type of each feature in your dataset and applies the most appropriate transformations for numerical and categorical variables. This includes: -- **Ordinal Encoding**: Categorical features are seamlessly transformed into numerical values, preserving their inherent order and making them model-ready. -- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality. -- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts. -- **Decision Tree Binning**: Optionally, Mambular can use decision trees to find the optimal binning strategy for numerical features, enhancing model interpretability and performance. -- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible. -- **Standardization**: Similarly, standardization instead of normalization can be used to scale features based on the mean and standard deviation. -- **PLE (Periodic Linear Encoding)**: This technique can be applied to numerical features to enhance the performance of tabular deep learning methods by encoding periodicity. -- **Quantile Transformation**: Numerical features can be transformed to follow a uniform or normal distribution, improving model robustness to outliers. -- **Spline Transformation**: Applies piecewise polynomial functions to numerical features, capturing nonlinear relationships more effectively. -- **Polynomial Features**: Generates polynomial and interaction features, increasing the feature space to capture more complex relationships within the data. +```sh +pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html +pip install mamba-ssm +``` + +# 🚀 Usage + +