diff --git a/src/saturn_engine/stores/topologies_store.py b/src/saturn_engine/stores/topologies_store.py index 7888fc48..d4ee9e3d 100644 --- a/src/saturn_engine/stores/topologies_store.py +++ b/src/saturn_engine/stores/topologies_store.py @@ -23,3 +23,10 @@ def patch(*, session: AnySession, patch: BaseObject) -> TopologyPatch: ) session.execute(stmt) # type: ignore return topology_patch + + +def get_patches(*, session: AnySession) -> list[TopologyPatch]: + """_summary_ + Return all the patches + """ + return session.query(TopologyPatch).all() diff --git a/src/saturn_engine/utils/declarative_config.py b/src/saturn_engine/utils/declarative_config.py index 82d423b8..2694a532 100644 --- a/src/saturn_engine/utils/declarative_config.py +++ b/src/saturn_engine/utils/declarative_config.py @@ -75,5 +75,4 @@ def load_uncompiled_objects_from_directory(config_dir: str) -> list[UncompiledOb with open(os.path.join(root, filename), "r", encoding="utf-8") as f: uncompiled_objects.extend(load_uncompiled_objects_from_str(f.read())) - return uncompiled_objects diff --git a/src/saturn_engine/utils/dict.py b/src/saturn_engine/utils/dict.py new file mode 100644 index 00000000..18cdbfc2 --- /dev/null +++ b/src/saturn_engine/utils/dict.py @@ -0,0 +1,14 @@ +import typing as t + + +def deep_merge(a: dict[str, t.Any], b: dict[str, t.Any]) -> dict[str, t.Any]: + """ + Merge b into a + """ + result = a.copy() + for key, value in b.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = value + return result diff --git a/src/saturn_engine/worker_manager/config/declarative.py b/src/saturn_engine/worker_manager/config/declarative.py index fd0076c4..9fd387a7 100644 --- a/src/saturn_engine/worker_manager/config/declarative.py +++ b/src/saturn_engine/worker_manager/config/declarative.py @@ -5,6 +5,8 @@ import re from collections import defaultdict +from saturn_engine.models.topology_patches import TopologyPatch +from saturn_engine.utils import dict as dict_utils from saturn_engine.utils.declarative_config import UncompiledObject from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_path from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str @@ -31,8 +33,15 @@ def compile_static_definitions( uncompiled_objects: list[UncompiledObject], + patches: list[TopologyPatch] | None = None, ) -> StaticDefinitions: objects_by_kind: DefaultDict[str, dict[str, UncompiledObject]] = defaultdict(dict) + + if patches: + uncompiled_objects = merge_with_patches( + uncompiled_objects=uncompiled_objects, patches=patches + ) + for uncompiled_object in uncompiled_objects: if uncompiled_object.name in objects_by_kind[uncompiled_object.kind]: raise Exception( @@ -122,12 +131,14 @@ def load_definitions_from_str(definitions: str) -> StaticDefinitions: return compile_static_definitions(load_uncompiled_objects_from_str(definitions)) -def load_definitions_from_paths(config_dirs: list[str]) -> StaticDefinitions: +def load_definitions_from_paths( + config_dirs: list[str], patches: list[TopologyPatch] | None = None +) -> StaticDefinitions: uncompiled_objects = [] for config_dir in config_dirs: uncompiled_objects.extend(load_uncompiled_objects_from_path(config_dir)) - return compile_static_definitions(uncompiled_objects) + return compile_static_definitions(uncompiled_objects, patches=patches) def filter_with_jobs_selector( @@ -141,3 +152,25 @@ def filter_with_jobs_selector( if pattern.search(name) } return dataclasses.replace(definitions, jobs=jobs, job_definitions=job_definitions) + + +def merge_with_patches( + uncompiled_objects: list[UncompiledObject], patches: list[TopologyPatch] +) -> list[UncompiledObject]: + uncompiled_object_by_kind_and_name = { + (u.kind, u.name): u for u in uncompiled_objects + } + for patch in patches: + uncompiled_object = uncompiled_object_by_kind_and_name.get( + (patch.kind, patch.name) + ) + if not uncompiled_object: + logging.warning( + f"Can't find an uncompiled objects to use with patch {patch=}" + ) + continue + + uncompiled_object.data = dict_utils.deep_merge( + a=uncompiled_object.data, b=patch.data + ) + return uncompiled_objects diff --git a/src/saturn_engine/worker_manager/context.py b/src/saturn_engine/worker_manager/context.py index 07710554..0d22596e 100644 --- a/src/saturn_engine/worker_manager/context.py +++ b/src/saturn_engine/worker_manager/context.py @@ -1,4 +1,5 @@ from saturn_engine.config import WorkerManagerConfig +from saturn_engine.stores import topologies_store from saturn_engine.utils.sqlalchemy import AnySession from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths @@ -38,7 +39,11 @@ def _load_static_definition( - Jobs - JobDefinitions """ - definitions = load_definitions_from_paths(config.static_definitions_directories) + patches = topologies_store.get_patches(session=session) + definitions = load_definitions_from_paths( + config.static_definitions_directories, patches=patches + ) + if config.static_definitions_jobs_selector: definitions = filter_with_jobs_selector( definitions=definitions, diff --git a/src/saturn_engine/worker_manager/services/lock.py b/src/saturn_engine/worker_manager/services/lock.py index ab497c97..0d314ba5 100644 --- a/src/saturn_engine/worker_manager/services/lock.py +++ b/src/saturn_engine/worker_manager/services/lock.py @@ -22,7 +22,6 @@ def lock_jobs( session: AnySyncSession, ) -> LockResponse: logger = logging.getLogger(f"{__name__}.lock_jobs") - # Note: # - Leftover items remain unassigned. assignation_expiration_cutoff: datetime = datetime.now() - timedelta(minutes=15) @@ -61,7 +60,6 @@ def lock_jobs( selector=lock_input.selector, ) ) - # Join definitions and filtered out by executors for item in assigned_items.copy(): try: @@ -129,7 +127,6 @@ def lock_jobs( continue executors.setdefault(executor.name, executor) - # Refresh assignments new_assigned_at = datetime.now() for assigned_item in assigned_items: diff --git a/tests/worker_manager/api/test_topologies.py b/tests/worker_manager/api/test_topologies.py index 6175871c..e46564bd 100644 --- a/tests/worker_manager/api/test_topologies.py +++ b/tests/worker_manager/api/test_topologies.py @@ -1,4 +1,9 @@ +from unittest import mock + from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from saturn_engine.worker_manager.app import SaturnApp def test_put_topology_patch(client: FlaskClient) -> None: @@ -36,3 +41,150 @@ def test_put_topology_patch(client: FlaskClient) -> None: "metadata": {"name": "test-topic", "labels": {}}, "spec": {"type": "RabbitMQTopic", "options": {"queue_name": "queue_2"}}, } + + +def test_put_topology_patch_ensure_topology_changed( + tmp_path: str, app: SaturnApp, client: FlaskClient, session: Session +) -> None: + topology = """ +apiVersion: saturn.flared.io/v1alpha1 +kind: SaturnExecutor +metadata: + name: default +spec: + type: ARQExecutor + options: + redis_url: "redis://redis" + queue_name: "arq:saturn-default" + redis_pool_args: + max_connections: 10000 + concurrency: 108 +--- +apiVersion: saturn.flared.io/v1alpha1 +kind: SaturnInventory +metadata: + name: test-inventory +spec: + type: testtype +--- +apiVersion: saturn.flared.io/v1alpha1 +kind: SaturnJobDefinition +metadata: + name: job_1 + labels: + owner: team-saturn +spec: + minimalInterval: "@weekly" + template: + input: + inventory: test-inventory + pipeline: + name: something.saturn.pipelines.aa.bb +--- + """ + with open(f"{tmp_path}/topology.yaml", "+w") as f: + f.write(topology) + + app.saturn.config.static_definitions_directories = [tmp_path] + app.saturn.load_static_definition(session=session) + + resp = client.post("/api/jobs/sync") + assert resp.status_code == 200 + assert resp.json == {} + resp = client.post("/api/lock", json={"worker_id": "worker-1"}) + assert resp.json == { + "executors": [ + { + "name": "default", + "options": { + "concurrency": 108, + "queue_name": "arq:saturn-default", + "redis_pool_args": {"max_connections": 10000}, + "redis_url": "redis://redis", + }, + "type": "ARQExecutor", + } + ], + "items": [ + { + "config": {}, + "executor": "default", + "input": {"name": "test-inventory", "options": {}, "type": "testtype"}, + "labels": {"owner": "team-saturn"}, + "name": mock.ANY, + "output": {}, + "pipeline": { + "args": {}, + "info": { + "name": "something.saturn.pipelines.aa.bb", + "resources": {}, + }, + }, + "state": { + "cursor": None, + "started_at": mock.ANY, + }, + } + ], + "resources": [], + "resources_providers": [], + } + + # Let's change the pipeline name + resp = client.put( + "/api/topologies/patch", + json={ + "apiVersion": "saturn.flared.io/v1alpha1", + "kind": "SaturnJobDefinition", + "metadata": {"name": "job_1"}, + "spec": { + "template": { + "pipeline": {"name": "something.else.saturn.pipelines.aa.bb"}, + }, + }, + }, + ) + + # And reset the static definition + session.commit() + app.saturn.load_static_definition(session=session) + + # Make sure we have the new topology version + resp = client.post("/api/lock", json={"worker_id": "worker-1"}) + assert resp.json == { + "executors": [ + { + "name": "default", + "options": { + "concurrency": 108, + "queue_name": "arq:saturn-default", + "redis_pool_args": {"max_connections": 10000}, + "redis_url": "redis://redis", + }, + "type": "ARQExecutor", + } + ], + "items": [ + { + "config": {}, + "executor": "default", + "input": {"name": "test-inventory", "options": {}, "type": "testtype"}, + "labels": {"owner": "team-saturn"}, + "name": mock.ANY, + "output": {}, + "pipeline": { + "args": {}, + "info": { + "name": "something.else.saturn.pipelines.aa.bb", + "resources": {}, + }, + }, + "state": { + "cursor": None, + "started_at": mock.ANY, + }, + } + ], + "resources": [], + "resources_providers": [], + } diff --git a/tests/worker_manager/config/test_declarative.py b/tests/worker_manager/config/test_declarative.py index 67df81cc..6d203d74 100644 --- a/tests/worker_manager/config/test_declarative.py +++ b/tests/worker_manager/config/test_declarative.py @@ -1,10 +1,17 @@ import os import pytest +from flask.testing import FlaskClient +from sqlalchemy.orm import Session from saturn_engine.core.api import ComponentDefinition from saturn_engine.core.api import JobDefinition from saturn_engine.core.api import ResourceItem +from saturn_engine.stores import topologies_store +from saturn_engine.utils.declarative_config import BaseObject +from saturn_engine.utils.declarative_config import ObjectMetadata +from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str +from saturn_engine.worker_manager.config.declarative import compile_static_definitions from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths from saturn_engine.worker_manager.config.declarative import load_definitions_from_str @@ -606,3 +613,70 @@ def test_dynamic_definition() -> None: static_definitions = load_definitions_from_str(resources_provider_str) assert "test-inventory" in static_definitions.inventories assert static_definitions.inventories["test-inventory"].name == "test-inventory" + + +def test_compile_static_definitions_with_patches( + client: FlaskClient, session: Session +) -> None: + concurrency_definition_str = """ + apiVersion: saturn.flared.io/v1alpha1 + kind: SaturnResource + metadata: + name: test-resource + labels: + owner: team-saturn + spec: + type: TestApiKey + data: + key: "qwe" + default_delay: 10 + concurrency: 2 + """ + + uncompiled_objects = load_uncompiled_objects_from_str(concurrency_definition_str) + + compileed_static_definitions_without_patch = compile_static_definitions( + uncompiled_objects=uncompiled_objects + ) + + assert compileed_static_definitions_without_patch.resources == { + "test-resource-1": ResourceItem( + name="test-resource-1", + type="TestApiKey", + data={"key": "qwe"}, + default_delay=10.0, + rate_limit=None, + ), + "test-resource-2": ResourceItem( + name="test-resource-2", + type="TestApiKey", + data={"key": "qwe"}, + default_delay=10.0, + rate_limit=None, + ), + } + + # Now we create a patch to change the resource concurrency + patch = topologies_store.patch( + session=session, + patch=BaseObject( + kind="SaturnResource", + apiVersion="saturn.flared.io/v1alpha1", + metadata=ObjectMetadata(name="test-resource"), + spec={"concurrency": 1}, + ), + ) + + compileed_static_definitions_without_patch = compile_static_definitions( + uncompiled_objects=uncompiled_objects, patches=[patch] + ) + + assert compileed_static_definitions_without_patch.resources == { + "test-resource": ResourceItem( + name="test-resource", + type="TestApiKey", + data={"key": "qwe"}, + default_delay=10.0, + rate_limit=None, + ) + } diff --git a/tests/worker_manager/conftest.py b/tests/worker_manager/conftest.py index 73e7e143..8e4f4019 100644 --- a/tests/worker_manager/conftest.py +++ b/tests/worker_manager/conftest.py @@ -9,6 +9,7 @@ from saturn_engine.core import api from saturn_engine.models import Base from saturn_engine.worker_manager import server as worker_manager_server +from saturn_engine.worker_manager.app import SaturnApp from saturn_engine.worker_manager.config.declarative import StaticDefinitions @@ -95,17 +96,23 @@ def fake_job_definition( @pytest.fixture -def client( - static_definitions: StaticDefinitions, -) -> t.Iterator[FlaskClient]: +def app() -> t.Iterator[SaturnApp]: app = worker_manager_server.get_app( config={ "TESTING": True, }, ) - app.saturn._static_definitions = static_definitions with app.app_context(): Base.metadata.drop_all(bind=database.engine()) Base.metadata.create_all(bind=database.engine()) - with app.test_client() as client: - yield client + yield app + + +@pytest.fixture +def client( + app: SaturnApp, + static_definitions: StaticDefinitions, +) -> t.Iterator[FlaskClient]: + app.saturn._static_definitions = static_definitions + with app.test_client() as client: + yield client