Skip to content

Commit

Permalink
Added possibility to load ports and steps in a different workflow (#274)
Browse files Browse the repository at this point in the history
This commit adds the possibility to load `Step` and `Port` instances
from a `Database` to build a different workflow.
Documentation and unit tests have been added for this new feature.
  • Loading branch information
LanderOtto authored Jan 13, 2024
1 parent 3aa3297 commit 4c9fce1
Show file tree
Hide file tree
Showing 12 changed files with 919 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "caaa02b3477bf264a667c1684176ecf3f5e03e6c916d3107299c33670506239c"
CHECKSUM: "fc9bdd01ef90f0b24d019da7683aa528af10119ef54d0a13cb16ec7adaa04242"
run: |
cd docs/
HASH="$(make checksum | tail -n1)"
echo "Docs checksum is ${HASH}"
test "${HASH}" == "${CHECKSUM}"
test-flux:
runs-on: ubuntu-22.04
Expand Down
83 changes: 81 additions & 2 deletions docs/source/ext/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ StreamFlow relies on a persistent ``Database`` to store all the metadata regardi
Each ``PersistableEntity`` is identified by a unique numerical ``persistent_id`` related to the corresponding ``Database`` record. Two methods, ``save`` and ``load``, allow persisting the entity in the ``Database`` and retrieving it from the persistent record. Note that ``load`` is a class method, as it must construct a new instance.

The ``load`` method receives three input parameters: the current execution ``context``, the ``persistent_id`` of the instance that should be loaded, and a ``loading_context``. The latter keeps track of all the objects already loaded in the current transaction, serving as a cache to efficiently load nested entities and prevent deadlocks when dealing with circular references.
The ``load`` method receives three input parameters: the current execution ``context``, the ``persistent_id`` of the instance that should be loaded, and a ``loading_context`` (see :ref:`DatabaseLoadingContext <DatabaseLoadingContext>`). Note that the ``load`` method should not directly assign the ``persistent_id`` to the new entity, as this operation is in charge to the :ref:`DatabaseLoadingContext <DatabaseLoadingContext>` class.

Persistence
===========

The ``Database`` interface, defined in the ``streamflow.core.persistence`` module, contains all the methods to create, modify, and retrieve this metadata. Data deletion is unnecessary, as StreamFlow never removes existing records. Internally, the ``save`` and ``load`` methods call one or more of these methods to perform the desired operations.

Expand Down Expand Up @@ -230,8 +233,9 @@ Each ``get_data`` method receives in input the identifier (commonly the ``persis

The ``close`` method receives no input parameter and does not return anything. It frees stateful resources potentially allocated during the object’s lifetime, e.g., network or database connections.


Implementations
===============
---------------

====== ============================================
Type Class
Expand All @@ -247,3 +251,78 @@ The database schema is structured as follows:

.. literalinclude:: ../../../streamflow/persistence/schemas/sqlite.sql
:language: sql


DatabaseLoadingContext
======================
Workflow loading is a delicate operation. If not managed properly, it can be costly in terms of time and memory and lead to deadlocks in case of circular references.
The ``DatabaseLoadingContext`` interface allows to define classes in charge of managing these aspects. Users should always rely on these classes to load entities, instead of directly calling ``load`` methods from ``PersistableEntity`` instances.

.. code-block:: python
def add_deployment(self, persistent_id: int, deployment: DeploymentConfig):
...
def add_filter(self, persistent_id: int, filter_config: FilterConfig):
...
def add_port(self, persistent_id: int, port: Port):
...
def add_step(self, persistent_id: int, step: Step):
...
def add_target(self, persistent_id: int, target: Target):
...
def add_token(self, persistent_id: int, token: Token):
...
def add_workflow(self, persistent_id: int, workflow: Workflow):
...
async def load_deployment(self, context: StreamFlowContext, persistent_id: int):
...
async def load_filter(self, context: StreamFlowContext, persistent_id: int):
...
async def load_port(self, context: StreamFlowContext, persistent_id: int):
...
async def load_step(self, context: StreamFlowContext, persistent_id: int):
...
async def load_target(self, context: StreamFlowContext, persistent_id: int):
...
async def load_token(self, context: StreamFlowContext, persistent_id: int):
...
async def load_workflow(self, context: StreamFlowContext, persistent_id: int):
...
Implementations
---------------

==================================================================== =============================================================
Name Class
==================================================================== =============================================================
:ref:`DefaultDatabaseLoadingContext <DefaultDatabaseLoadingContext>` streamflow.persistent.loading_context.DefaultDatabaseLoadingContext
:ref:`WorkflowBuilder <WorkflowBuilder>` streamflow.persistent.loading_context.WorkflowBuilder
==================================================================== =============================================================

DefaultDatabaseLoadingContext
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The ``DefaultDatabaseLoadingContext`` keeps track of all the objects already loaded in the current transaction, serving as a cache to efficiently load nested entities and prevent deadlocks when dealing with circular references.
Furthermore, it is in charge of assigning the ``persistent_id`` when an entity is added to the cache through an ``add_*`` method.

WorkflowBuilder
^^^^^^^^^^^^^^^
The ``WorkflowBuilder`` class loads the steps and ports of an existing workflow from a ``Database`` and inserts them into a new workflow object received as a constructor argument. It extends the ``DefaultDatabaseLoadingContext`` class and overrides only the methods involving ``step``, ``port``, and ``workflow`` entities. In particular, the ``add_*`` methods of these entities must not set the ``persistent_id``, as they are dealing with a newly-created workflow, and the ``load_*`` methods should reset the internal state of their entities to the initial value (e.g., reset the status to `Status.WAITING` and clear the `terminated` flag).

The ``load_workflow`` method must behave in two different ways, depending on whether it is called directly from a user or in the internal logic of another entity's ``load`` method. In the first case, it should load all the entities related to the original workflow, identified by the ``persistent_id`` argument, into the new one. In the latter case it should simply return the new workflow entity being built.

Other entities, such as ``deployment`` and ``target`` objects, can be safely shared between the old and the new workflows, as their internal state does not need to be modified. Therefore, they can be loaded following the common path implemented in the ``DefaultDatabaseLoadingContext`` class.

3 changes: 0 additions & 3 deletions streamflow/core/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ async def load(
lazy=row["lazy"],
workdir=row["workdir"],
)
obj.persistent_id = persistent_id
loading_context.add_deployment(persistent_id, obj)
return obj

Expand Down Expand Up @@ -276,7 +275,6 @@ async def load(
row = await context.database.get_target(persistent_id)
type = cast(Type[Target], utils.get_class_from_name(row["type"]))
obj = await type._load(context, row, loading_context)
obj.persistent_id = persistent_id
loading_context.add_target(persistent_id, obj)
return obj

Expand Down Expand Up @@ -338,7 +336,6 @@ async def load(
type=row["type"],
config=json.loads(row["config"]),
)
obj.persistent_id = persistent_id
loading_context.add_filter(persistent_id, obj)
return obj

Expand Down
1 change: 0 additions & 1 deletion streamflow/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
MutableSequence,
TYPE_CHECKING,
)

from streamflow.core.exception import WorkflowExecutionException

if TYPE_CHECKING:
Expand Down
6 changes: 1 addition & 5 deletions streamflow/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ async def load(
row = await context.database.get_port(persistent_id)
type = cast(Type[Port], utils.get_class_from_name(row["type"]))
port = await type._load(context, row, loading_context)
port.persistent_id = persistent_id
loading_context.add_port(persistent_id, port)
return port

Expand Down Expand Up @@ -432,14 +431,14 @@ async def load(
row = await context.database.get_step(persistent_id)
type = cast(Type[Step], utils.get_class_from_name(row["type"]))
step = await type._load(context, row, loading_context)
step.persistent_id = persistent_id
step.status = Status(row["status"])
step.terminated = step.status in [
Status.COMPLETED,
Status.FAILED,
Status.SKIPPED,
]
input_deps = await context.database.get_input_ports(persistent_id)
loading_context.add_step(persistent_id, step)
input_ports = await asyncio.gather(
*(
asyncio.create_task(loading_context.load_port(context, d["port"]))
Expand All @@ -457,7 +456,6 @@ async def load(
step.output_ports = {
d["name"]: p.name for d, p in zip(output_deps, output_ports)
}
loading_context.add_step(persistent_id, step)
return step

@abstractmethod
Expand Down Expand Up @@ -549,7 +547,6 @@ async def load(
row = await context.database.get_token(persistent_id)
type = cast(Type[Token], utils.get_class_from_name(row["type"]))
token = await type._load(context, row, loading_context)
token.persistent_id = persistent_id
loading_context.add_token(persistent_id, token)
return token

Expand Down Expand Up @@ -676,7 +673,6 @@ async def load(
workflow = cls(
context=context, type=row["type"], config=params["config"], name=row["name"]
)
workflow.persistent_id = row["id"]
loading_context.add_workflow(persistent_id, workflow)
rows = await context.database.get_workflow_ports(persistent_id)
workflow.ports = {
Expand Down
8 changes: 5 additions & 3 deletions streamflow/cwl/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ async def _load(
format_graph=(
format_graph.parse(data=row["format_graph"])
if row["format_graph"] is not None
else format_graph
),
else None
), # todo: fix multiple instances
full_js=row["full_js"],
load_contents=row["load_contents"],
load_listing=LoadListing(row["load_listing"])
Expand Down Expand Up @@ -300,7 +300,9 @@ async def _save_additional_params(self, context: StreamFlowContext):
"expression_lib": self.expression_lib,
"file_format": self.file_format,
"format_graph": (
self.format_graph.serialize() if self.format_graph else None
self.format_graph.serialize()
if self.format_graph is not None
else None
),
"full_js": self.full_js,
"load_contents": self.load_contents,
Expand Down
59 changes: 58 additions & 1 deletion streamflow/persistence/loading_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import MutableMapping

from streamflow.core.context import StreamFlowContext
from streamflow.core.deployment import DeploymentConfig, Target, FilterConfig
from streamflow.core.persistence import DatabaseLoadingContext
from streamflow.core.workflow import Port, Step, Token, Workflow
from streamflow.core.workflow import Port, Step, Token, Workflow, Status


class DefaultDatabaseLoadingContext(DatabaseLoadingContext):
Expand All @@ -18,24 +19,31 @@ def __init__(self):
self._workflows: MutableMapping[int, Workflow] = {}

def add_deployment(self, persistent_id: int, deployment: DeploymentConfig):
deployment.persistent_id = persistent_id
self._deployment_configs[persistent_id] = deployment

def add_filter(self, persistent_id: int, filter_config: FilterConfig):
filter_config.persistent_id = persistent_id
self._filter_configs[persistent_id] = filter_config

def add_port(self, persistent_id: int, port: Port):
port.persistent_id = persistent_id
self._ports[persistent_id] = port

def add_step(self, persistent_id: int, step: Step):
step.persistent_id = persistent_id
self._steps[persistent_id] = step

def add_target(self, persistent_id: int, target: Target):
target.persistent_id = persistent_id
self._targets[persistent_id] = target

def add_token(self, persistent_id: int, token: Token):
token.persistent_id = persistent_id
self._tokens[persistent_id] = token

def add_workflow(self, persistent_id: int, workflow: Workflow):
workflow.persistent_id = persistent_id
self._workflows[persistent_id] = workflow

async def load_deployment(self, context: StreamFlowContext, persistent_id: int):
Expand Down Expand Up @@ -72,3 +80,52 @@ async def load_workflow(self, context: StreamFlowContext, persistent_id: int):
return self._workflows.get(persistent_id) or await Workflow.load(
context, persistent_id, self
)


class WorkflowBuilder(DefaultDatabaseLoadingContext):
def __init__(self, workflow: Workflow):
super().__init__()
self.workflow: Workflow = workflow

def add_port(self, persistent_id: int, port: Port):
self._ports[persistent_id] = port

def add_step(self, persistent_id: int, step: Step):
self._steps[persistent_id] = step

def add_workflow(self, persistent_id: int, workflow: Workflow):
self._workflows[persistent_id] = self.workflow

async def load_step(self, context: StreamFlowContext, persistent_id: int):
if persistent_id in self._steps.keys():
return self._steps[persistent_id]
else:
step_row = await context.database.get_step(persistent_id)
if (step := self.workflow.steps.get(step_row["name"])) is None:
# If the step is not available in the new workflow, a new one must be created
self.add_workflow(step_row["workflow"], self.workflow)
step = await Step.load(context, persistent_id, self)

# restore initial step state
step.status = Status.WAITING
step.terminated = False

self.workflow.steps[step.name] = step
return step

async def load_port(self, context: StreamFlowContext, persistent_id: int):
if persistent_id in self._ports.keys():
return self._ports[persistent_id]
else:
port_row = await context.database.get_port(persistent_id)
if (port := self.workflow.ports.get(port_row["name"])) is None:
# If the port is not available in the new workflow, a new one must be created
self.add_workflow(port_row["workflow"], self.workflow)
port = await Port.load(context, persistent_id, self)
self.workflow.ports[port.name] = port
return port

async def load_workflow(self, context: StreamFlowContext, persistent_id: int):
if persistent_id not in self._workflows.keys():
await Workflow.load(context, persistent_id, self)
return self.workflow
Loading

0 comments on commit 4c9fce1

Please sign in to comment.