diff --git a/scripts/train_gia2.py b/scripts/train_gia2.py index d507fd8e..5a992739 100755 --- a/scripts/train_gia2.py +++ b/scripts/train_gia2.py @@ -145,7 +145,7 @@ def main(): lambda example_batch: processor(**example_batch, padding="max_length", truncation="preserve"), batched=True, batch_size=10, - remove_columns={"text", "images"}.intersection(column_names), + remove_columns={"text", "images", "text_observations"}.intersection(column_names), ) dataset = dataset.map( lambda x: {"loss_weight": [LOSS_WEIGHTS.get(task, 1.0)] * len(next(iter(x.values())))}