Skip to content

Commit

Permalink
Merge pull request #55 from georgian-io/akash/bugfix
Browse files Browse the repository at this point in the history
Akash/bugfix
  • Loading branch information
akashsara authored Nov 7, 2023
2 parents bfc6c24 + 5b0bffe commit 3543ca2
Show file tree
Hide file tree
Showing 12 changed files with 1,011 additions and 987 deletions.
16 changes: 8 additions & 8 deletions docs/source/notes/combine_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ The following describes each supported method and whether or not it requires bot
| gating_on_cat_and_num_feats_then_sum | Gated summation of transformer outputs, numerical feats, and categorical feats before final classifier layer(s). Inspired by [Integrating Multimodal Information in Large Pretrained Transformers](https://www.aclweb.org/anthology/2020.acl-main.214.pdf) which performs the mechanism for each token. | False
| weighted_feature_sum_on_transformer_cat_and_numerical_feats | Learnable weighted feature-wise sum of transformer outputs, numerical feats and categorical feats for each feature dimension before final classifier layer(s) | False

This table shows the the equations involved with each method. First we define some notation
This table shows the the equations involved with each method. First we define some notations:

* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)  denotes the combined multimodal features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)  denotes the output text features from the transformer
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)  denotes the categorical features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)  denotes the numerical features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D)
* ![equation](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)  denotes a weight matrix
* ![equation](https://latex.codecogs.com/svg.latex?b)  denotes a scalar bias
* ![m](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)   denotes the combined multimodal features
* ![x](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)   denotes the output text features from the transformer
* ![c](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)   denotes the categorical features
* ![n](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)   denotes the numerical features
* ![h_theta](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![theta](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D)
* ![W](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)   denotes a weight matrix
* ![b](https://latex.codecogs.com/svg.latex?b)   denotes a scalar bias

| Combine Feat Method | Equation |
|:--------------|:-------------------|
Expand Down
1 change: 0 additions & 1 deletion docs/source/notes/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Say for example we had categorical features of dim 9 and numerical features of d
cat_feat_dim=9, # need to specify this
numerical_feat_dim=5, # need to specify this
num_labels=2, # need to specify this, assuming our task is binary classification
use_num_bn=False,
)
bert_config.tabular_config = tabular_config
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def compute_metrics_fn(p: EvalPrediction):
)
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path
resume_from_checkpoint=model_args.model_name_or_path
if os.path.isdir(model_args.model_name_or_path)
else None
)
Expand Down
312 changes: 165 additions & 147 deletions multimodal_exp_args.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion multimodal_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multimodal_transformers.data
import multimodal_transformers.model

__version__ = "0.2-alpha"
__version__ = "0.3.0"

__all__ = ["multimodal_transformers", "__version__"]
29 changes: 11 additions & 18 deletions multimodal_transformers/model/tabular_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,13 @@ def __init__(self, tabular_config):
self.numerical_feat_dim = tabular_config.numerical_feat_dim
self.num_labels = tabular_config.num_labels
self.numerical_bn = tabular_config.numerical_bn
self.categorical_bn = tabular_config.categorical_bn
self.mlp_act = tabular_config.mlp_act
self.mlp_dropout = tabular_config.mlp_dropout
self.mlp_division = tabular_config.mlp_division
self.text_out_dim = tabular_config.text_feat_dim
self.tabular_config = tabular_config

if self.numerical_bn and self.numerical_feat_dim > 0:
self.num_bn = nn.BatchNorm1d(self.numerical_feat_dim)
else:
self.num_bn = None

if self.combine_feat_method == "text_only":
self.final_out_dim = self.text_out_dim
elif self.combine_feat_method == "concat":
Expand All @@ -131,7 +127,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
self.final_out_dim = (
self.text_out_dim + output_dim + self.numerical_feat_dim
Expand All @@ -157,7 +153,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn and self.numerical_bn,
)
self.final_out_dim = self.text_out_dim + output_dim
elif (
Expand All @@ -181,7 +177,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)

