Skip to content

Commit

Permalink
MaxText Convergence Tests (#180)
Browse files Browse the repository at this point in the history
Add MaxText convergence tests for bf16 and int8
  • Loading branch information
gobbleturk authored Mar 5, 2024
1 parent f65c4fb commit 2be6407
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions dags/gcs_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IMAGENET_DIR = "gs://ml-auto-solutions/data/imagenet"
TFDS_DATA_DIR = "gs://ml-auto-solutions/data/tfds-data"
PAX_DIR = "gs://cloud-tpu-checkpoints/pax"
MAXTEXT_DIR = "gs://max-datasets-rogue"

# GCS bucket for output
BENCHMARK_OUTPUT_DIR = "gs://ml-auto-solutions/output/benchmark"
Expand Down
65 changes: 65 additions & 0 deletions dags/multipod/configs/maxtext_gke_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct configs for maxtext DAG on GKE."""

from xlml.apis import gcp_config, metric_config, task, test_config
from dags import test_owner
from dags.vm_resource import TpuVersion, Project, ClusterName
from typing import Iterable


def get_maxtext_gke_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
test_name: str,
docker_image: str,
test_owner: str,
run_model_cmds: Iterable[str],
cluster_name: str = ClusterName.V4_8_MULTISLICE_CLUSTER.value,
project_name: str = Project.TPU_PROD_ENV_MULTIPOD.value,
num_slices: int = 1,
dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.XLML_DATASET,
dataset_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value,
composer_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value,
) -> task.TpuXpkTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=dataset_name,
dataset_project=dataset_project,
composer_project=composer_project,
)

job_test_config = test_config.TpuGkeTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
),
test_name=test_name,
run_model_cmds=run_model_cmds,
set_up_cmds=None,
time_out_in_min=time_out_in_min,
task_owner=test_owner,
num_slices=num_slices,
cluster_name=cluster_name,
docker_image=docker_image,
)

return task.TpuXpkTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
)
66 changes: 66 additions & 0 deletions dags/multipod/maxtext_convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A DAG to run MaxText convergence tests for both bf16 and int8.
"""
import datetime
from airflow import models
from dags import composer_env, test_owner, gcs_bucket
from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName
from dags.multipod.configs import maxtext_gke_config
from dags.multipod.configs.common import SetupMode
from xlml.apis import gcp_config, metric_config, task, test_config

# Run once a day at 6 am UTC (10 pm PST)
SCHEDULED_TIME = "0 6 * * *" if composer_env.is_prod_env() else None

with models.DAG(
dag_id="maxtext_convergence",
schedule=SCHEDULED_TIME,
tags=["multipod_team", "maxtext", "stable"],
start_date=datetime.datetime(2024, 3, 1),
catchup=False,
concurrency=2,
) as dag:
current_time = datetime.datetime.now()
current_date = current_time.strftime("%Y-%m-%d")
base_output_directory = (
f"{gcs_bucket.XLML_OUTPUT_DIR}/maxtext/stable/automated/{current_date}"
)
dataset_path = gcs_bucket.MAXTEXT_DIR

steps = 10200 # Half Chinchilla
loss_threshold = 2.7

base_convergence_command = f"bash end_to_end/test_convergence_1b_params.sh OUTPUT_PATH={base_output_directory} DATASET_PATH={dataset_path} LOSS_THRESHOLD={loss_threshold} STEPS={steps}"
convergence_tests = {
"maxtext-convergence-bf16": ((base_convergence_command),),
"maxtext-convergence-int8": (
(f"export M_QUANTIZATION=int8; {base_convergence_command}"),
),
}

for test_name, run_command in convergence_tests.items():
maxtext_v4_configs_test = maxtext_gke_config.get_maxtext_gke_config(
tpu_version=TpuVersion.V4,
tpu_cores=128,
tpu_zone=Zone.US_CENTRAL2_B.value,
cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value,
time_out_in_min=300,
test_name=test_name,
run_model_cmds=run_command,
docker_image=DockerImage.MAXTEXT_JAX_STABLE.value,
test_owner=test_owner.MATT_D,
).run()
1 change: 1 addition & 0 deletions dags/test_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TONY_C = "Tony C."
JON_B = "Jon B."
RAYMOND_Z = "Raymond Z."
MATT_D = "Matt D."

# MLCompass
ORTI_B = "Orti B."
4 changes: 4 additions & 0 deletions dags/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""The file for common projects, zone, and runtime versions."""

import enum
import datetime


V5_NETWORKS_PREFIX = "projects/tpu-prod-env-automated"
Expand Down Expand Up @@ -116,6 +117,7 @@ class ClusterName(enum.Enum):
V4_32_CLUSTER = "mas-v4-32"
V5E_4_CLUSTER = "mas-v5e-4"
V5E_16_CLUSTER = "mas-v5e-16"
V4_8_MULTISLICE_CLUSTER = "v4-8-maxtext"
V4_128_MULTISLICE_CLUSTER = "v4-bodaborg"
V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg"
V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg"
Expand All @@ -126,3 +128,5 @@ class DockerImage(enum.Enum):

XPK_JAX_TEST = "gcr.io/cloud-ml-auto-solutions/xpk_jax_test:latest"
XPK_MAXTEXT_TEST = "gcr.io/tpu-prod-env-multipod/xpk_maxtext_test:latest"
MAXTEXT_JAX_STABLE = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:{datetime.datetime.today().strftime('%Y-%m-%d')}"
MAXTEXT_JAX_NIGHTLY = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:{datetime.datetime.today().strftime('%Y-%m-%d')}"

0 comments on commit 2be6407

Please sign in to comment.