Skip to content

Commit

Permalink
Adding query API endpoints to python client (#3650)
Browse files Browse the repository at this point in the history
* Adding missing methods to python client (#154)

* Adding proto methods

* Adding tests and docs

* Making tox formatting pass

* Fixing docs

* Bump version

---------

Co-authored-by: Mustafa Ilyas <[email protected]>
  • Loading branch information
MustafaI and mustafai-gr authored Jun 5, 2024
1 parent ee1ee05 commit 553da79
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 12 deletions.
14 changes: 14 additions & 0 deletions client/python/armada_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
try:
from .typings import JobState
from ._proto_methods import is_active, is_terminal

JobState.is_active = is_active
JobState.is_terminal = is_terminal

del is_active, is_terminal, JobState
except ImportError:
"""
Import errors occur during proto generation, where certain
modules import types that don't exist yet. We can safely ignore these failures
"""
pass
38 changes: 38 additions & 0 deletions client/python/armada_client/_proto_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from armada_client.typings import JobState


def is_terminal(self) -> bool:
"""
Determines if a job state is terminal.
Terminal states indicate that a job has completed its lifecycle,
whether successfully or due to failure.
:param state: The current state of the job.
:type state: JobState
:returns: True if the job state is terminal, False if it is active.
:rtype: bool
"""
terminal_states = {
JobState.SUCCEEDED,
JobState.FAILED,
JobState.CANCELLED,
JobState.PREEMPTED,
}
return self.value in terminal_states


def is_active(self) -> bool:
"""
Determines if a job state is active.
Active states indicate that a job is still running or in a non-terminal state.
:param state: The current state of the job.
:type state: JobState
:returns: True if the job state is active, False if it is terminal.
:rtype: bool
"""
return not is_terminal(self.value)
47 changes: 46 additions & 1 deletion client/python/armada_client/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
submit_pb2,
submit_pb2_grpc,
health_pb2,
job_pb2,
job_pb2_grpc,
)
from armada_client.event import Event
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
) -> None:
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.job_stub = job_pb2_grpc.JobsStub(channel)
self.event_timeout = event_timeout

