Skip to content

Commit

Permalink
Deferred and non-deferred operator cancels armada jobs when task is m…
Browse files Browse the repository at this point in the history
…arked as failed.

* Deferred and non-deferred operator cancels armada jobs when task is marked as failed.
* Minor test / quality of life fixes.
  • Loading branch information
masipauskas authored Aug 22, 2024
1 parent 6923240 commit 9b6f770
Show file tree
Hide file tree
Showing 23 changed files with 835 additions and 676 deletions.
86 changes: 63 additions & 23 deletions docs/python_airflow_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ and handles job cancellation if the Airflow task is killed.



#### _property_ client(_: ArmadaClien_ )

#### execute(context)
Submits the job to Armada and polls for completion.

Expand All @@ -76,6 +74,10 @@ Submits the job to Armada and polls for completion.



#### _property_ hook(_: ArmadaHoo_ )

#### lookout_url(job_id)

#### on_kill()
Override this method to clean up subprocesses when a task instance gets killed.

Expand Down Expand Up @@ -117,6 +119,8 @@ Args:


#### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ )

#### template_fields_renderers(_: Dict[str, str_ _ = {'job_request': 'py'_ )
Initializes a new ArmadaOperator.


Expand Down Expand Up @@ -158,8 +162,6 @@ acknowledged by Armada.
:type job_acknowledgement_timeout: int
:param kwargs: Additional keyword arguments to pass to the BaseOperator.


### armada.operators.armada.log_exceptions(method)
## armada.triggers.armada module

## armada.auth module
Expand All @@ -176,18 +178,10 @@ Bases: `Protocol`
str



#### serialize()

* **Return type**

*Tuple*[str, *Dict*[str, *Any*]]


## armada.model module


### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None, auth_details=None)
### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None)
Bases: `object`


Expand All @@ -197,32 +191,31 @@ Bases: `object`
* **target** (*str*) –


* **options** (*Sequence**[**Tuple**[**str**, **Any**]**] **| **None*) –
* **options** (*Optional**[**Sequence**[**Tuple**[**str**, **Any**]**]**]*) –


* **compression** (*Compression** | **None*) –
* **compression** (*Optional**[**grpc.Compression**]*) –


* **auth** (*AuthMetadataPlugin** | **None*) –
* **auth** (*Optional**[**grpc.AuthMetadataPlugin**]*) –


* **auth_details** (*Dict**[**str**, **Any**] **| **None*) –

#### _static_ deserialize(data, version)

* **Parameters**

#### aio_channel()

* **Return type**

* **data** (*dict**[**str**, **Any**]*) –

*Channel*

* **version** (*int*) –


#### channel()

* **Return type**

*Channel*
*GrpcChannelArgs*



Expand All @@ -231,3 +224,50 @@ Bases: `object`
* **Return type**

*Dict*[str, *Any*]



### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', submit_time: 'DateTime', cluster: 'Optional[str]' = None, last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN')
Bases: `object`


* **Parameters**


* **armada_queue** (*str*) –


* **job_id** (*str*) –


* **job_set_id** (*str*) –


* **submit_time** (*DateTime*) –


* **cluster** (*str** | **None*) –


* **last_log_time** (*DateTime** | **None*) –


* **job_state** (*str*) –



#### armada_queue(_: st_ )

#### cluster(_: str | Non_ _ = Non_ )

#### job_id(_: st_ )

#### job_set_id(_: st_ )

#### job_state(_: st_ _ = 'UNKNOWN_ )

#### last_log_time(_: DateTime | Non_ _ = Non_ )

#### _property_ state(_: JobStat_ )

#### submit_time(_: DateTim_ )
4 changes: 4 additions & 0 deletions third_party/airflow/armada/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from airflow.serialization.serde import _extra_allowed

_extra_allowed.add("armada.model.RunningJobContext")
_extra_allowed.add("armada.model.GrpcChannelArgs")
4 changes: 1 addition & 3 deletions third_party/airflow/armada/auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Any, Dict, Protocol, Tuple
from typing import Protocol

""" We use this interface for objects fetching Kubernetes auth tokens. Since
it's used within the Trigger, it must be serialisable."""