output_dim_num = 0
Expand All @@ -194,7 +190,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
num_hidden_lyr=1,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
self.final_out_dim = self.text_out_dim + output_dim_num + output_dim_cat
elif (
Expand All @@ -220,7 +216,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
else:
self.cat_layer = nn.Linear(self.cat_feat_dim, output_dim_cat)
Expand All @@ -242,7 +238,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
else:
self.num_layer = nn.Linear(self.numerical_feat_dim, output_dim_num)
Expand Down Expand Up @@ -275,7 +271,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
return_layer_outs=False,
hidden_channels=dims,
bn=True,
bn=self.categorical_bn,
)
else:
output_dim_cat = self.cat_feat_dim
Expand All @@ -297,7 +293,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
return_layer_outs=False,
hidden_channels=dims,
bn=True,
bn=self.numerical_bn,
)
else:
output_dim_num = self.numerical_feat_dim
Expand Down Expand Up @@ -330,7 +326,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
self.g_cat_layer = nn.Linear(
self.text_out_dim + min(self.text_out_dim, self.cat_feat_dim),
Expand All @@ -357,7 +353,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
self.g_num_layer = nn.Linear(
min(self.numerical_feat_dim, self.text_out_dim) + self.text_out_dim,
Expand Down Expand Up @@ -398,9 +394,6 @@ def forward(self, text_feats, cat_feats=None, numerical_feats=None):
text_feats.device
)

if self.numerical_bn and self.numerical_feat_dim != 0:
numerical_feats = self.num_bn(numerical_feats)

if self.combine_feat_method == "text_only":
combined_feats = text_feats
if self.combine_feat_method == "concat":
Expand Down
3 changes: 3 additions & 0 deletions multimodal_transformers/model/tabular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TabularConfig:
See :obj:`TabularFeatCombiner` for details on the supported methods.
mlp_dropout (float): dropout ratio used for MLP layers
numerical_bn (bool): whether to use batchnorm on numerical features
categorical_bn (bool): whether to use batchnorm on categorical features
use_simple_classifier (bool): whether to use single layer or MLP as final classifier
mlp_act (str): the activation function to use for finetuning layers
gating_beta (float): the beta hyperparameters used for gating tabular data
Expand All @@ -25,6 +26,7 @@ def __init__(
combine_feat_method="text_only",
mlp_dropout=0.1,
numerical_bn=True,
categorical_bn=True,
use_simple_classifier=True,
mlp_act="relu",
gating_beta=0.2,
Expand All @@ -36,6 +38,7 @@ def __init__(
self.combine_feat_method = combine_feat_method
self.mlp_dropout = mlp_dropout
self.numerical_bn = numerical_bn
self.categorical_bn = categorical_bn
self.use_simple_classifier = use_simple_classifier
self.mlp_act = mlp_act
self.gating_beta = gating_beta
Expand Down
2 changes: 2 additions & 0 deletions multimodal_transformers/model/tabular_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def __init__(self, hf_model_config):
self.config.tabular_config = tabular_config.__dict__

tabular_config.text_feat_dim = hf_model_config.hidden_size
tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob
self.tabular_combiner = TabularFeatCombiner(tabular_config)
self.num_labels = tabular_config.num_labels
combined_feat_dim = self.tabular_combiner.final_out_dim
Expand Down Expand Up @@ -603,6 +604,7 @@ def __init__(self, hf_model_config):
self.config.tabular_config = tabular_config.__dict__

tabular_config.text_feat_dim = hf_model_config.hidden_size
tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob
self.tabular_combiner = TabularFeatCombiner(tabular_config)
self.num_labels = tabular_config.num_labels
combined_feat_dim = self.tabular_combiner.final_out_dim
Expand Down
Loading

0 comments on commit 3543ca2

Please sign in to comment.