Skip to content

Commit

Permalink
Update mxla test to use llama3 8B and remove v4 tests (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondzouu authored Feb 3, 2025
1 parent 29800f6 commit 6332084
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 59 deletions.
13 changes: 0 additions & 13 deletions dags/common/quarantined_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,6 @@ class QuarantineTests:
"mxla-gpt3-6b-nightly-gke-8xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
# DAG: mxla_maxtext_nightly_gke
"mxla-maxtext-nightly-gke-v5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-2xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-4xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-8xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
# DAG: maxtext_trillium_configs_perf
"maxtext-llama2_70b_4096-stable-3-2xv6e-256": TestInfo(
team.PERFORMANCE, "2024-11-12"
Expand Down
6 changes: 3 additions & 3 deletions dags/common/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ class XpkClusters:
zone=Zone.US_CENTRAL2_B.value,
)
TPU_V5P_8_CLUSTER = XpkClusterConfig(
name="v5p-8-bodaborg-us-east5-a",
name="v5p-8-bodaborg-europe-west4-b",
device_version=TpuVersion.V5P,
core_count=8,
project=Project.TPU_PROD_ENV_LARGE_CONT.value,
zone=Zone.US_EAST5_A.value,
project=Project.CLOUD_TPU_MULTIPOD_DEV.value,
zone=Zone.EUROPE_WEST4_B.value,
)
TPU_V5E_256_CLUSTER = XpkClusterConfig(
name="v5e-256-bodaborg-europe-west4",
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/configs/gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_gke_maxtext_nightly_config(
f" python3 MaxText/train.py MaxText/configs/base.yml run_name={run_name}"
f" base_output_directory={base_output_directory}"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" per_device_batch_size=12 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'"
" model_name=llama3-8b per_device_batch_size=12 reuse_example_batch=1 metrics_file='metrics.txt'"
" steps=50 enable_checkpointing=false profiler=xplane upload_all_profiler_results=true skip_first_n_steps_for_profiler=10 profiler_steps=10 gcs_metrics=true"
),
)
Expand Down
46 changes: 4 additions & 42 deletions dags/multipod/mxla_maxtext_nightly_gke.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,43 +40,12 @@
group_id="Quarantine", dag=dag, prefix_group_id=False
)

maxtext_nightly_1slice_v4_8 = gke_config.get_gke_maxtext_nightly_config(
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_2slice_v4_8 = gke_config.get_gke_maxtext_nightly_config(
num_slices=2,
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_4slice_v4_8 = gke_config.get_gke_maxtext_nightly_config(
num_slices=4,
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_8slice_v4_8 = gke_config.get_gke_maxtext_nightly_config(
num_slices=8,
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_1slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config(
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
test_owner=test_owner.RAYMOND_Z,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_2slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config(
Expand All @@ -85,7 +54,7 @@
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
test_owner=test_owner.RAYMOND_Z,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_4slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config(
Expand All @@ -94,7 +63,7 @@
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
test_owner=test_owner.RAYMOND_Z,
).run_with_quarantine(quarantine_task_group)

maxtext_nightly_8slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config(
Expand All @@ -103,16 +72,9 @@
time_out_in_min=60,
test_name=default_test_name,
docker_image=jax_nightly_image.value,
test_owner=test_owner.TONY_C,
test_owner=test_owner.RAYMOND_Z,
).run_with_quarantine(quarantine_task_group)

(
maxtext_nightly_1slice_v4_8
>> maxtext_nightly_2slice_v4_8
>> maxtext_nightly_4slice_v4_8
>> maxtext_nightly_8slice_v4_8
)

(
maxtext_nightly_1slice_v5p_8
>> maxtext_nightly_2slice_v5p_8
Expand Down

0 comments on commit 6332084

Please sign in to comment.