diff --git a/dags/multipod/maxtext_checkpointing.py b/dags/multipod/maxtext_checkpointing.py index 17bd052ea..90a3f9461 100644 --- a/dags/multipod/maxtext_checkpointing.py +++ b/dags/multipod/maxtext_checkpointing.py @@ -55,17 +55,19 @@ for accelerator, slices in test_configs.items(): cores = accelerator.rsplit("-", maxsplit=1)[-1] for slice_num in slices: - command = ( - "bash end_to_end/test_checkpointing.sh" - f" checkpointing-{mode.value}-{slice_num}x-{accelerator}" - f" {base_output_directory} {dataset_path} true", - ) - maxtext_v4_configs_test = gke_config.get_gke_config( - num_slices=slice_num, - cluster=clusters[accelerator], - time_out_in_min=60, - test_name=f"maxtext-checkpointing-{mode.value}", - run_model_cmds=command, - docker_image=image.value, - test_owner=test_owner.SURBHI_J, - ).run() + for chkpt_mode in ["sync", "async"]: + async_checkpointing = chkpt_mode == "async" + command = ( + "bash end_to_end/test_checkpointing.sh" + f" checkpointing-{mode.value}-{slice_num}x-{accelerator}-{chkpt_mode}" + f" {base_output_directory} {dataset_path} true tfds autoselected {async_checkpointing}" + ) + maxtext_v4_configs_test = gke_config.get_gke_config( + num_slices=slice_num, + cluster=clusters[accelerator], + time_out_in_min=60, + test_name=f"maxtext-checkpointing-{mode.value}-{chkpt_mode}", + run_model_cmds=command, + docker_image=image.value, + test_owner=test_owner.SURBHI_J, + ).run()