Skip to content

Commit

Permalink
ci: Add new benchmarking cluster profile (#3665)
Browse files Browse the repository at this point in the history
# Overview

Minor edits! This PR creates a new `benchmarking-arm` cluster-profile,
as well as cleans up some GitHub Actions outputs.
  • Loading branch information
raunakab authored Jan 15, 2025
1 parent 0afc55f commit 6157ad8
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 15 deletions.
28 changes: 19 additions & 9 deletions .github/ci-scripts/job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

from ray.job_submission import JobStatus, JobSubmissionClient

# We impose a 5min timeout here
# If any job does *not* finish in 5min, then we cancel it and mark the question as a "DNF" (did-not-finish).
TIMEOUT_S = 60 * 5


def parse_env_var_str(env_var_str: str) -> dict:
iter = map(
Expand All @@ -29,13 +33,17 @@ async def print_logs(logs):
print(lines, end="")


async def wait_on_job(logs, timeout_s):
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)
async def wait_on_job(logs, timeout_s) -> bool:
try:
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)
return False
except asyncio.exceptions.TimeoutError:
return True


@dataclass
class Result:
query: int
arguments: str
duration: timedelta
error_msg: Optional[str]

Expand Down Expand Up @@ -66,7 +74,7 @@ def submit_job(

results = []

for index, args in enumerate(list_of_entrypoint_args):
for args in list_of_entrypoint_args:
entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}"
print(f"{entrypoint=}")
start = datetime.now()
Expand All @@ -78,18 +86,20 @@ def submit_job(
},
)

asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))
timed_out = asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=TIMEOUT_S))

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
end = datetime.now()
duration = end - start
error_msg = None
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
error_msg = job_info.message
if timed_out:
error_msg = f"Job exceeded {TIMEOUT_S} second(s)"
else:
job_info = client.get_job_info(job_id)
error_msg = job_info.message

result = Result(query=index, duration=duration, error_msg=error_msg)
result = Result(arguments=args, duration=duration, error_msg=error_msg)
results.append(result)

output_file = output_dir / "out.csv"
Expand Down
16 changes: 15 additions & 1 deletion .github/ci-scripts/templatize_ray_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ class Metadata(BaseModel, extra="allow"):
sudo chmod 777 /tmp
fi""",
),
"benchmarking-arm": Profile(
instance_type="i8g.4xlarge",
image_id="ami-0d4eea77bb23270f4",
node_count=8,
ssh_user="ubuntu",
volume_mount=""" |
findmnt /tmp 1> /dev/null
code=$?
if [ $code -ne 0 ]; then
sudo mkfs.ext4 /dev/nvme0n1
sudo mount -t ext4 /dev/nvme0n1 /tmp
sudo chmod 777 /tmp
fi""",
),
}


Expand All @@ -71,7 +85,7 @@ class Metadata(BaseModel, extra="allow"):
parser.add_argument("--daft-wheel-url")
parser.add_argument("--daft-version")
parser.add_argument("--python-version", required=True)
parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86"])
parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86", "benchmarking-arm"])
parser.add_argument("--working-dir", required=True)
parser.add_argument("--entrypoint-script", required=True)
args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ on:
description: Cluster profile
type: choice
options:
- benchmarking-arm
- medium-x86
- debug_xs-x86
required: false
Expand Down Expand Up @@ -49,7 +50,7 @@ jobs:
uses: ./.github/workflows/build-commit.yaml
if: ${{ inputs.daft_version == '' && inputs.daft_wheel_url == '' }}
with:
arch: x86
arch: ${{ (inputs.cluster_profile == 'debug_xs-x86' || inputs.cluster_profile == 'medium-x86') && 'x86' || 'arm' }}
python_version: ${{ inputs.python_version }}
secrets:
ACTIONS_AWS_ROLE_ARN: ${{ secrets.ACTIONS_AWS_ROLE_ARN }}
Expand Down
4 changes: 2 additions & 2 deletions benchmarking/tpcds/ray_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def run(
{
"question": question,
"scale-factor": scale_factor,
"planning-time": explain_delta,
"execution-time": execute_delta,
"planning-time": str(explain_delta),
"execution-time": str(execute_delta),
}
)
f.write(stats)
Expand Down
2 changes: 1 addition & 1 deletion tools/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_name_and_commit_hash(branch_name: Optional[str]) -> tuple[str, str]:

def parse_questions(questions: Optional[str], total_number_of_questions: int) -> list[int]:
if questions is None:
return list(range(total_number_of_questions))
return list(range(1, total_number_of_questions + 1))
else:

def to_int(q: str) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tools/tpcds.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run(
)
parser.add_argument(
"--cluster-profile",
choices=["debug_xs-x86", "medium-x86"],
choices=["debug_xs-x86", "medium-x86", "benchmarking-arm"],
type=str,
required=False,
help="The ray cluster configuration to run on",
Expand Down

0 comments on commit 6157ad8

Please sign in to comment.