async def get_job_events_stream(
Expand Down Expand Up @@ -169,7 +172,7 @@ async def event_health(self) -> health_pb2.HealthCheckResponse:
async def submit_jobs(
self, queue: str, job_set_id: str, job_request_items
) -> AsyncIterator[submit_pb2.JobSubmitResponse]:
"""Submit a armada job.
"""Submit an armada job.
Uses SubmitJobs RPC to submit a job.
Expand All @@ -185,6 +188,48 @@ async def submit_jobs(
response = await self.submit_stub.SubmitJobs(request)
return response

async def get_job_status(self, job_ids: List[str]) -> job_pb2.JobStatusResponse:
"""
Asynchronously retrieves the status of a list of jobs from Armada.
:param job_ids: A list of unique job identifiers.
:type job_ids: List[str]
:returns: The response from the server containing the job status.
:rtype: JobStatusResponse
"""
req = job_pb2.JobStatusRequest(job_ids=job_ids)
resp = await self.job_stub.GetJobStatus(req)
return resp

async def get_job_details(self, job_ids: List[str]) -> job_pb2.JobDetailsResponse:
"""
Asynchronously retrieves the details of a job from Armada.
:param job_ids: A list of unique job identifiers.
:type job_ids: List[str]
:returns: The Armada job details response.
"""
req = job_pb2.JobDetailsRequest(job_ids=job_ids, expand_job_run=True)
resp = await self.job_stub.GetJobDetails(req)
return resp

async def get_job_run_details(
self, run_ids: List[str]
) -> job_pb2.JobRunDetailsResponse:
"""
Asynchronously retrieves the details of a job run from Armada.
:param run_ids: A list of unique job run identifiers.
:type run_ids: List[str]
:returns: The Armada run details response.
"""
req = job_pb2.JobRunDetailsRequest(run_ids=run_ids)
resp = await self.job_stub.GetJobRunDetails(req)
return resp

async def cancel_jobs(
self,
queue: str,
Expand Down
42 changes: 41 additions & 1 deletion client/python/armada_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
submit_pb2,
submit_pb2_grpc,
health_pb2,
job_pb2,
job_pb2_grpc,
)
from armada_client.event import Event
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(self, channel, event_timeout: timedelta = timedelta(minutes=15)):
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.event_timeout = event_timeout
self.job_stub = job_pb2_grpc.JobsStub(channel)

def get_job_events_stream(
self,
Expand Down Expand Up @@ -161,10 +164,47 @@ def event_health(self) -> health_pb2.HealthCheckResponse:
"""
return self.event_stub.Health(request=empty_pb2.Empty())

def get_job_status(self, job_ids: List[str]) -> job_pb2.JobStatusResponse:
"""
Retrieves the status of a list of jobs from Armada.
:param job_ids: A list of unique job identifiers.
:type job_ids: List[str]
:returns: The response from the server containing the job status.
:rtype: JobStatusResponse
"""
req = job_pb2.JobStatusRequest(job_ids=job_ids)
return self.job_stub.GetJobStatus(req)

def get_job_details(self, job_ids: List[str]) -> job_pb2.JobDetailsResponse:
"""
Retrieves the details of a job from Armada.
:param job_ids: A list of unique job identifiers.
:type job_ids: List[str]
:returns: The Armada job details response.
"""
req = job_pb2.JobDetailsRequest(job_ids=job_ids, expand_job_run=True)
return self.job_stub.GetJobDetails(req)

def get_job_run_details(self, run_ids: List[str]) -> job_pb2.JobRunDetailsResponse:
"""
Retrieves the details of a job run from Armada.
:param run_ids: A list of unique job run identifiers.
:type run_ids: List[str]
:returns: The Armada run details response.
"""
req = job_pb2.JobRunDetailsRequest(run_ids=run_ids)
return self.job_stub.GetJobRunDetails(req)

def submit_jobs(
self, queue: str, job_set_id: str, job_request_items
) -> submit_pb2.JobSubmitResponse:
"""Submit a armada job.
"""Submit an armada job.
Uses SubmitJobs RPC to submit a job.
Expand Down
2 changes: 1 addition & 1 deletion client/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "armada_client"
version = "0.3.2"
version = "0.3.3"
description = "Armada gRPC API python client"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
48 changes: 48 additions & 0 deletions client/python/tests/unit/server_mock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from google.protobuf import empty_pb2

from armada_client.armada import (
submit_pb2_grpc,
submit_pb2,
event_pb2,
event_pb2_grpc,
health_pb2,
job_pb2_grpc,
job_pb2,
)
from armada_client.armada.job_pb2 import JobRunState
from armada_client.armada.submit_pb2 import JobState


class SubmitService(submit_pb2_grpc.SubmitServicer):
Expand Down Expand Up @@ -101,3 +106,46 @@ def Health(self, request, context):
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVING
)


class QueryAPIService(job_pb2_grpc.JobsServicer):
DEFAULT_JOB_DETAILS = {
"queue": "test_queue",
"jobset": "test_jobset",
"namespace": "test_namespace",
"state": JobState.RUNNING,
"cancel_reason": "",
"latest_run_id": "0",
}

DEFAULT_JOB_RUN_DETAILS = {
"job_id": "0",
"cluster": "test_cluster",
"node": "test_node",
"state": JobRunState.RUN_STATE_RUNNING,
}

def GetJobStatus(self, request, context):
return job_pb2.JobStatusResponse(
job_states={job: JobState.RUNNING for job in request.job_ids}
)

def GetJobDetails(self, request, context):
return job_pb2.JobDetailsResponse(
job_details={
job: job_pb2.JobDetails(
job_id=job, **QueryAPIService.DEFAULT_JOB_DETAILS
)
for job in request.job_ids
}
)

def GetJobRunDetails(self, request, context):
return job_pb2.JobRunDetailsResponse(
job_run_details={
run: job_pb2.JobRunDetails(
run_id=run, **QueryAPIService.DEFAULT_JOB_RUN_DETAILS
)
for run in request.run_ids
}
)
49 changes: 45 additions & 4 deletions client/python/tests/unit/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,32 @@
import pytest
import pytest_asyncio

from server_mock import EventService, SubmitService

from armada_client.armada import event_pb2_grpc, submit_pb2_grpc, submit_pb2, health_pb2
from armada_client.typings import JobState
from armada_client.armada.job_pb2 import JobRunState
from server_mock import EventService, SubmitService, QueryAPIService

from armada_client.armada import (
event_pb2_grpc,
submit_pb2_grpc,
submit_pb2,
health_pb2,
job_pb2_grpc,
)
from armada_client.asyncio_client import ArmadaAsyncIOClient
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
from armada_client.k8s.io.apimachinery.pkg.api.resource import (
generated_pb2 as api_resource,
)

from armada_client.permissions import Permissions, Subject
from armada_client.typings import JobState


@pytest.fixture
def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server)
server.add_insecure_port("[::]:50051")
server.start()
yield
Expand Down Expand Up @@ -302,3 +310,36 @@ async def test_health_submit(aio_client):
async def test_health_event(aio_client):
health = await aio_client.event_health()
assert health.SERVING == health_pb2.HealthCheckResponse.SERVING


@pytest.mark.asyncio
async def test_job_status(aio_client):
await test_create_queue(aio_client)
await test_submit_job(aio_client)

job_status_response = await aio_client.get_job_status(["job-1"])
assert job_status_response.job_states["job-1"] == submit_pb2.JobState.RUNNING


@pytest.mark.asyncio
async def test_job_details(aio_client):
await test_create_queue(aio_client)
await test_submit_job(aio_client)

job_details_response = await aio_client.get_job_details(["job-1"])
job_details = job_details_response.job_details
assert job_details["job-1"].state == submit_pb2.JobState.RUNNING
assert job_details["job-1"].job_id == "job-1"
assert job_details["job-1"].queue == "test_queue"


@pytest.mark.asyncio
async def test_job_run_details(aio_client):
await test_create_queue(aio_client)
await test_submit_job(aio_client)

run_details_response = await aio_client.get_job_run_details(["run-1"])
run_details = run_details_response.job_run_details
assert run_details["run-1"].state == JobRunState.RUN_STATE_RUNNING
assert run_details["run-1"].run_id == "run-1"
assert run_details["run-1"].cluster == "test_cluster"
Loading

0 comments on commit 553da79

Please sign in to comment.