Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update train script with new interleave function #171

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ Here are some examples of how you might use JAT in both evaluation and fine-tuni
- **Training JAT**: Train your own JAT model from scratch (run on 8xA100)
```shell
accelerate launch scripts/train_jat_tokenized.py \
--output_dir checkpoints/jat_small_v100 \
--model_name_or_path jat-project/jat-small \
--output_dir checkpoints/jat \
--model_name_or_path jat-project/jat \
--tasks all \
--trust_remote_code \
--per_device_train_batch_size 20 \
Expand Down
1 change: 1 addition & 0 deletions jat/processing_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class JatProcessor(ProcessorMixin):
tokenizer ([`AutoTokenizer`]):
The tokenizer is a required input.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
Expand Down
12 changes: 9 additions & 3 deletions scripts/train_jat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from jat.eval.rl.core import TASK_NAME_TO_ENV_ID
from jat.modeling_jat import JatModel
from jat.utils import mix_iterable_datasets
from jat.utils_interleave_datasets import interleave_datasets


# Sometimes, the server is down; increasing the number of
Expand Down Expand Up @@ -185,9 +185,15 @@ def add_loss_weight(example, loss_weight):
eval_dataset[key] = eval_dataset[key].take(data_args.eval_num_samples)

weights = [SAMPLE_WEIGHTS.get(t, 1.0) for t in train_dataset.keys()]
train_dataset = mix_iterable_datasets(
list(train_dataset.values()), batch_size=training_args.per_device_train_batch_size, weights=weights

train_dataset = interleave_datasets(
list(train_dataset.values()),
probabilities=[w / sum(weights) for w in weights],
seed=training_args.seed,
stopping_strategy="all_exhausted",
n_contiguous=training_args.per_device_train_batch_size,
)

# Due to the train dataset's structure, where every 'n' consecutive samples share the same modalities, we can't
# load all samples at once. Different sets of 'n' samples have different modalities. Therefore, we must load and
# process each set of 'n' samples separately.
Expand Down
Loading