diff --git a/dags/gcs_bucket.py b/dags/gcs_bucket.py index c00f2902..315fa0b2 100644 --- a/dags/gcs_bucket.py +++ b/dags/gcs_bucket.py @@ -19,6 +19,7 @@ IMAGENET_DIR = "gs://ml-auto-solutions/data/imagenet" TFDS_DATA_DIR = "gs://ml-auto-solutions/data/tfds-data" PAX_DIR = "gs://cloud-tpu-checkpoints/pax" +MAXTEXT_DIR = "gs://max-datasets-rogue" # GCS bucket for output BENCHMARK_OUTPUT_DIR = "gs://ml-auto-solutions/output/benchmark" diff --git a/dags/multipod/configs/maxtext_gke_config.py b/dags/multipod/configs/maxtext_gke_config.py new file mode 100644 index 00000000..2dd81bab --- /dev/null +++ b/dags/multipod/configs/maxtext_gke_config.py @@ -0,0 +1,65 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for maxtext DAG on GKE.""" + +from xlml.apis import gcp_config, metric_config, task, test_config +from dags import test_owner +from dags.vm_resource import TpuVersion, Project, ClusterName +from typing import Iterable + + +def get_maxtext_gke_config( + tpu_version: TpuVersion, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + test_name: str, + docker_image: str, + test_owner: str, + run_model_cmds: Iterable[str], + cluster_name: str = ClusterName.V4_8_MULTISLICE_CLUSTER.value, + project_name: str = Project.TPU_PROD_ENV_MULTIPOD.value, + num_slices: int = 1, + dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.XLML_DATASET, + dataset_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value, + composer_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value, +) -> task.TpuXpkTask: + job_gcp_config = gcp_config.GCPConfig( + project_name=project_name, + zone=tpu_zone, + dataset_name=dataset_name, + dataset_project=dataset_project, + composer_project=composer_project, + ) + + job_test_config = test_config.TpuGkeTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + ), + test_name=test_name, + run_model_cmds=run_model_cmds, + set_up_cmds=None, + time_out_in_min=time_out_in_min, + task_owner=test_owner, + num_slices=num_slices, + cluster_name=cluster_name, + docker_image=docker_image, + ) + + return task.TpuXpkTask( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/dags/multipod/maxtext_convergence.py b/dags/multipod/maxtext_convergence.py new file mode 100644 index 00000000..9262c14a --- /dev/null +++ b/dags/multipod/maxtext_convergence.py @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A DAG to run MaxText convergence tests for both bf16 and int8. +""" +import datetime +from airflow import models +from dags import composer_env, test_owner, gcs_bucket +from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName +from dags.multipod.configs import maxtext_gke_config +from dags.multipod.configs.common import SetupMode +from xlml.apis import gcp_config, metric_config, task, test_config + +# Run once a day at 6 am UTC (10 pm PST) +SCHEDULED_TIME = "0 6 * * *" if composer_env.is_prod_env() else None + +with models.DAG( + dag_id="maxtext_convergence", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "maxtext", "stable"], + start_date=datetime.datetime(2024, 3, 1), + catchup=False, + concurrency=2, +) as dag: + current_time = datetime.datetime.now() + current_date = current_time.strftime("%Y-%m-%d") + base_output_directory = ( + f"{gcs_bucket.XLML_OUTPUT_DIR}/maxtext/stable/automated/{current_date}" + ) + dataset_path = gcs_bucket.MAXTEXT_DIR + + steps = 10200 # Half Chinchilla + loss_threshold = 2.7 + + base_convergence_command = f"bash end_to_end/test_convergence_1b_params.sh OUTPUT_PATH={base_output_directory} DATASET_PATH={dataset_path} LOSS_THRESHOLD={loss_threshold} STEPS={steps}" + convergence_tests = { + "maxtext-convergence-bf16": ((base_convergence_command),), + "maxtext-convergence-int8": ( + (f"export M_QUANTIZATION=int8; {base_convergence_command}"), + ), + } + + for test_name, run_command in convergence_tests.items(): + maxtext_v4_configs_test = maxtext_gke_config.get_maxtext_gke_config( + tpu_version=TpuVersion.V4, + tpu_cores=128, + tpu_zone=Zone.US_CENTRAL2_B.value, + cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value, + time_out_in_min=300, + test_name=test_name, + run_model_cmds=run_command, + docker_image=DockerImage.MAXTEXT_JAX_STABLE.value, + test_owner=test_owner.MATT_D, + ).run() diff --git a/dags/test_owner.py b/dags/test_owner.py index feb21745..53490ff8 100644 --- a/dags/test_owner.py +++ b/dags/test_owner.py @@ -33,6 +33,7 @@ TONY_C = "Tony C." JON_B = "Jon B." RAYMOND_Z = "Raymond Z." +MATT_D = "Matt D." # MLCompass ORTI_B = "Orti B." diff --git a/dags/vm_resource.py b/dags/vm_resource.py index 5c01cc49..01814fe2 100644 --- a/dags/vm_resource.py +++ b/dags/vm_resource.py @@ -15,6 +15,7 @@ """The file for common projects, zone, and runtime versions.""" import enum +import datetime V5_NETWORKS_PREFIX = "projects/tpu-prod-env-automated" @@ -116,6 +117,7 @@ class ClusterName(enum.Enum): V4_32_CLUSTER = "mas-v4-32" V5E_4_CLUSTER = "mas-v5e-4" V5E_16_CLUSTER = "mas-v5e-16" + V4_8_MULTISLICE_CLUSTER = "v4-8-maxtext" V4_128_MULTISLICE_CLUSTER = "v4-bodaborg" V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg" V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg" @@ -126,3 +128,5 @@ class DockerImage(enum.Enum): XPK_JAX_TEST = "gcr.io/cloud-ml-auto-solutions/xpk_jax_test:latest" XPK_MAXTEXT_TEST = "gcr.io/tpu-prod-env-multipod/xpk_maxtext_test:latest" + MAXTEXT_JAX_STABLE = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:{datetime.datetime.today().strftime('%Y-%m-%d')}" + MAXTEXT_JAX_NIGHTLY = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:{datetime.datetime.today().strftime('%Y-%m-%d')}"