From e0c73792e007ae0c9a0b54c4d487addf2bb97d2e Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Wed, 13 Nov 2024 14:04:14 +0000 Subject: [PATCH] Update datetime usage to ensure UTC consistency across task repository --- aana/storage/repository/task.py | 19 +++++++++++-------- aana/tests/db/datastore/test_task_repo.py | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 74971013..a0752076 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any from uuid import UUID @@ -137,7 +137,6 @@ def fetch_unprocessed_tasks( .with_for_update(skip_locked=True) .all() ) - for task in tasks: self.update_status( task_id=task.id, @@ -167,9 +166,9 @@ def update_status( """ task = self.read(task_id) if status == TaskStatus.COMPLETED or status == TaskStatus.FAILED: - task.completed_at = datetime.now() # noqa: DTZ005 + task.completed_at = datetime.now(timezone.utc) if status == TaskStatus.ASSIGNED: - task.assigned_at = datetime.now() # noqa: DTZ005 + task.assigned_at = datetime.now(timezone.utc) task.num_retries += 1 if progress is not None: task.progress = progress @@ -249,8 +248,12 @@ def update_expired_tasks( Returns: list[TaskEntity]: the expired tasks. """ - timeout_cutoff = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 - heartbeat_cutoff = datetime.now() - timedelta(seconds=heartbeat_timeout) # noqa: DTZ005 + timeout_cutoff = datetime.now(timezone.utc) - timedelta( + seconds=execution_timeout + ) + heartbeat_cutoff = datetime.now(timezone.utc) - timedelta( + seconds=heartbeat_timeout + ) tasks = ( self.session.query(TaskEntity) .filter( @@ -268,7 +271,7 @@ def update_expired_tasks( ) for task in tasks: if task.num_retries >= max_retries: - if task.assigned_at <= timeout_cutoff: + if task.assigned_at.astimezone(timezone.utc) <= timeout_cutoff: result = { "error": "TimeoutError", "message": ( @@ -313,7 +316,7 @@ def heartbeat(self, task_ids: list[str] | set[str]): for task_id in task_ids ] self.session.query(TaskEntity).filter(TaskEntity.id.in_(task_ids)).update( - {TaskEntity.updated_at: datetime.now()}, # noqa: DTZ005 + {TaskEntity.updated_at: datetime.now(timezone.utc)}, synchronize_session=False, ) self.session.commit() diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py index f8ab11ab..52b3a978 100644 --- a/aana/tests/db/datastore/test_task_repo.py +++ b/aana/tests/db/datastore/test_task_repo.py @@ -1,7 +1,7 @@ # ruff: noqa: S101 import asyncio -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -51,7 +51,7 @@ def _create_sample_tasks(): db_session.commit() # Create sample tasks with different statuses - now = datetime.now() # noqa: DTZ005 + now = datetime.now(timezone.utc) task1 = TaskEntity( endpoint="/test1", @@ -291,7 +291,7 @@ def test_update_expired_tasks(db_session): db_session.commit() # Set up current time and a cutoff time - current_time = datetime.now() # noqa: DTZ005 + current_time = datetime.now(timezone.utc) execution_timeout = 3600 # 1 hour in seconds heartbeat_timeout = 60 # 1 minute in seconds