-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
32 b #121
base: main
Are you sure you want to change the base?
Conversation
@@ -130,7 +130,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]: | |||
eval_batch_size = ( | |||
self.eval_batch_size | |||
if self.eval_batch_size is not None | |||
else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) | |||
else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could instead passed an updated evaluator callback in OLMo2-32B.py
:
.with_callback(
"lm_evaluator",
LMEvaluatorCallbackConfig(
eval_batch_size=<whatever you want>,
eval_dataset=NumpyDatasetConfig.from_data_mix(
DataMix.v3_small_ppl_validation,
name=NumpyDatasetType.padded_fsl,
mix_base_dir=root_dir,
sequence_length=dataset_config.effective_sequence_length,
tokenizer=tokenizer_config,
work_dir=get_work_dir(root_dir),
),
eval_interval=1000,
),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but I think this is better. I think we can default to 2x the training batch size. It should always work.
# import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore | ||
|
||
_fused_cross_entropy_loss = triton_ce_loss.cross_entropy_loss | ||
import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our in-house triton CE loss was copied directly from the flash-attn repo, so I don't see the point of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I took this back out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I want compiling and fused loss at the same time?
""" | ||
d_model = 5120 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a very narrow model then... are you sure about that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a clone of Qwen 32. The tradeoffs are, narrow d_model, wide FFN, GQA, lots of layers.
src/scripts/train/OLMo2-32B.py
Outdated
fused_loss=True, | ||
compile_loss=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the trepidation about the different loss implementations, but the way it was before was the most performant. This way will be slower and have a higher memory footprint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have some certainty that this will do the right thing? What happens if we take the 13B from a late checkpoint and run it?
src/scripts/train/OLMo2-32B.py
Outdated
enabled=False, | ||
cancel_check_interval=10, | ||
), | ||
).with_callback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just add this to the common callbacks.
"lm_evaluator": LMEvaluatorCallbackConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know that we want these for everything. Default should probably be only the new, blessed ones.
This reverts commit e27b91d.
src/olmo_core/io.py
Outdated
@@ -590,7 +594,7 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes | |||
) | |||
|
|||
|
|||
@retriable() | |||
@retriable(retry_condition=_gcs_is_retriable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This general approach sort of blows up our retry time from 10 mins to 30 mins. Sort of not a fan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But at least it looks like it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could always reduce the deadline/timeout
…ing fails" This reverts commit a0700e8.
This PR pulls the general important changes in from #121.
No description provided.