Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile topology with patches #439

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/saturn_engine/stores/topologies_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ def patch(*, session: AnySession, patch: BaseObject) -> TopologyPatch:
)
session.execute(stmt) # type: ignore
return topology_patch


def get_patches(*, session: AnySession) -> list[TopologyPatch]:
"""_summary_
Return all the patches
"""
return session.query(TopologyPatch).all()
1 change: 0 additions & 1 deletion src/saturn_engine/utils/declarative_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,4 @@ def load_uncompiled_objects_from_directory(config_dir: str) -> list[UncompiledOb

with open(os.path.join(root, filename), "r", encoding="utf-8") as f:
uncompiled_objects.extend(load_uncompiled_objects_from_str(f.read()))

return uncompiled_objects
14 changes: 14 additions & 0 deletions src/saturn_engine/utils/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import typing as t


def deep_merge(a: dict[str, t.Any], b: dict[str, t.Any]) -> dict[str, t.Any]:
"""
Merge b into a
"""
result = a.copy()
for key, value in b.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
37 changes: 35 additions & 2 deletions src/saturn_engine/worker_manager/config/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
from collections import defaultdict

from saturn_engine.models.topology_patches import TopologyPatch
from saturn_engine.utils import dict as dict_utils
from saturn_engine.utils.declarative_config import UncompiledObject
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_path
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str
Expand All @@ -31,8 +33,15 @@

def compile_static_definitions(
uncompiled_objects: list[UncompiledObject],
patches: list[TopologyPatch] | None = None,
) -> StaticDefinitions:
objects_by_kind: DefaultDict[str, dict[str, UncompiledObject]] = defaultdict(dict)

if patches:
uncompiled_objects = merge_with_patches(
uncompiled_objects=uncompiled_objects, patches=patches
)

for uncompiled_object in uncompiled_objects:
if uncompiled_object.name in objects_by_kind[uncompiled_object.kind]:
raise Exception(
Expand Down Expand Up @@ -122,12 +131,14 @@ def load_definitions_from_str(definitions: str) -> StaticDefinitions:
return compile_static_definitions(load_uncompiled_objects_from_str(definitions))


def load_definitions_from_paths(config_dirs: list[str]) -> StaticDefinitions:
def load_definitions_from_paths(
config_dirs: list[str], patches: list[TopologyPatch] | None = None
) -> StaticDefinitions:
uncompiled_objects = []
for config_dir in config_dirs:
uncompiled_objects.extend(load_uncompiled_objects_from_path(config_dir))

return compile_static_definitions(uncompiled_objects)
return compile_static_definitions(uncompiled_objects, patches=patches)


def filter_with_jobs_selector(
Expand All @@ -141,3 +152,25 @@ def filter_with_jobs_selector(
if pattern.search(name)
}
return dataclasses.replace(definitions, jobs=jobs, job_definitions=job_definitions)


def merge_with_patches(
uncompiled_objects: list[UncompiledObject], patches: list[TopologyPatch]
) -> list[UncompiledObject]:
uncompiled_object_by_kind_and_name = {
(u.kind, u.name): u for u in uncompiled_objects
}
for patch in patches:
uncompiled_object = uncompiled_object_by_kind_and_name.get(
(patch.kind, patch.name)
)
if not uncompiled_object:
logging.warning(
f"Can't find an uncompiled objects to use with patch {patch=}"
)
continue

uncompiled_object.data = dict_utils.deep_merge(
a=uncompiled_object.data, b=patch.data
)
return uncompiled_objects
7 changes: 6 additions & 1 deletion src/saturn_engine/worker_manager/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from saturn_engine.config import WorkerManagerConfig
from saturn_engine.stores import topologies_store
from saturn_engine.utils.sqlalchemy import AnySession
from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector
from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths
Expand Down Expand Up @@ -38,7 +39,11 @@ def _load_static_definition(
- Jobs
- JobDefinitions
"""
definitions = load_definitions_from_paths(config.static_definitions_directories)
patches = topologies_store.get_patches(session=session)
definitions = load_definitions_from_paths(
config.static_definitions_directories, patches=patches
)

if config.static_definitions_jobs_selector:
definitions = filter_with_jobs_selector(
definitions=definitions,
Expand Down
3 changes: 0 additions & 3 deletions src/saturn_engine/worker_manager/services/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def lock_jobs(
session: AnySyncSession,
) -> LockResponse:
logger = logging.getLogger(f"{__name__}.lock_jobs")

# Note:
# - Leftover items remain unassigned.
assignation_expiration_cutoff: datetime = datetime.now() - timedelta(minutes=15)
Expand Down Expand Up @@ -61,7 +60,6 @@ def lock_jobs(
selector=lock_input.selector,
)
)

# Join definitions and filtered out by executors
for item in assigned_items.copy():
try:
Expand Down Expand Up @@ -129,7 +127,6 @@ def lock_jobs(
continue

executors.setdefault(executor.name, executor)

# Refresh assignments
new_assigned_at = datetime.now()
for assigned_item in assigned_items:
Expand Down
152 changes: 152 additions & 0 deletions tests/worker_manager/api/test_topologies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from unittest import mock

from flask.testing import FlaskClient
from sqlalchemy.orm import Session

from saturn_engine.worker_manager.app import SaturnApp


def test_put_topology_patch(client: FlaskClient) -> None:
Expand Down Expand Up @@ -36,3 +41,150 @@ def test_put_topology_patch(client: FlaskClient) -> None:
"metadata": {"name": "test-topic", "labels": {}},
"spec": {"type": "RabbitMQTopic", "options": {"queue_name": "queue_2"}},
}
isra17 marked this conversation as resolved.
Show resolved Hide resolved


def test_put_topology_patch_ensure_topology_changed(
tmp_path: str, app: SaturnApp, client: FlaskClient, session: Session
) -> None:
topology = """
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnExecutor
metadata:
name: default
spec:
type: ARQExecutor
options:
redis_url: "redis://redis"
queue_name: "arq:saturn-default"
redis_pool_args:
max_connections: 10000
concurrency: 108
---
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnInventory
metadata:
name: test-inventory
spec:
type: testtype
---
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnJobDefinition
metadata:
name: job_1
labels:
owner: team-saturn
spec:
minimalInterval: "@weekly"
template:
input:
inventory: test-inventory
pipeline:
name: something.saturn.pipelines.aa.bb
---
"""
with open(f"{tmp_path}/topology.yaml", "+w") as f:
f.write(topology)

app.saturn.config.static_definitions_directories = [tmp_path]
app.saturn.load_static_definition(session=session)

resp = client.post("/api/jobs/sync")
assert resp.status_code == 200
assert resp.json == {}
resp = client.post("/api/lock", json={"worker_id": "worker-1"})
assert resp.json == {
"executors": [
{
"name": "default",
"options": {
"concurrency": 108,
"queue_name": "arq:saturn-default",
"redis_pool_args": {"max_connections": 10000},
"redis_url": "redis://redis",
},
"type": "ARQExecutor",
}
],
"items": [
{
"config": {},
"executor": "default",
"input": {"name": "test-inventory", "options": {}, "type": "testtype"},
"labels": {"owner": "team-saturn"},
"name": mock.ANY,
"output": {},
"pipeline": {
"args": {},
"info": {
"name": "something.saturn.pipelines.aa.bb",
"resources": {},
},
},
"state": {
"cursor": None,
"started_at": mock.ANY,
},
}
],
"resources": [],
"resources_providers": [],
}

# Let's change the pipeline name
resp = client.put(
"/api/topologies/patch",
json={
"apiVersion": "saturn.flared.io/v1alpha1",
"kind": "SaturnJobDefinition",
"metadata": {"name": "job_1"},
"spec": {
"template": {
"pipeline": {"name": "something.else.saturn.pipelines.aa.bb"},
},
},
},
)

# And reset the static definition
session.commit()
app.saturn.load_static_definition(session=session)

# Make sure we have the new topology version
resp = client.post("/api/lock", json={"worker_id": "worker-1"})
assert resp.json == {
"executors": [
{
"name": "default",
"options": {
"concurrency": 108,
"queue_name": "arq:saturn-default",
"redis_pool_args": {"max_connections": 10000},
"redis_url": "redis://redis",
},
"type": "ARQExecutor",
}
],
"items": [
{
"config": {},
"executor": "default",
"input": {"name": "test-inventory", "options": {}, "type": "testtype"},
"labels": {"owner": "team-saturn"},
"name": mock.ANY,
"output": {},
"pipeline": {
"args": {},
"info": {
"name": "something.else.saturn.pipelines.aa.bb",
"resources": {},
},
},
"state": {
"cursor": None,
"started_at": mock.ANY,
},
}
],
"resources": [],
"resources_providers": [],
}
74 changes: 74 additions & 0 deletions tests/worker_manager/config/test_declarative.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os

import pytest
from flask.testing import FlaskClient
from sqlalchemy.orm import Session

from saturn_engine.core.api import ComponentDefinition
from saturn_engine.core.api import JobDefinition
from saturn_engine.core.api import ResourceItem
from saturn_engine.stores import topologies_store
from saturn_engine.utils.declarative_config import BaseObject
from saturn_engine.utils.declarative_config import ObjectMetadata
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str
from saturn_engine.worker_manager.config.declarative import compile_static_definitions
from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector
from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths
from saturn_engine.worker_manager.config.declarative import load_definitions_from_str
Expand Down Expand Up @@ -606,3 +613,70 @@ def test_dynamic_definition() -> None:
static_definitions = load_definitions_from_str(resources_provider_str)
assert "test-inventory" in static_definitions.inventories
assert static_definitions.inventories["test-inventory"].name == "test-inventory"


def test_compile_static_definitions_with_patches(
client: FlaskClient, session: Session
) -> None:
concurrency_definition_str = """
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnResource
metadata:
name: test-resource
labels:
owner: team-saturn
spec:
type: TestApiKey
data:
key: "qwe"
default_delay: 10
concurrency: 2
"""

uncompiled_objects = load_uncompiled_objects_from_str(concurrency_definition_str)

compileed_static_definitions_without_patch = compile_static_definitions(
uncompiled_objects=uncompiled_objects
)

assert compileed_static_definitions_without_patch.resources == {
"test-resource-1": ResourceItem(
name="test-resource-1",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
),
"test-resource-2": ResourceItem(
name="test-resource-2",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
),
}

# Now we create a patch to change the resource concurrency
patch = topologies_store.patch(
session=session,
patch=BaseObject(
kind="SaturnResource",
apiVersion="saturn.flared.io/v1alpha1",
metadata=ObjectMetadata(name="test-resource"),
spec={"concurrency": 1},
),
)

compileed_static_definitions_without_patch = compile_static_definitions(
uncompiled_objects=uncompiled_objects, patches=[patch]
)

assert compileed_static_definitions_without_patch.resources == {
"test-resource": ResourceItem(
name="test-resource",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
)
}
Loading
Loading