Skip to content

Commit

Permalink
fix train script for babyai
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Oct 16, 2023
1 parent 9288371 commit d1c4bc9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/train_gia2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def main():
lambda example_batch: processor(**example_batch, padding="max_length", truncation="preserve"),
batched=True,
batch_size=10,
remove_columns={"text", "images"}.intersection(column_names),
remove_columns={"text", "images", "text_observations"}.intersection(column_names),
)
dataset = dataset.map(
lambda x: {"loss_weight": [LOSS_WEIGHTS.get(task, 1.0)] * len(next(iter(x.values())))}
Expand Down

0 comments on commit d1c4bc9

Please sign in to comment.