class TokenRetriever(Protocol):
def get_token(self) -> str: ...

def serialize(self) -> Tuple[str, Dict[str, Any]]: ...
129 changes: 129 additions & 0 deletions third_party/airflow/armada/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import dataclasses
import json
import threading
from functools import cached_property
from typing import Dict, Optional

import grpc
from airflow.exceptions import AirflowException
from airflow.serialization.serde import serialize
from airflow.utils.log.logging_mixin import LoggingMixin
from armada.model import GrpcChannelArgs
from armada_client.armada.job_pb2 import JobRunDetails
from armada_client.armada.submit_pb2 import JobSubmitRequestItem
from armada_client.client import ArmadaClient
from armada_client.typings import JobState
from pendulum import DateTime

from .model import RunningJobContext


class ArmadaClientFactory:
CLIENTS_LOCK = threading.Lock()
CLIENTS: Dict[str, ArmadaClient] = {}

@staticmethod
def client_for(args: GrpcChannelArgs) -> ArmadaClient:
"""
Armada clients, maintain GRPC connection to Armada API.
We cache them per channel args config in class level cache.
Access to this method can be from multiple-threads.
"""
channel_args_key = json.dumps(serialize(args))
with ArmadaClientFactory.CLIENTS_LOCK:
if channel_args_key not in ArmadaClientFactory.CLIENTS:
ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient(
channel=ArmadaClientFactory._create_channel(args)
)
return ArmadaClientFactory.CLIENTS[channel_args_key]

@staticmethod
def _create_channel(args: GrpcChannelArgs) -> grpc.Channel:
if args.auth is None:
return grpc.insecure_channel(
target=args.target, options=args.options, compression=args.compression
)

return grpc.secure_channel(
target=args.target,
options=args.options,
compression=args.compression,
credentials=grpc.composite_channel_credentials(
grpc.ssl_channel_credentials(),
grpc.metadata_call_credentials(args.auth),
),
)


class ArmadaHook(LoggingMixin):
def __init__(self, args: GrpcChannelArgs):
self.args = args

@cached_property
def client(self):
return ArmadaClientFactory.client_for(self.args)

def cancel_job(self, job_context: RunningJobContext) -> RunningJobContext:
try:
result = self.client.cancel_jobs(
queue=job_context.armada_queue,
job_set_id=job_context.job_set_id,
job_id=job_context.job_id,
)
if len(list(result.cancelled_ids)) > 0:
self.log.info(f"Cancelled job with id {result.cancelled_ids}")
else:
self.log.warning(f"Failed to cancel job with id {job_context.job_id}")
except Exception as e:
self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}")
finally:
return dataclasses.replace(job_context, job_state=JobState.CANCELLED.name)

def submit_job(
self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem
) -> RunningJobContext:
resp = self.client.submit_jobs(queue, job_set_id, [job_request])
num_responses = len(resp.job_response_items)

# We submitted exactly one job to armada, so we expect a single response
if num_responses != 1:
raise AirflowException(
f"No valid received from Armada (expected 1 job to be created "
f"but got {num_responses})"
)
job = resp.job_response_items[0]

# Throw if armada told us we had submitted something bad
if job.error:
raise AirflowException(f"Error submitting job to Armada: {job.error}")

return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow())

def refresh_context(
self, job_context: RunningJobContext, tracking_url: str
) -> RunningJobContext:
response = self.client.get_job_status([job_context.job_id])
state = JobState(response.job_states[job_context.job_id])
if state != job_context.state:
self.log.info(
f"job {job_context.job_id} is in state: {state.name}. "
f"{tracking_url}"
)

cluster = job_context.cluster
if not cluster:
# Job is running / or completed already
if state == JobState.RUNNING or state.is_terminal():
run_details = self._get_latest_job_run_details(job_context.job_id)
if run_details:
cluster = run_details.cluster
return dataclasses.replace(job_context, job_state=state.name, cluster=cluster)

def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]:
job_details = self.client.get_job_details([job_id]).job_details[job_id]
if job_details and job_details.latest_run_id:
for run in job_details.job_runs:
if run.run_id == job_details.latest_run_id:
return run
return None
Loading

0 comments on commit 9b6f770

Please sign in to comment.