diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 29709313ac9..32fa081fe83 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -58,8 +58,6 @@ and handles job cancellation if the Airflow task is killed. -#### _property_ client(_: ArmadaClien_ ) - #### execute(context) Submits the job to Armada and polls for completion. @@ -76,6 +74,10 @@ Submits the job to Armada and polls for completion. +#### _property_ hook(_: ArmadaHoo_ ) + +#### lookout_url(job_id) + #### on_kill() Override this method to clean up subprocesses when a task instance gets killed. @@ -117,6 +119,8 @@ Args: #### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ ) + +#### template_fields_renderers(_: Dict[str, str_ _ = {'job_request': 'py'_ ) Initializes a new ArmadaOperator. @@ -158,8 +162,6 @@ acknowledged by Armada. :type job_acknowledgement_timeout: int :param kwargs: Additional keyword arguments to pass to the BaseOperator. - -### armada.operators.armada.log_exceptions(method) ## armada.triggers.armada module ## armada.auth module @@ -176,18 +178,10 @@ Bases: `Protocol` str - -#### serialize() - -* **Return type** - - *Tuple*[str, *Dict*[str, *Any*]] - - ## armada.model module -### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None, auth_details=None) +### _class_ armada.model.GrpcChannelArgs(target, options=None, compression=None, auth=None) Bases: `object` @@ -197,32 +191,31 @@ Bases: `object` * **target** (*str*) – - * **options** (*Sequence**[**Tuple**[**str**, **Any**]**] **| **None*) – + * **options** (*Optional**[**Sequence**[**Tuple**[**str**, **Any**]**]**]*) – - * **compression** (*Compression** | **None*) – + * **compression** (*Optional**[**grpc.Compression**]*) – - * **auth** (*AuthMetadataPlugin** | **None*) – + * **auth** (*Optional**[**grpc.AuthMetadataPlugin**]*) – - * **auth_details** (*Dict**[**str**, **Any**] **| **None*) – +#### _static_ deserialize(data, version) +* **Parameters** -#### aio_channel() - -* **Return type** + + * **data** (*dict**[**str**, **Any**]*) – - *Channel* + * **version** (*int*) – -#### channel() * **Return type** - *Channel* + *GrpcChannelArgs* @@ -231,3 +224,50 @@ Bases: `object` * **Return type** *Dict*[str, *Any*] + + + +### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', submit_time: 'DateTime', cluster: 'Optional[str]' = None, last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN') +Bases: `object` + + +* **Parameters** + + + * **armada_queue** (*str*) – + + + * **job_id** (*str*) – + + + * **job_set_id** (*str*) – + + + * **submit_time** (*DateTime*) – + + + * **cluster** (*str** | **None*) – + + + * **last_log_time** (*DateTime** | **None*) – + + + * **job_state** (*str*) – + + + +#### armada_queue(_: st_ ) + +#### cluster(_: str | Non_ _ = Non_ ) + +#### job_id(_: st_ ) + +#### job_set_id(_: st_ ) + +#### job_state(_: st_ _ = 'UNKNOWN_ ) + +#### last_log_time(_: DateTime | Non_ _ = Non_ ) + +#### _property_ state(_: JobStat_ ) + +#### submit_time(_: DateTim_ ) diff --git a/third_party/airflow/armada/__init__.py b/third_party/airflow/armada/__init__.py new file mode 100644 index 00000000000..807a199de85 --- /dev/null +++ b/third_party/airflow/armada/__init__.py @@ -0,0 +1,4 @@ +from airflow.serialization.serde import _extra_allowed + +_extra_allowed.add("armada.model.RunningJobContext") +_extra_allowed.add("armada.model.GrpcChannelArgs") diff --git a/third_party/airflow/armada/auth.py b/third_party/airflow/armada/auth.py index 16275dbc343..6bf45df780f 100644 --- a/third_party/airflow/armada/auth.py +++ b/third_party/airflow/armada/auth.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Protocol, Tuple +from typing import Protocol """ We use this interface for objects fetching Kubernetes auth tokens. Since it's used within the Trigger, it must be serialisable.""" @@ -6,5 +6,3 @@ class TokenRetriever(Protocol): def get_token(self) -> str: ... - - def serialize(self) -> Tuple[str, Dict[str, Any]]: ... diff --git a/third_party/airflow/armada/hooks.py b/third_party/airflow/armada/hooks.py new file mode 100644 index 00000000000..a894d09249e --- /dev/null +++ b/third_party/airflow/armada/hooks.py @@ -0,0 +1,129 @@ +import dataclasses +import json +import threading +from functools import cached_property +from typing import Dict, Optional + +import grpc +from airflow.exceptions import AirflowException +from airflow.serialization.serde import serialize +from airflow.utils.log.logging_mixin import LoggingMixin +from armada.model import GrpcChannelArgs +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.client import ArmadaClient +from armada_client.typings import JobState +from pendulum import DateTime + +from .model import RunningJobContext + + +class ArmadaClientFactory: + CLIENTS_LOCK = threading.Lock() + CLIENTS: Dict[str, ArmadaClient] = {} + + @staticmethod + def client_for(args: GrpcChannelArgs) -> ArmadaClient: + """ + Armada clients, maintain GRPC connection to Armada API. + We cache them per channel args config in class level cache. + + Access to this method can be from multiple-threads. + """ + channel_args_key = json.dumps(serialize(args)) + with ArmadaClientFactory.CLIENTS_LOCK: + if channel_args_key not in ArmadaClientFactory.CLIENTS: + ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( + channel=ArmadaClientFactory._create_channel(args) + ) + return ArmadaClientFactory.CLIENTS[channel_args_key] + + @staticmethod + def _create_channel(args: GrpcChannelArgs) -> grpc.Channel: + if args.auth is None: + return grpc.insecure_channel( + target=args.target, options=args.options, compression=args.compression + ) + + return grpc.secure_channel( + target=args.target, + options=args.options, + compression=args.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(args.auth), + ), + ) + + +class ArmadaHook(LoggingMixin): + def __init__(self, args: GrpcChannelArgs): + self.args = args + + @cached_property + def client(self): + return ArmadaClientFactory.client_for(self.args) + + def cancel_job(self, job_context: RunningJobContext) -> RunningJobContext: + try: + result = self.client.cancel_jobs( + queue=job_context.armada_queue, + job_set_id=job_context.job_set_id, + job_id=job_context.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {job_context.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") + finally: + return dataclasses.replace(job_context, job_state=JobState.CANCELLED.name) + + def submit_job( + self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem + ) -> RunningJobContext: + resp = self.client.submit_jobs(queue, job_set_id, [job_request]) + num_responses = len(resp.job_response_items) + + # We submitted exactly one job to armada, so we expect a single response + if num_responses != 1: + raise AirflowException( + f"No valid received from Armada (expected 1 job to be created " + f"but got {num_responses})" + ) + job = resp.job_response_items[0] + + # Throw if armada told us we had submitted something bad + if job.error: + raise AirflowException(f"Error submitting job to Armada: {job.error}") + + return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow()) + + def refresh_context( + self, job_context: RunningJobContext, tracking_url: str + ) -> RunningJobContext: + response = self.client.get_job_status([job_context.job_id]) + state = JobState(response.job_states[job_context.job_id]) + if state != job_context.state: + self.log.info( + f"job {job_context.job_id} is in state: {state.name}. " + f"{tracking_url}" + ) + + cluster = job_context.cluster + if not cluster: + # Job is running / or completed already + if state == JobState.RUNNING or state.is_terminal(): + run_details = self._get_latest_job_run_details(job_context.job_id) + if run_details: + cluster = run_details.cluster + return dataclasses.replace(job_context, job_state=state.name, cluster=cluster) + + def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + job_details = self.client.get_job_details([job_id]).job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py index 00b9ab59800..91db62420e0 100644 --- a/third_party/airflow/armada/model.py +++ b/third_party/airflow/armada/model.py @@ -1,7 +1,11 @@ -import importlib -from typing import Any, Dict, Optional, Sequence, Tuple +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, Optional, Sequence, Tuple import grpc +from armada_client.typings import JobState +from pendulum import DateTime """ This class exists so that we can retain our connection to the Armada Query API when using the deferrable Armada Airflow Operator. Airflow requires any state @@ -10,73 +14,55 @@ class GrpcChannelArgs: + __version__: ClassVar[int] = 1 + def __init__( self, target: str, options: Optional[Sequence[Tuple[str, Any]]] = None, compression: Optional[grpc.Compression] = None, auth: Optional[grpc.AuthMetadataPlugin] = None, - auth_details: Optional[Dict[str, Any]] = None, ): self.target = target self.options = options self.compression = compression - if auth: - self.auth = auth - elif auth_details: - classpath, kwargs = auth_details - module_path, class_name = classpath.rsplit( - ".", 1 - ) # Split the classpath to module and class name - module = importlib.import_module( - module_path - ) # Dynamically import the module - cls = getattr(module, class_name) # Get the class from the module - self.auth = cls( - **kwargs - ) # Instantiate the class with the deserialized kwargs - else: - self.auth = None + self.auth = auth def serialize(self) -> Dict[str, Any]: - auth_details = self.auth.serialize() if self.auth else None return { "target": self.target, "options": self.options, "compression": self.compression, - "auth_details": auth_details, + "auth": self.auth, } - def channel(self) -> grpc.Channel: - if self.auth is None: - return grpc.insecure_channel( - target=self.target, options=self.options, compression=self.compression - ) + @staticmethod + def deserialize(data: dict[str, Any], version: int) -> GrpcChannelArgs: + if version > GrpcChannelArgs.__version__: + raise TypeError("serialized version > class version") + return GrpcChannelArgs(**data) - return grpc.secure_channel( - target=self.target, - options=self.options, - compression=self.compression, - credentials=grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(self.auth), - ), + def __eq__(self, value: object) -> bool: + if type(value) is not GrpcChannelArgs: + return False + return ( + self.target == value.target + and self.options == value.options + and self.compression == value.compression + and self.auth == value.auth ) - def aio_channel(self) -> grpc.aio.Channel: - if self.auth is None: - return grpc.aio.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.aio.secure_channel( - target=self.target, - options=self.options, - compression=self.compression, - credentials=grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(self.auth), - ), - ) +@dataclass(frozen=True) +class RunningJobContext: + armada_queue: str + job_id: str + job_set_id: str + submit_time: DateTime + cluster: Optional[str] = None + last_log_time: Optional[DateTime] = None + job_state: str = JobState.UNKNOWN.name + + @property + def state(self) -> JobState: + return JobState[self.job_state] diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 7e365417ed3..aa07227b80e 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -17,125 +17,31 @@ # under the License. from __future__ import annotations -import asyncio +import dataclasses import datetime -import functools import os -import threading import time -from dataclasses import dataclass -from functools import cached_property -from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import jinja2 from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.serialization.serde import deserialize from airflow.utils.context import Context from airflow.utils.log.logging_mixin import LoggingMixin from armada.auth import TokenRetriever from armada.log_manager import KubernetesPodLogManager from armada.model import GrpcChannelArgs -from armada_client.armada.job_pb2 import JobRunDetails from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient from armada_client.typings import JobState from google.protobuf.json_format import MessageToDict, ParseDict from pendulum import DateTime - -def log_exceptions(method): - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - try: - return method(self, *args, **kwargs) - except Exception as e: - if hasattr(self, "log") and hasattr(self.log, "error"): - self.log.error(f"Exception in {method.__name__}: {e}") - raise - - return wrapper - - -@dataclass(frozen=False) -class _RunningJobContext: - armada_queue: str - job_set_id: str - job_id: str - state: JobState = JobState.UNKNOWN - start_time: DateTime = DateTime.utcnow() - cluster: Optional[str] = None - last_log_time: Optional[DateTime] = None - - def serialize(self) -> tuple[str, Dict[str, Any]]: - return ( - "armada.operators.armada._RunningJobContext", - { - "armada_queue": self.armada_queue, - "job_set_id": self.job_set_id, - "job_id": self.job_id, - "state": self.state.value, - "start_time": self.start_time, - "cluster": self.cluster, - "last_log_time": self.last_log_time, - }, - ) - - def from_payload(payload: Dict[str, Any]) -> _RunningJobContext: - return _RunningJobContext( - armada_queue=payload["armada_queue"], - job_set_id=payload["job_set_id"], - job_id=payload["job_id"], - state=JobState(payload["state"]), - start_time=payload["start_time"], - cluster=payload["cluster"], - last_log_time=payload["last_log_time"], - ) - - -class _ArmadaPollJobTrigger(BaseTrigger): - def __init__(self, moment: datetime.timedelta, context: _RunningJobContext) -> None: - super().__init__() - self.moment = moment - self.context = context - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "armada.operators.armada._ArmadaPollJobTrigger", - {"moment": self.moment, "context": self.context.serialize()}, - ) - - def __eq__(self, value: object) -> bool: - if not isinstance(value, _ArmadaPollJobTrigger): - return False - return self.moment == value.moment and self.context == value.context - - async def run(self) -> AsyncIterator[TriggerEvent]: - while self.moment > DateTime.utcnow(): - await asyncio.sleep(1) - yield TriggerEvent(self.context) - - -class _ArmadaClientFactory: - CLIENTS_LOCK = threading.Lock() - CLIENTS: Dict[str, ArmadaClient] = {} - - @staticmethod - def client_for(args: GrpcChannelArgs) -> ArmadaClient: - """ - Armada clients, maintain GRPC connection to Armada API. - We cache them per channel args config in class level cache. - - Access to this method can be from multiple-threads. - """ - channel_args_key = str(args.serialize()) - with _ArmadaClientFactory.CLIENTS_LOCK: - if channel_args_key not in _ArmadaClientFactory.CLIENTS: - _ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( - channel=args.channel() - ) - return _ArmadaClientFactory.CLIENTS[channel_args_key] +from ..hooks import ArmadaHook +from ..model import RunningJobContext +from ..triggers import ArmadaPollJobTrigger +from ..utils import log_exceptions class ArmadaOperator(BaseOperator, LoggingMixin): @@ -147,6 +53,7 @@ class ArmadaOperator(BaseOperator, LoggingMixin): """ template_fields: Sequence[str] = ("job_request", "job_set_prefix") + template_fields_renderers: Dict[str, str] = {"job_request": "py"} """ Initializes a new ArmadaOperator. @@ -226,32 +133,24 @@ def execute(self, context) -> None: :param context: The execution context provided by Airflow. :type context: Context """ - # We take the job_set_id from Airflow's run_id. This means that all jobs in the - # dag will be in the same jobset. + # We take the job_set_id from Airflow's run_id. + # So all jobs in the dag will be in the same jobset. self.job_set_id = f"{self.job_set_prefix}{context['run_id']}" self._annotate_job_request(context, self.job_request) - # Submit job or reattach to previously submitted job. We always do this - # synchronously. - job_id = self._reattach_or_submit_job( - context, self.armada_queue, self.job_set_id, self.job_request - ) - - # Wait until finished - self.job_context = _RunningJobContext( - self.armada_queue, self.job_set_id, job_id, start_time=DateTime.utcnow() + # Submit job or reattach to previously submitted job. + # Always do this synchronously. + self.job_context = self._reattach_or_submit_job( + context, self.job_set_id, self.job_request ) - if self.deferrable: - self._deffered_yield(self.job_context) - else: - self._poll_for_termination(self.job_context) + self._poll_for_termination() - @cached_property - def client(self) -> ArmadaClient: - return _ArmadaClientFactory.client_for(self.channel_args) + @property + def hook(self) -> ArmadaHook: + return ArmadaHook(self.channel_args) - @cached_property + @property def pod_manager(self) -> KubernetesPodLogManager: return KubernetesPodLogManager(token_retriever=self.k8s_token_retriever) @@ -276,114 +175,96 @@ def render_template_fields( super().render_template_fields(context, jinja_env) self.job_request = ParseDict(self.job_request, JobSubmitRequestItem()) - def _cancel_job(self, job_context) -> None: - try: - result = self.client.cancel_jobs( - queue=job_context.armada_queue, - job_set_id=job_context.job_set_id, - job_id=job_context.job_id, - ) - if len(list(result.cancelled_ids)) > 0: - self.log.info(f"Cancelled job with id {result.cancelled_ids}") - else: - self.log.warning(f"Failed to cancel job with id {job_context.job_id}") - except Exception as e: - self.log.warning(f"Failed to cancel job with id {job_context.job_id}: {e}") - def on_kill(self) -> None: if self.job_context is not None: self.log.info( f"on_kill called, " - "cancelling job with id {self.job_context.job_id} in queue " + f"cancelling job with id {self.job_context.job_id} in queue " f"{self.job_context.armada_queue}" ) - self._cancel_job(self.job_context) + self.hook.cancel_job(self.job_context) + self.job_context = None - def _trigger_tracking_message(self, job_id: str): + def lookout_url(self, job_id): if self.lookout_url_template: - return ( - f"Job details available at " - f'{self.lookout_url_template.replace("", job_id)}' - ) + return self.lookout_url_template.replace("", job_id) + return None + + def _trigger_tracking_message(self, job_id): + url = self.lookout_url(job_id) + if url: + return f"Job details available at {url}" return "" - def _deffered_yield(self, context: _RunningJobContext): - self.defer( - timeout=self.execution_timeout, - trigger=_ArmadaPollJobTrigger( - DateTime.utcnow() + datetime.timedelta(seconds=self.poll_interval), - context, - ), - method_name="_deffered_poll_for_termination", - ) + def _yield(self): + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=ArmadaPollJobTrigger( + DateTime.utcnow() + datetime.timedelta(seconds=self.poll_interval), + self.job_context, + self.channel_args, + ), + method_name="_trigger_reentry", + ) + else: + time.sleep(self.poll_interval) - @log_exceptions - def _deffered_poll_for_termination( + def _trigger_reentry( self, context: Context, event: Tuple[str, Dict[str, Any]] ) -> None: - job_run_context = _RunningJobContext.from_payload(event[1]) - while job_run_context.state.is_active(): - job_run_context = self._check_job_status_and_fetch_logs(job_run_context) - if job_run_context.state.is_active(): - self._deffered_yield(job_run_context) - - self._running_job_terminated(job_run_context) + self.job_context = deserialize(event) + self._poll_for_termination() def _reattach_or_submit_job( self, context: Context, - queue: str, job_set_id: str, job_request: JobSubmitRequestItem, - ) -> str: + ) -> RunningJobContext: + # Try to re-initialize job_context from xcom if it exist. ti = context["ti"] - existing_id = ti.xcom_pull( + existing_run = ti.xcom_pull( dag_id=ti.dag_id, task_ids=ti.task_id, key=f"{ti.try_number}" ) - if existing_id is not None: + if existing_run is not None: self.log.info( - f"Attached to existing job with id {existing_id['armada_job_id']}." - f" {self._trigger_tracking_message(existing_id['armada_job_id'])}" + f"Attached to existing job with id {existing_run['armada_job_id']}." + f" {self._trigger_tracking_message(existing_run['armada_job_id'])}" ) - return existing_id["armada_job_id"] - - job_id = self._submit_job(queue, job_set_id, job_request) - self.log.info( - f"Submitted job with id {job_id}. {self._trigger_tracking_message(job_id)}" - ) - ti.xcom_push(key=f"{ti.try_number}", value={"armada_job_id": job_id}) - return job_id - - def _submit_job( - self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem - ) -> str: - resp = self.client.submit_jobs(queue, job_set_id, [job_request]) - num_responses = len(resp.job_response_items) - - # We submitted exactly one job to armada, so we expect a single response - if num_responses != 1: - raise AirflowException( - f"No valid received from Armada (expected 1 job to be created " - f"but got {num_responses}" + return RunningJobContext( + armada_queue=existing_run["armada_queue"], + job_id=existing_run["armada_job_id"], + job_set_id=existing_run["armada_job_set_id"], + submit_time=DateTime.utcnow(), ) - job = resp.job_response_items[0] - - # Throw if armada told us we had submitted something bad - if job.error: - raise AirflowException(f"Error submitting job to Armada: {job.error}") - return job.job_id + # We haven't got a running job, submit a new one and persist state to xcom. + ctx = self.hook.submit_job(self.armada_queue, job_set_id, job_request) + tracking_msg = self._trigger_tracking_message(ctx.job_id) + self.log.info(f"Submitted job with id {ctx.job_id}. {tracking_msg}") + + ti.xcom_push( + key=f"{ti.try_number}", + value={ + "armada_queue": ctx.armada_queue, + "armada_job_id": ctx.job_id, + "armada_job_set_id": ctx.job_set_id, + "armada_lookout_url": self.lookout_url(ctx.job_id), + }, + ) + return ctx - def _poll_for_termination(self, context: _RunningJobContext) -> None: - while context.state.is_active(): - context = self._check_job_status_and_fetch_logs(context) - if context.state.is_active(): - time.sleep(self.poll_interval) + def _poll_for_termination(self) -> None: + while self.job_context.state.is_active(): + self._check_job_status_and_fetch_logs() + if self.job_context.state.is_active(): + self._yield() - self._running_job_terminated(context) + self._running_job_terminated(self.job_context) - def _running_job_terminated(self, context: _RunningJobContext): + def _running_job_terminated(self, context: RunningJobContext): self.log.info( f"job {context.job_id} terminated with state: {context.state.name}" ) @@ -393,57 +274,43 @@ def _running_job_terminated(self, context: _RunningJobContext): f"Final status was {context.state.name}" ) - @log_exceptions - def _check_job_status_and_fetch_logs( - self, context: _RunningJobContext - ) -> _RunningJobContext: - response = self.client.get_job_status([context.job_id]) - state = JobState(response.job_states[context.job_id]) - if state != context.state: - self.log.info( - f"job {context.job_id} is in state: {state.name}. " - f"{self._trigger_tracking_message(context.job_id)}" - ) - context.state = state - - if context.state == JobState.UNKNOWN: + def _not_acknowledged_within_timeout(self) -> bool: + if self.job_context.state == JobState.UNKNOWN: if ( - DateTime.utcnow().diff(context.start_time).in_seconds() + DateTime.utcnow().diff(self.job_context.submit_time).in_seconds() > self.job_acknowledgement_timeout ): - self.log.info( - f"Job {context.job_id} not acknowledged by the Armada within " - f"timeout ({self.job_acknowledgement_timeout}), terminating" - ) - self._cancel_job(context) - context.state = JobState.CANCELLED - return context + return True + return False - if self.container_logs and not context.cluster: - if context.state == JobState.RUNNING or context.state.is_terminal(): - run_details = self._get_latest_job_run_details(context.job_id) - context.cluster = run_details.cluster + @log_exceptions + def _check_job_status_and_fetch_logs(self) -> None: + self.job_context = self.hook.refresh_context( + self.job_context, self._trigger_tracking_message(self.job_context.job_id) + ) + + if self._not_acknowledged_within_timeout(): + self.log.info( + f"Job {self.job_context.job_id} not acknowledged by the Armada within " + f"timeout ({self.job_acknowledgement_timeout}), terminating" + ) + self.job_context = self.hook.cancel_job(self.job_context) + return - if context.cluster: + if self.job_context.cluster and self.container_logs: try: - context.last_log_time = self.pod_manager.fetch_container_logs( - k8s_context=context.cluster, + last_log_time = self.pod_manager.fetch_container_logs( + k8s_context=self.job_context.cluster, namespace=self.job_request.namespace, - pod=f"armada-{context.job_id}-0", + pod=f"armada-{self.job_context.job_id}-0", container=self.container_logs, - since_time=context.last_log_time, + since_time=self.job_context.last_log_time, + ) + self.job_context = dataclasses.replace( + self.job_context, last_log_time=last_log_time ) except Exception as e: self.log.warning(f"Error fetching logs {e}") - return context - - def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: - job_details = self.client.get_job_details([job_id]).job_details[job_id] - if job_details and job_details.latest_run_id: - for run in job_details.job_runs: - if run.run_id == job_details.latest_run_id: - return run - return None @staticmethod def _annotate_job_request(context, request: JobSubmitRequestItem): diff --git a/third_party/airflow/armada/triggers.py b/third_party/airflow/armada/triggers.py new file mode 100644 index 00000000000..2ea44e16c0c --- /dev/null +++ b/third_party/airflow/armada/triggers.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import asyncio +from datetime import timedelta +from typing import Any, AsyncIterator, ClassVar, Dict + +from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance +from airflow.serialization.serde import deserialize, serialize +from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState +from pendulum import DateTime +from sqlalchemy.orm.session import Session + +from .hooks import ArmadaHook +from .model import GrpcChannelArgs, RunningJobContext +from .utils import log_exceptions + + +class ArmadaPollJobTrigger(BaseTrigger): + __version__: ClassVar[int] = 1 + + @log_exceptions + def __init__( + self, + moment: timedelta, + context: RunningJobContext | tuple[str, Dict[str, Any]], + channel_args: GrpcChannelArgs | tuple[str, Dict[str, Any]], + ) -> None: + super().__init__() + + self.moment = moment + if type(context) is RunningJobContext: + self.context = context + else: + self.context = deserialize(context) + + if type(channel_args) is GrpcChannelArgs: + self.channel_args = channel_args + else: + self.channel_args = deserialize(channel_args) + + @log_exceptions + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "armada.triggers.ArmadaPollJobTrigger", + { + "moment": self.moment, + "context": serialize(self.context), + "channel_args": serialize(self.channel_args), + }, + ) + + @log_exceptions + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s,task_id: %s, " + "run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def should_cancel_job(self) -> bool: + """ + We only want to cancel jobs when task is being marked Failed/Succeeded. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + + def __eq__(self, value: object) -> bool: + if not isinstance(value, ArmadaPollJobTrigger): + return False + return ( + self.moment == value.moment + and self.context == value.context + and self.channel_args == value.channel_args + ) + + @property + def hook(self) -> ArmadaHook: + return ArmadaHook(self.channel_args) + + @log_exceptions + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + while self.moment > DateTime.utcnow(): + await asyncio.sleep(1) + yield TriggerEvent(serialize(self.context)) + except asyncio.CancelledError: + if self.should_cancel_job(): + self.hook.cancel_job(self.context) + raise diff --git a/third_party/airflow/armada/utils.py b/third_party/airflow/armada/utils.py new file mode 100644 index 00000000000..e700a1bbc5e --- /dev/null +++ b/third_party/airflow/armada/utils.py @@ -0,0 +1,14 @@ +import functools + + +def log_exceptions(method): + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + try: + return method(self, *args, **kwargs) + except Exception as e: + if hasattr(self, "log") and hasattr(self.log, "error"): + self.log.error(f"Exception in {method.__name__}: {e}") + raise + + return wrapper diff --git a/third_party/airflow/docs/source/conf.py b/third_party/airflow/docs/source/conf.py index 10d3949aee8..a7e2f5a75bb 100644 --- a/third_party/airflow/docs/source/conf.py +++ b/third_party/airflow/docs/source/conf.py @@ -13,14 +13,14 @@ import os import sys -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'python-armadaairflowoperator' -copyright = '2022 Armada Project' -author = 'armada@armadaproject.io' +project = "python-armadaairflowoperator" +copyright = "2022 Armada Project" +author = "armada@armadaproject.io" # -- General configuration --------------------------------------------------- @@ -28,12 +28,12 @@ # Jekyll is the style of markdown used by github pages; using # sphinx_jekyll_builder here allows us to generate docs as # markdown files. -extensions = ['sphinx.ext.autodoc', 'sphinx_jekyll_builder'] +extensions = ["sphinx.ext.autodoc", "sphinx_jekyll_builder"] # This setting puts information about typing in the description section instead # of in the function signature directly. This makes rendered content look much # better in our gh-pages template that renders the generated markdown. -autodoc_typehints = 'description' +autodoc_typehints = "description" # Add any paths that contain templates here, relative to this directory. templates_path = [] @@ -49,7 +49,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/third_party/airflow/examples/bad_armada.py b/third_party/airflow/examples/bad_armada.py index 137f4730791..809648e625a 100644 --- a/third_party/airflow/examples/bad_armada.py +++ b/third_party/airflow/examples/bad_armada.py @@ -5,8 +5,9 @@ from armada.operators.armada import ArmadaOperator from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import \ - generated_pb2 as api_resource +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) def submit_sleep_container(image: str): diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index ebd84d723ce..868a31516e9 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -5,8 +5,9 @@ from armada.operators.armada import ArmadaOperator from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import \ - generated_pb2 as api_resource +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) def submit_sleep_job(): diff --git a/third_party/airflow/examples/hello_armada.py b/third_party/airflow/examples/hello_armada.py index d3120bdf5f6..bdd773ce8fe 100644 --- a/third_party/airflow/examples/hello_armada.py +++ b/third_party/airflow/examples/hello_armada.py @@ -5,8 +5,9 @@ from armada.operators.armada import ArmadaOperator from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import \ - generated_pb2 as api_resource +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) def submit_sleep_job(): diff --git a/third_party/airflow/examples/hello_armada_deferrable.py b/third_party/airflow/examples/hello_armada_deferrable.py index f3e661875d0..beac8fa9adb 100644 --- a/third_party/airflow/examples/hello_armada_deferrable.py +++ b/third_party/airflow/examples/hello_armada_deferrable.py @@ -5,8 +5,9 @@ from armada.operators.armada import ArmadaOperator from armada_client.armada import submit_pb2 from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import \ - generated_pb2 as api_resource +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) def submit_sleep_job(): diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index 8f8fb538a57..39040c3f3d8 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -39,7 +39,7 @@ include = ["armada_airflow*"] [tool.black] line-length = 88 -target-version = ['py310'] +target-version = ['py38', 'py39', 'py310'] include = ''' /( armada diff --git a/third_party/airflow/test/integration/test_airflow_operator_logic.py b/third_party/airflow/test/integration/test_airflow_operator_logic.py index 4bc3c43418e..594d3d5eaec 100644 --- a/third_party/airflow/test/integration/test_airflow_operator_logic.py +++ b/third_party/airflow/test/integration/test_airflow_operator_logic.py @@ -85,9 +85,7 @@ def sleep_pod(image: str): ] -def test_success_job( - client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker -): +def test_success_job(client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs): job_set_name = f"test-{uuid.uuid1()}" job = client.submit_jobs( queue=DEFAULT_QUEUE, @@ -96,10 +94,11 @@ def test_success_job( ) job_id = job.job_response_items[0].job_id - mocker.patch( - "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", - return_value=job_id, - ) + context["ti"].xcom_pull.return_value = { + "armada_queue": DEFAULT_QUEUE, + "armada_job_id": job_id, + "armada_job_set_id": job_set_name, + } operator = ArmadaOperator( task_id=DEFAULT_TASK_ID, @@ -113,13 +112,11 @@ def test_success_job( operator.execute(context) - response = operator.client.get_job_status([job_id]) + response = client.get_job_status([job_id]) assert JobState(response.job_states[job_id]) == JobState.SUCCEEDED -def test_bad_job( - client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker -): +def test_bad_job(client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs): job_set_name = f"test-{uuid.uuid1()}" job = client.submit_jobs( queue=DEFAULT_QUEUE, @@ -128,10 +125,11 @@ def test_bad_job( ) job_id = job.job_response_items[0].job_id - mocker.patch( - "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", - return_value=job_id, - ) + context["ti"].xcom_pull.return_value = { + "armada_queue": DEFAULT_QUEUE, + "armada_job_id": job_id, + "armada_job_set_id": job_set_name, + } operator = ArmadaOperator( task_id=DEFAULT_TASK_ID, @@ -149,7 +147,7 @@ def test_bad_job( "Operator did not raise AirflowException on job failure as expected" ) except AirflowException: # Expected - response = operator.client.get_job_status([job_id]) + response = client.get_job_status([job_id]) assert JobState(response.job_states[job_id]) == JobState.FAILED except Exception as e: pytest.fail( @@ -159,7 +157,7 @@ def test_bad_job( def success_job( - task_number: int, context: Any, channel_args: GrpcChannelArgs + task_number: int, context: Any, channel_args: GrpcChannelArgs, client: ArmadaClient ) -> JobState: operator = ArmadaOperator( task_id=f"{DEFAULT_TASK_ID}_{task_number}", @@ -173,7 +171,7 @@ def success_job( operator.execute(context) - response = operator.client.get_job_status([operator.job_id]) + response = client.get_job_status([operator.job_id]) return JobState(response.job_states[operator.job_id]) @@ -182,7 +180,9 @@ def test_parallel_execution( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(5): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] @@ -199,7 +199,9 @@ def test_parallel_execution_large( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(80): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] @@ -216,7 +218,9 @@ def test_parallel_execution_huge( client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker ): threads = [] - success_job(task_number=0, context=context, channel_args=channel_args) + success_job( + task_number=0, context=context, channel_args=channel_args, client=client + ) for task_number in range(500): t = threading.Thread( target=success_job, args=[task_number, context, channel_args] diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py deleted file mode 100644 index 85129000ad1..00000000000 --- a/third_party/airflow/test/operators/test_armada.py +++ /dev/null @@ -1,324 +0,0 @@ -import unittest -from datetime import timedelta -from math import ceil -from unittest.mock import MagicMock, PropertyMock, patch - -from airflow.exceptions import AirflowException -from armada.model import GrpcChannelArgs -from armada.operators.armada import ( - ArmadaOperator, - _ArmadaPollJobTrigger, - _RunningJobContext, -) -from armada_client.armada import job_pb2, submit_pb2 -from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -from armada_client.typings import JobState -from pendulum import UTC, DateTime - -DEFAULT_CURRENT_TIME = DateTime(2024, 8, 7, tzinfo=UTC) -DEFAULT_JOB_ID = "test_job" -DEFAULT_TASK_ID = "test_task_1" -DEFAULT_DAG_ID = "test_dag_1" -DEFAULT_RUN_ID = "test_run_1" -DEFAULT_QUEUE = "test_queue_1" -DEFAULT_POLLING_INTERVAL = 30 -DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 - - -class TestArmadaOperator(unittest.TestCase): - def setUp(self): - # Set up a mock context - mock_ti = MagicMock() - mock_ti.task_id = DEFAULT_TASK_ID - mock_dag = MagicMock() - mock_dag.dag_id = DEFAULT_DAG_ID - self.context = { - "ti": mock_ti, - "run_id": DEFAULT_RUN_ID, - "dag": mock_dag, - } - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_execute(self, mock_client_fn, _): - test_cases = [ - { - "name": "Job Succeeds", - "statuses": [submit_pb2.RUNNING, submit_pb2.SUCCEEDED], - "success": True, - }, - { - "name": "Job Failed", - "statuses": [submit_pb2.RUNNING, submit_pb2.FAILED], - "success": False, - }, - { - "name": "Job cancelled", - "statuses": [submit_pb2.RUNNING, submit_pb2.CANCELLED], - "success": False, - }, - { - "name": "Job preempted", - "statuses": [submit_pb2.RUNNING, submit_pb2.PREEMPTED], - "success": False, - }, - { - "name": "Job Succeeds but takes a lot of transitions", - "statuses": [ - submit_pb2.SUBMITTED, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.RUNNING, - submit_pb2.SUCCEEDED, - ], - "success": True, - }, - ] - - for test_case in test_cases: - with self.subTest(test_case=test_case["name"]): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[ - submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID) - ] - ) - - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in test_case["statuses"] - ] - - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = None - - try: - operator.execute(self.context) - self.assertTrue(test_case["success"]) - except AirflowException: - self.assertFalse(test_case["success"]) - return - - self.assertEqual(mock_client.submit_jobs.call_count, 1) - self.assertEqual( - mock_client.get_job_status.call_count, len(test_case["statuses"]) - ) - - @patch("time.sleep", return_value=None) - @patch( - "armada.operators.armada.ArmadaOperator._cancel_job", new_callable=PropertyMock - ) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=-1, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [submit_pb2.UNKNOWN, submit_pb2.UNKNOWN] - ] - - self.context["ti"].xcom_pull.return_value = None - with self.assertRaises(AirflowException): - operator.execute(self.context) - self.assertEqual(mock_on_kill.call_count, 1) - - """We call on_kill by triggering the job unacknowledged timeout""" - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_on_kill_cancels_job(self, mock_client_fn, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=-1, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [ - submit_pb2.UNKNOWN - for _ in range( - 1 - + ceil( - DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL - ) - ) - ] - ] - - self.context["ti"].xcom_pull.return_value = None - with self.assertRaises(AirflowException): - operator.execute(self.context) - self.assertEqual(mock_client.cancel_jobs.call_count, 1) - - @patch("time.sleep", return_value=None) - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_job_reattaches(self, mock_client_fn, _): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=False, - job_acknowledgement_timeout=10, - ) - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.get_job_status.side_effect = [ - job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) - for x in [ - submit_pb2.SUCCEEDED - for _ in range( - 1 - + ceil( - DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL - ) - ) - ] - ] - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = {"armada_job_id": DEFAULT_JOB_ID} - - operator.execute(self.context) - self.assertEqual(mock_client.submit_jobs.call_count, 0) - - -class TestArmadaOperatorDeferrable(unittest.IsolatedAsyncioTestCase): - def setUp(self): - # Set up a mock context - mock_ti = MagicMock() - mock_ti.task_id = DEFAULT_TASK_ID - mock_dag = MagicMock() - mock_dag.dag_id = DEFAULT_DAG_ID - self.context = { - "ti": mock_ti, - "run_id": DEFAULT_RUN_ID, - "dag": mock_dag, - } - - @patch("pendulum.DateTime.utcnow") - @patch("armada.operators.armada.ArmadaOperator.defer") - @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) - def test_execute_deferred(self, mock_client_fn, mock_defer_fn, mock_datetime_now): - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=JobSubmitRequestItem(), - task_id=DEFAULT_TASK_ID, - deferrable=True, - ) - - mock_datetime_now.return_value = DEFAULT_CURRENT_TIME - - # Set up Mock Armada - mock_client = MagicMock() - mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( - job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] - ) - mock_client_fn.return_value = mock_client - self.context["ti"].xcom_pull.return_value = None - - operator.execute(self.context) - self.assertEqual(mock_client.submit_jobs.call_count, 1) - mock_defer_fn.assert_called_with( - timeout=operator.execution_timeout, - trigger=_ArmadaPollJobTrigger( - moment=DEFAULT_CURRENT_TIME + timedelta(seconds=operator.poll_interval), - context=_RunningJobContext( - armada_queue=DEFAULT_QUEUE, - job_set_id=operator.job_set_id, - job_id=DEFAULT_JOB_ID, - state=JobState.UNKNOWN, - start_time=DEFAULT_CURRENT_TIME, - cluster=None, - last_log_time=None, - ), - ), - method_name="_deffered_poll_for_termination", - ) - - def test_templating(self): - """Tests templating for both the job_prefix and the pod spec""" - prefix = "{{ run_id }}" - pod_arg = "{{ run_id }}" - - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="sleep", - image="alpine:3.20.2", - args=[pod_arg], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - job = JobSubmitRequestItem(priority=1, pod_spec=pod, namespace="armada") - - operator = ArmadaOperator( - name="test", - channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), - armada_queue=DEFAULT_QUEUE, - job_request=job, - job_set_prefix=prefix, - task_id=DEFAULT_TASK_ID, - deferrable=True, - ) - - operator.render_template_fields(self.context) - - self.assertEqual(operator.job_set_prefix, "test_run_1") - self.assertEqual( - operator.job_request.pod_spec.containers[0].args[0], "test_run_1" - ) diff --git a/third_party/airflow/test/__init__.py b/third_party/airflow/test/unit/__init__.py similarity index 100% rename from third_party/airflow/test/__init__.py rename to third_party/airflow/test/unit/__init__.py diff --git a/third_party/airflow/test/operators/__init__.py b/third_party/airflow/test/unit/operators/__init__.py similarity index 100% rename from third_party/airflow/test/operators/__init__.py rename to third_party/airflow/test/unit/operators/__init__.py diff --git a/third_party/airflow/test/unit/operators/test_armada.py b/third_party/airflow/test/unit/operators/test_armada.py new file mode 100644 index 00000000000..d2aab33cce4 --- /dev/null +++ b/third_party/airflow/test/unit/operators/test_armada.py @@ -0,0 +1,262 @@ +import dataclasses +from datetime import timedelta +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +from airflow.exceptions import AirflowException, TaskDeferred +from armada.model import GrpcChannelArgs, RunningJobContext +from armada.operators.armada import ArmadaOperator +from armada.triggers import ArmadaPollJobTrigger +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.typings import JobState +from pendulum import UTC, DateTime + +DEFAULT_CURRENT_TIME = DateTime(2024, 8, 7, tzinfo=UTC) +DEFAULT_JOB_ID = "test_job" +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_JOB_SET = "prefix-test_run_1" +DEFAULT_QUEUE = "test_queue_1" +DEFAULT_CLUSTER = "cluster-1" + + +def default_hook() -> MagicMock: + mock = MagicMock() + job_context = running_job_context() + mock.submit_job.return_value = job_context + mock.refresh_context.return_value = dataclasses.replace( + job_context, job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER + ) + mock.cancel_job.return_value = dataclasses.replace( + job_context, job_state=JobState.CANCELLED.name + ) + + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_operator_dependencies(): + # We no-op time.sleep in tests. + with patch("time.sleep", return_value=None) as sleep, patch( + "armada.log_manager.KubernetesPodLogManager.fetch_container_logs" + ) as logs, patch( + "armada.operators.armada.ArmadaOperator.hook", new_callable=default_hook + ) as hook: + yield sleep, logs, hook + + +@pytest.fixture +def context(): + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_ti.try_number = 0 + mock_ti.xcom_pull.return_value = None + + mock_dag = MagicMock() + mock_dag.dag_id = "test_dag_1" + + context = {"ti": mock_ti, "run_id": "test_run_1", "dag": mock_dag} + + return context + + +def operator( + job_request: JobSubmitRequestItem, + deferrable: bool = False, + job_acknowledgement_timeout_s: int = 30, + container_logs: Optional[str] = None, +) -> ArmadaOperator: + operator = ArmadaOperator( + armada_queue=DEFAULT_QUEUE, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + container_logs=container_logs, + deferrable=deferrable, + job_acknowledgement_timeout=job_acknowledgement_timeout_s, + job_request=job_request, + job_set_prefix="prefix-", + lookout_url_template="http://lookout.armadaproject.io/jobs?job_id=", + name="test", + task_id=DEFAULT_TASK_ID, + ) + + return operator + + +def running_job_context( + cluster: str = None, + submit_time: DateTime = DateTime.now(), + job_state: str = JobState.UNKNOWN.name, +) -> RunningJobContext: + return RunningJobContext( + DEFAULT_QUEUE, + DEFAULT_JOB_ID, + DEFAULT_JOB_SET, + submit_time, + cluster, + job_state=job_state, + ) + + +@pytest.mark.parametrize( + "job_states", + [ + [JobState.RUNNING, JobState.SUCCEEDED], + [ + JobState.QUEUED, + JobState.LEASED, + JobState.QUEUED, + JobState.RUNNING, + JobState.SUCCEEDED, + ], + ], + ids=["success", "success - multiple events"], +) +def test_execute(job_states, context): + op = operator(JobSubmitRequestItem()) + + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) for s in job_states + ] + + op.execute(context) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert op.hook.refresh_context.call_count == len(job_states) + + # We're not polling for logs + op.pod_manager.fetch_container_logs.assert_not_called() + + +@patch("pendulum.DateTime.utcnow", return_value=DEFAULT_CURRENT_TIME) +def test_execute_in_deferrable(_, context): + op = operator(JobSubmitRequestItem(), deferrable=True) + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) + for s in [JobState.QUEUED, JobState.QUEUED] + ] + + with pytest.raises(TaskDeferred) as deferred: + op.execute(context) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert deferred.value.timeout == op.execution_timeout + assert deferred.value.trigger == ArmadaPollJobTrigger( + moment=DEFAULT_CURRENT_TIME + timedelta(seconds=op.poll_interval), + context=op.job_context, + channel_args=op.channel_args, + ) + assert deferred.value.method_name == "_trigger_reentry" + + +@pytest.mark.parametrize( + "terminal_state", + [JobState.FAILED, JobState.PREEMPTED, JobState.CANCELLED], + ids=["failed", "preempted", "cancelled"], +) +def test_execute_fail(terminal_state, context): + op = operator(JobSubmitRequestItem()) + + op.hook.refresh_context.side_effect = [ + running_job_context(cluster="cluster-1", job_state=s.name) + for s in [JobState.RUNNING, terminal_state] + ] + + with pytest.raises(AirflowException) as exec_info: + op.execute(context) + + # Error message contain terminal state and job id + assert DEFAULT_JOB_ID in str(exec_info) + assert terminal_state.name in str(exec_info) + + op.hook.submit_job.assert_called_once_with( + DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request + ) + assert op.hook.refresh_context.call_count == 2 + + # We're not polling for logs + op.pod_manager.fetch_container_logs.assert_not_called() + + +def test_on_kill_terminates_running_job(): + op = operator(JobSubmitRequestItem()) + job_context = running_job_context() + op.job_context = job_context + + op.on_kill() + op.on_kill() + + # We ensure we only try to cancel job once. + op.hook.cancel_job.assert_called_once_with(job_context) + + +def test_not_acknowledged_within_timeout_terminates_running_job(context): + job_context = running_job_context() + op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1) + op.hook.refresh_context.return_value = job_context + + with pytest.raises(AirflowException) as exec_info: + op.execute(context) + + # Error message contain terminal state and job id + assert DEFAULT_JOB_ID in str(exec_info) + assert JobState.CANCELLED.name in str(exec_info) + + # We also cancel already submitted job + op.hook.cancel_job.assert_called_once_with(job_context) + + +def test_polls_for_logs(context): + op = operator( + JobSubmitRequestItem(namespace="namespace-1"), container_logs="alpine" + ) + op.execute(context) + + # We polled logs as expected. + op.pod_manager.fetch_container_logs.assert_called_once_with( + k8s_context="cluster-1", + namespace="namespace-1", + pod="armada-test_job-0", + container="alpine", + since_time=None, + ) + + +def test_publishes_xcom_state(context): + op = operator(JobSubmitRequestItem()) + op.execute(context) + + lookout_url = f"http://lookout.armadaproject.io/jobs?job_id={DEFAULT_JOB_ID}" + context["ti"].xcom_push.assert_called_once_with( + key="0", + value={ + "armada_job_id": DEFAULT_JOB_ID, + "armada_job_set_id": DEFAULT_JOB_SET, + "armada_lookout_url": lookout_url, + "armada_queue": DEFAULT_QUEUE, + }, + ) + + +def test_reattaches_to_running_job(context): + op = operator(JobSubmitRequestItem()) + context["ti"].xcom_pull.return_value = { + "armada_job_id": DEFAULT_JOB_ID, + "armada_job_set_id": DEFAULT_JOB_SET, + "armada_queue": DEFAULT_QUEUE, + } + + op.execute(context) + + assert op.job_context == running_job_context( + job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER + ) + op.hook.submit_job.assert_not_called() + + +@pytest.mark.skip("TODO") +def test_templates_job_request_item(): + pass diff --git a/third_party/airflow/test/unit/test_hooks.py b/third_party/airflow/test/unit/test_hooks.py new file mode 100644 index 00000000000..0a2e1ba2e11 --- /dev/null +++ b/third_party/airflow/test/unit/test_hooks.py @@ -0,0 +1,16 @@ +import pytest + + +@pytest.mark.skip("TODO") +def test_submits_job_using_armada_client(): + pass + + +@pytest.mark.skip("TODO") +def test_cancels_job_using_armada_client(): + pass + + +@pytest.mark.skip("TODO") +def test_updates_job_context(): + pass diff --git a/third_party/airflow/test/unit/test_model.py b/third_party/airflow/test/unit/test_model.py new file mode 100644 index 00000000000..906b7315ad9 --- /dev/null +++ b/third_party/airflow/test/unit/test_model.py @@ -0,0 +1,33 @@ +import grpc +from airflow.serialization.serde import deserialize, serialize +from armada.model import GrpcChannelArgs, RunningJobContext +from armada_client.typings import JobState +from pendulum import DateTime + + +def test_roundtrip_running_job_context(): + context = RunningJobContext( + "queue_123", + "job_id_123", + "job_set_id_123", + DateTime.utcnow(), + "cluster-1.armada.localhost", + DateTime.utcnow().add(minutes=-2), + JobState.RUNNING.name, + ) + + result = deserialize(serialize(context)) + assert context == result + assert JobState.RUNNING == result.state + + +def test_roundtrip_grpc_channel_args(): + channel_args = GrpcChannelArgs( + "armada-api.localhost", + [("key-1", 10), ("key-2", "value-2")], + grpc.Compression.NoCompression, + None, + ) + + result = deserialize(serialize(channel_args)) + assert channel_args == result diff --git a/third_party/airflow/test/unit/test_triggers.py b/third_party/airflow/test/unit/test_triggers.py new file mode 100644 index 00000000000..bdd15333caa --- /dev/null +++ b/third_party/airflow/test/unit/test_triggers.py @@ -0,0 +1,16 @@ +import pytest + + +@pytest.mark.skip("TODO") +def test_yields_with_context(): + pass + + +@pytest.mark.skip("TODO") +def test_cancels_running_job_when_task_is_cancelled(): + pass + + +@pytest.mark.skip("TODO") +def test_do_not_cancels_running_job_when_trigger_is_suspended(): + pass diff --git a/third_party/airflow/tox.ini b/third_party/airflow/tox.ini index 09dd8ce15ea..ed457e94d70 100644 --- a/third_party/airflow/tox.ini +++ b/third_party/airflow/tox.ini @@ -13,7 +13,7 @@ allowlist_externals = find xargs commands = - coverage run -m unittest discover + coverage run -m pytest test/unit/ coverage xml # This executes the dag files in examples but really only checks for imports and python errors bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3"