Skip to content

Commit

Permalink
feat(gwas_catalog): harmonisation dag
Browse files Browse the repository at this point in the history
  • Loading branch information
Szymon Szyszkowski committed Oct 22, 2024
1 parent 99518f0 commit 6750ea0
Show file tree
Hide file tree
Showing 7 changed files with 620 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/ot_orchestration/dags/config/gwas_catalog_harmonisation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
nodes:
- id: generate_sumstat_index
kind: Task
prerequisites: []
google_batch_index_specs:
manifest_generator_label: gwas_catalog_harmonisation
max_task_count: 1000
manifest_generator_specs:
commands:
- -c
- ./harmonise-sumstats.sh
- $RAW
- $HARMONISED
- $QC
- 1.0e-8
options:
manifest_kwargs:
qc_output_pattern: gs://gwas_catalog_inputs/summary_statistics_qc/**_SUCCESS
harm_output_pattern: gs://gwas_catalog_inputs/harmonised_summary_statistics/**_SUCCESS
raw_input_pattern: gs://gwas_catalog_inputs/raw_summary_statistics/**.h.tsv.gz
manifest_output_uri: gs://gwas_catalog_inputs/harmonisation_manifest.csv

- id: gwas_catalog_harmonisation
kind: Task
prerequisites:
- generate_sumstat_index
google_batch:
entrypoint: /bin/sh
image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/gentropy-app/gentropy:orchestration
resource_specs:
cpu_milli: 4000
memory_mib: 8000
boot_disk_mib: 8000
task_specs:
max_retry_count: 2
max_run_duration: "1h"
policy_specs:
machine_type: n1-standard-4
64 changes: 64 additions & 0 deletions src/ot_orchestration/dags/gwas_catalog_harmonisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Airflow DAG for GWAS Catalog sumstat harmonisation."""

from __future__ import annotations

import logging
from pathlib import Path

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG

from ot_orchestration.operators.batch.harmonisation import (
BatchIndexOperator,
GeneticsBatchJobOperator,
)
from ot_orchestration.utils import (
find_node_in_config,
read_yaml_config,
)
from ot_orchestration.utils.common import shared_dag_args, shared_dag_kwargs

SOURCE_CONFIG_FILE_PATH = (
Path(__file__).parent / "config" / "gwas_catalog_harmonisation.yaml"
)
config = read_yaml_config(SOURCE_CONFIG_FILE_PATH)


@task(task_id="begin")
def begin():
"""Starting the DAG execution."""
logging.info("STARTING")
logging.info(config)


@task(task_id="end")
def end():
"""Finish the DAG execution."""
logging.info("FINISHED")


with DAG(
dag_id=Path(__file__).stem,
description="Open Targets Genetics — GWAS Catalog Sumstat Harmonisation",
default_args=shared_dag_args,
**shared_dag_kwargs,
):
node_config = find_node_in_config(config["nodes"], "generate_sumstat_index")
batch_index = BatchIndexOperator(
task_id=node_config["id"],
batch_index_specs=node_config["google_batch_index_specs"],
)
node_config = find_node_in_config(config["nodes"], "gwas_catalog_harmonisation")
harmonisation_batch_job = GeneticsBatchJobOperator.partial(
task_id=node_config["id"],
job_name="harmonisation",
google_batch=node_config["google_batch"],
).expand(batch_index_row=batch_index.output)

chain(
begin(),
batch_index,
# harmonisation_batch_job,
end(),
)
153 changes: 153 additions & 0 deletions src/ot_orchestration/operators/batch/batch_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Batch Index."""

from __future__ import annotations

import logging
from typing import TypedDict

from airflow.exceptions import AirflowSkipException
from google.cloud.batch import Environment

from ot_orchestration.utils.batch import create_task_commands, create_task_env


class BatchCommandsSerialized(TypedDict):
options: dict[str, str]
commands: list[str]


class BatchEnvironmentsSerialized(TypedDict):
vars_list: list[dict[str, str]]


class BatchCommands:
def __init__(self, options: dict[str, str], commands: list[str]):
self.options = options
self.commands = commands

def construct(self) -> list[str]:
"""Construct Batch commands from mapping."""
logging.info(
"Constructing batch task commands from commands: %s and options: %s",
self.commands,
self.options,
)
commands = create_task_commands(self.commands, self.options)
return commands

def serialize(self) -> BatchCommandsSerialized:
"""Serialize batch commands."""
return BatchCommandsSerialized(options=self.options, commands=self.commands)

@staticmethod
def deserialize(data: BatchCommandsSerialized) -> BatchCommands:
"""Deserialize batch commands."""
return BatchCommands(options=data["options"], commands=data["commands"])


class BatchEnvironments:
def __init__(self, vars_list: list[dict[str, str]]):
self.vars_list = vars_list

def construct(self) -> list[Environment]:
"""Construct Batch Environment from list of mappings."""
logging.info(
"Constructing batch environments from vars_list: %s", self.vars_list
)
if not self.vars_list:
logging.warning(
"Can not create Batch environments from empty variable list, skipping"
)
raise AirflowSkipException(
"Can not create Batch environments from empty variable list"
)
environments = create_task_env(self.vars_list)
print(f"{environments=}")
return environments

def serialize(self) -> BatchEnvironmentsSerialized:
"""Serialize batch environments."""
return BatchEnvironmentsSerialized(vars_list=self.vars_list)

@staticmethod
def deserialize(data: BatchEnvironmentsSerialized) -> BatchEnvironments:
"""Deserialize batch environments."""
return BatchEnvironments(vars_list=data["vars_list"])


class BatchIndexRow(TypedDict):
idx: int
command: BatchCommandsSerialized
environment: BatchEnvironmentsSerialized


class BatchIndex:
"""Index of all batch jobs.
This object contains paths to individual manifest objects.
Each of the manifests will be a single batch job.
Each line of the individual manifest is a representation of the batch job task.
"""

def __init__(
self,
vars_list: list[dict[str, str]],
options: dict[str, str],
commands: list[str],
max_task_count: int,
) -> None:
self.vars_list = vars_list
self.options = options
self.commands = commands
self.max_task_count = max_task_count
self.vars_batches: list[BatchEnvironmentsSerialized] = []

def partition(self) -> BatchIndex:
"""Partition batch index by N chunks taking into account max_task_count."""
if not self.vars_list:
msg = "BatchIndex can not partition variable list, as list is empty."
logging.warning(msg)
return self

if self.max_task_count > len(self.vars_list):
logging.warning(
"BatchIndex will use only one partition due to size of the dataset being smaller then max_task_count %s < %s",
len(self.vars_list),
self.max_task_count,
)
self.max_task_count = len(self.vars_list)

for i in range(0, len(self.vars_list), self.max_task_count):
batch = self.vars_list[i : i + self.max_task_count]
self.vars_batches.append(BatchEnvironmentsSerialized(vars_list=batch))

logging.info("Created %s task list batches.", len(self.vars_batches))

return self

@property
def rows(self) -> list[BatchIndexRow]:
"""Create the master manifest that will gather the information needed to create batch Environments."""
rows: list[BatchIndexRow] = []
logging.info("Preparing BatchIndexRows. Each row represents a batch job.")
for idx, batch in enumerate(self.vars_batches):
rows.append(
{
"idx": idx + 1,
"command": BatchCommandsSerialized(
options=self.options, commands=self.commands
),
"environment": batch,
}
)

logging.info("Prepared %s BatchIndexRows", len(rows))
if not rows:
raise AirflowSkipException(
"Empty BatchIndexRows will not allow to create batch task. Skipping downstream"
)
return rows

def __repr__(self) -> str:
"""Get batch index string representation."""
return f"BatchIndex(vars_list={self.vars_list}, options={self.options}, commands={self.commands}, max_task_count={self.max_task_count})"
99 changes: 99 additions & 0 deletions src/ot_orchestration/operators/batch/harmonisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Operators for batch job."""

from __future__ import annotations

import logging
import time
from typing import Type

from airflow.models.baseoperator import BaseOperator
from airflow.providers.google.cloud.operators.cloud_batch import (
CloudBatchSubmitJobOperator,
)

from ot_orchestration.operators.batch.batch_index import (
BatchCommands,
BatchEnvironments,
BatchIndexRow,
)
from ot_orchestration.operators.batch.manifest_generators import ProtoManifestGenerator
from ot_orchestration.operators.batch.manifest_generators.harmonisation import (
HarmonisationManifestGenerator,
)
from ot_orchestration.types import GoogleBatchIndexSpecs, GoogleBatchSpecs
from ot_orchestration.utils.batch import create_batch_job, create_task_spec
from ot_orchestration.utils.common import GCP_PROJECT_GENETICS, GCP_REGION

logging.basicConfig(level=logging.DEBUG)


class BatchIndexOperator(BaseOperator):
"""Operator to prepare google batch job index.
Each manifest prepared by the operator should create an environment for a single batch job.
Each row of the individual manifest should represent individual batch task.
"""

# NOTE: here register all manifest generators.
manifest_generator_registry: dict[str, Type[ProtoManifestGenerator]] = {
"gwas_catalog_harmonisation": HarmonisationManifestGenerator
}

def __init__(
self,
batch_index_specs: GoogleBatchIndexSpecs,
**kwargs,
) -> None:
self.generator_label = batch_index_specs["manifest_generator_label"]
self.manifest_generator = self.get_generator(self.generator_label)
self.manifest_generator_specs = batch_index_specs["manifest_generator_specs"]
self.max_task_count = batch_index_specs["max_task_count"]
super().__init__(**kwargs)

@classmethod
def get_generator(cls, label: str) -> Type[ProtoManifestGenerator]:
"""Get the generator by it's label in the registry."""
return cls.manifest_generator_registry[label]

def execute(self, context) -> list[BatchIndexRow]:
"""Execute the operator."""
generator = self.manifest_generator.from_generator_config(
self.manifest_generator_specs, max_task_count=self.max_task_count
)
index = generator.generate_batch_index()
self.log.info(index)
partitioned_index = index.partition()
rows = partitioned_index.rows
return rows


class GeneticsBatchJobOperator(CloudBatchSubmitJobOperator):
def __init__(
self,
job_name: str,
batch_index_row: BatchIndexRow,
google_batch: GoogleBatchSpecs,
**kwargs,
):
super().__init__(
project_id=GCP_PROJECT_GENETICS,
region=GCP_REGION,
job_name=f"{job_name}-job-{batch_index_row['idx']}-{time.strftime('%Y%m%d-%H%M%S')}",
job=create_batch_job(
task=create_task_spec(
image=google_batch["image"],
commands=BatchCommands.deserialize(
batch_index_row["command"]
).construct(),
task_specs=google_batch["task_specs"],
resource_specs=google_batch["resource_specs"],
entrypoint=google_batch["entrypoint"],
),
task_env=BatchEnvironments.deserialize(
batch_index_row["environment"]
).construct(),
policy_specs=google_batch["policy_specs"],
),
deferrable=False,
**kwargs,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Manifest generators."""

from __future__ import annotations

from typing import Protocol

from ot_orchestration.operators.batch.batch_index import BatchIndex
from ot_orchestration.types import ManifestGeneratorSpecs


class ProtoManifestGenerator(Protocol):
@classmethod
def from_generator_config(
cls, specs: ManifestGeneratorSpecs, max_task_count: int
) -> ProtoManifestGenerator:
"""Constructor for Manifest Generator."""
raise NotImplementedError("Implement it in subclasses")

def generate_batch_index(self) -> BatchIndex:
"""Generate batch index."""
raise NotImplementedError("Implement it in subclasses")
Loading

0 comments on commit 6750ea0

Please sign in to comment.