-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(gwas_catalog): harmonisation dag
- Loading branch information
Szymon Szyszkowski
committed
Oct 22, 2024
1 parent
99518f0
commit 6750ea0
Showing
7 changed files
with
620 additions
and
0 deletions.
There are no files selected for viewing
38 changes: 38 additions & 0 deletions
38
src/ot_orchestration/dags/config/gwas_catalog_harmonisation.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
nodes: | ||
- id: generate_sumstat_index | ||
kind: Task | ||
prerequisites: [] | ||
google_batch_index_specs: | ||
manifest_generator_label: gwas_catalog_harmonisation | ||
max_task_count: 1000 | ||
manifest_generator_specs: | ||
commands: | ||
- -c | ||
- ./harmonise-sumstats.sh | ||
- $RAW | ||
- $HARMONISED | ||
- $QC | ||
- 1.0e-8 | ||
options: | ||
manifest_kwargs: | ||
qc_output_pattern: gs://gwas_catalog_inputs/summary_statistics_qc/**_SUCCESS | ||
harm_output_pattern: gs://gwas_catalog_inputs/harmonised_summary_statistics/**_SUCCESS | ||
raw_input_pattern: gs://gwas_catalog_inputs/raw_summary_statistics/**.h.tsv.gz | ||
manifest_output_uri: gs://gwas_catalog_inputs/harmonisation_manifest.csv | ||
|
||
- id: gwas_catalog_harmonisation | ||
kind: Task | ||
prerequisites: | ||
- generate_sumstat_index | ||
google_batch: | ||
entrypoint: /bin/sh | ||
image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/gentropy-app/gentropy:orchestration | ||
resource_specs: | ||
cpu_milli: 4000 | ||
memory_mib: 8000 | ||
boot_disk_mib: 8000 | ||
task_specs: | ||
max_retry_count: 2 | ||
max_run_duration: "1h" | ||
policy_specs: | ||
machine_type: n1-standard-4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Airflow DAG for GWAS Catalog sumstat harmonisation.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
from airflow.decorators import task | ||
from airflow.models.baseoperator import chain | ||
from airflow.models.dag import DAG | ||
|
||
from ot_orchestration.operators.batch.harmonisation import ( | ||
BatchIndexOperator, | ||
GeneticsBatchJobOperator, | ||
) | ||
from ot_orchestration.utils import ( | ||
find_node_in_config, | ||
read_yaml_config, | ||
) | ||
from ot_orchestration.utils.common import shared_dag_args, shared_dag_kwargs | ||
|
||
SOURCE_CONFIG_FILE_PATH = ( | ||
Path(__file__).parent / "config" / "gwas_catalog_harmonisation.yaml" | ||
) | ||
config = read_yaml_config(SOURCE_CONFIG_FILE_PATH) | ||
|
||
|
||
@task(task_id="begin") | ||
def begin(): | ||
"""Starting the DAG execution.""" | ||
logging.info("STARTING") | ||
logging.info(config) | ||
|
||
|
||
@task(task_id="end") | ||
def end(): | ||
"""Finish the DAG execution.""" | ||
logging.info("FINISHED") | ||
|
||
|
||
with DAG( | ||
dag_id=Path(__file__).stem, | ||
description="Open Targets Genetics — GWAS Catalog Sumstat Harmonisation", | ||
default_args=shared_dag_args, | ||
**shared_dag_kwargs, | ||
): | ||
node_config = find_node_in_config(config["nodes"], "generate_sumstat_index") | ||
batch_index = BatchIndexOperator( | ||
task_id=node_config["id"], | ||
batch_index_specs=node_config["google_batch_index_specs"], | ||
) | ||
node_config = find_node_in_config(config["nodes"], "gwas_catalog_harmonisation") | ||
harmonisation_batch_job = GeneticsBatchJobOperator.partial( | ||
task_id=node_config["id"], | ||
job_name="harmonisation", | ||
google_batch=node_config["google_batch"], | ||
).expand(batch_index_row=batch_index.output) | ||
|
||
chain( | ||
begin(), | ||
batch_index, | ||
# harmonisation_batch_job, | ||
end(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""Batch Index.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TypedDict | ||
|
||
from airflow.exceptions import AirflowSkipException | ||
from google.cloud.batch import Environment | ||
|
||
from ot_orchestration.utils.batch import create_task_commands, create_task_env | ||
|
||
|
||
class BatchCommandsSerialized(TypedDict): | ||
options: dict[str, str] | ||
commands: list[str] | ||
|
||
|
||
class BatchEnvironmentsSerialized(TypedDict): | ||
vars_list: list[dict[str, str]] | ||
|
||
|
||
class BatchCommands: | ||
def __init__(self, options: dict[str, str], commands: list[str]): | ||
self.options = options | ||
self.commands = commands | ||
|
||
def construct(self) -> list[str]: | ||
"""Construct Batch commands from mapping.""" | ||
logging.info( | ||
"Constructing batch task commands from commands: %s and options: %s", | ||
self.commands, | ||
self.options, | ||
) | ||
commands = create_task_commands(self.commands, self.options) | ||
return commands | ||
|
||
def serialize(self) -> BatchCommandsSerialized: | ||
"""Serialize batch commands.""" | ||
return BatchCommandsSerialized(options=self.options, commands=self.commands) | ||
|
||
@staticmethod | ||
def deserialize(data: BatchCommandsSerialized) -> BatchCommands: | ||
"""Deserialize batch commands.""" | ||
return BatchCommands(options=data["options"], commands=data["commands"]) | ||
|
||
|
||
class BatchEnvironments: | ||
def __init__(self, vars_list: list[dict[str, str]]): | ||
self.vars_list = vars_list | ||
|
||
def construct(self) -> list[Environment]: | ||
"""Construct Batch Environment from list of mappings.""" | ||
logging.info( | ||
"Constructing batch environments from vars_list: %s", self.vars_list | ||
) | ||
if not self.vars_list: | ||
logging.warning( | ||
"Can not create Batch environments from empty variable list, skipping" | ||
) | ||
raise AirflowSkipException( | ||
"Can not create Batch environments from empty variable list" | ||
) | ||
environments = create_task_env(self.vars_list) | ||
print(f"{environments=}") | ||
return environments | ||
|
||
def serialize(self) -> BatchEnvironmentsSerialized: | ||
"""Serialize batch environments.""" | ||
return BatchEnvironmentsSerialized(vars_list=self.vars_list) | ||
|
||
@staticmethod | ||
def deserialize(data: BatchEnvironmentsSerialized) -> BatchEnvironments: | ||
"""Deserialize batch environments.""" | ||
return BatchEnvironments(vars_list=data["vars_list"]) | ||
|
||
|
||
class BatchIndexRow(TypedDict): | ||
idx: int | ||
command: BatchCommandsSerialized | ||
environment: BatchEnvironmentsSerialized | ||
|
||
|
||
class BatchIndex: | ||
"""Index of all batch jobs. | ||
This object contains paths to individual manifest objects. | ||
Each of the manifests will be a single batch job. | ||
Each line of the individual manifest is a representation of the batch job task. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vars_list: list[dict[str, str]], | ||
options: dict[str, str], | ||
commands: list[str], | ||
max_task_count: int, | ||
) -> None: | ||
self.vars_list = vars_list | ||
self.options = options | ||
self.commands = commands | ||
self.max_task_count = max_task_count | ||
self.vars_batches: list[BatchEnvironmentsSerialized] = [] | ||
|
||
def partition(self) -> BatchIndex: | ||
"""Partition batch index by N chunks taking into account max_task_count.""" | ||
if not self.vars_list: | ||
msg = "BatchIndex can not partition variable list, as list is empty." | ||
logging.warning(msg) | ||
return self | ||
|
||
if self.max_task_count > len(self.vars_list): | ||
logging.warning( | ||
"BatchIndex will use only one partition due to size of the dataset being smaller then max_task_count %s < %s", | ||
len(self.vars_list), | ||
self.max_task_count, | ||
) | ||
self.max_task_count = len(self.vars_list) | ||
|
||
for i in range(0, len(self.vars_list), self.max_task_count): | ||
batch = self.vars_list[i : i + self.max_task_count] | ||
self.vars_batches.append(BatchEnvironmentsSerialized(vars_list=batch)) | ||
|
||
logging.info("Created %s task list batches.", len(self.vars_batches)) | ||
|
||
return self | ||
|
||
@property | ||
def rows(self) -> list[BatchIndexRow]: | ||
"""Create the master manifest that will gather the information needed to create batch Environments.""" | ||
rows: list[BatchIndexRow] = [] | ||
logging.info("Preparing BatchIndexRows. Each row represents a batch job.") | ||
for idx, batch in enumerate(self.vars_batches): | ||
rows.append( | ||
{ | ||
"idx": idx + 1, | ||
"command": BatchCommandsSerialized( | ||
options=self.options, commands=self.commands | ||
), | ||
"environment": batch, | ||
} | ||
) | ||
|
||
logging.info("Prepared %s BatchIndexRows", len(rows)) | ||
if not rows: | ||
raise AirflowSkipException( | ||
"Empty BatchIndexRows will not allow to create batch task. Skipping downstream" | ||
) | ||
return rows | ||
|
||
def __repr__(self) -> str: | ||
"""Get batch index string representation.""" | ||
return f"BatchIndex(vars_list={self.vars_list}, options={self.options}, commands={self.commands}, max_task_count={self.max_task_count})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Operators for batch job.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import time | ||
from typing import Type | ||
|
||
from airflow.models.baseoperator import BaseOperator | ||
from airflow.providers.google.cloud.operators.cloud_batch import ( | ||
CloudBatchSubmitJobOperator, | ||
) | ||
|
||
from ot_orchestration.operators.batch.batch_index import ( | ||
BatchCommands, | ||
BatchEnvironments, | ||
BatchIndexRow, | ||
) | ||
from ot_orchestration.operators.batch.manifest_generators import ProtoManifestGenerator | ||
from ot_orchestration.operators.batch.manifest_generators.harmonisation import ( | ||
HarmonisationManifestGenerator, | ||
) | ||
from ot_orchestration.types import GoogleBatchIndexSpecs, GoogleBatchSpecs | ||
from ot_orchestration.utils.batch import create_batch_job, create_task_spec | ||
from ot_orchestration.utils.common import GCP_PROJECT_GENETICS, GCP_REGION | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
class BatchIndexOperator(BaseOperator): | ||
"""Operator to prepare google batch job index. | ||
Each manifest prepared by the operator should create an environment for a single batch job. | ||
Each row of the individual manifest should represent individual batch task. | ||
""" | ||
|
||
# NOTE: here register all manifest generators. | ||
manifest_generator_registry: dict[str, Type[ProtoManifestGenerator]] = { | ||
"gwas_catalog_harmonisation": HarmonisationManifestGenerator | ||
} | ||
|
||
def __init__( | ||
self, | ||
batch_index_specs: GoogleBatchIndexSpecs, | ||
**kwargs, | ||
) -> None: | ||
self.generator_label = batch_index_specs["manifest_generator_label"] | ||
self.manifest_generator = self.get_generator(self.generator_label) | ||
self.manifest_generator_specs = batch_index_specs["manifest_generator_specs"] | ||
self.max_task_count = batch_index_specs["max_task_count"] | ||
super().__init__(**kwargs) | ||
|
||
@classmethod | ||
def get_generator(cls, label: str) -> Type[ProtoManifestGenerator]: | ||
"""Get the generator by it's label in the registry.""" | ||
return cls.manifest_generator_registry[label] | ||
|
||
def execute(self, context) -> list[BatchIndexRow]: | ||
"""Execute the operator.""" | ||
generator = self.manifest_generator.from_generator_config( | ||
self.manifest_generator_specs, max_task_count=self.max_task_count | ||
) | ||
index = generator.generate_batch_index() | ||
self.log.info(index) | ||
partitioned_index = index.partition() | ||
rows = partitioned_index.rows | ||
return rows | ||
|
||
|
||
class GeneticsBatchJobOperator(CloudBatchSubmitJobOperator): | ||
def __init__( | ||
self, | ||
job_name: str, | ||
batch_index_row: BatchIndexRow, | ||
google_batch: GoogleBatchSpecs, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
project_id=GCP_PROJECT_GENETICS, | ||
region=GCP_REGION, | ||
job_name=f"{job_name}-job-{batch_index_row['idx']}-{time.strftime('%Y%m%d-%H%M%S')}", | ||
job=create_batch_job( | ||
task=create_task_spec( | ||
image=google_batch["image"], | ||
commands=BatchCommands.deserialize( | ||
batch_index_row["command"] | ||
).construct(), | ||
task_specs=google_batch["task_specs"], | ||
resource_specs=google_batch["resource_specs"], | ||
entrypoint=google_batch["entrypoint"], | ||
), | ||
task_env=BatchEnvironments.deserialize( | ||
batch_index_row["environment"] | ||
).construct(), | ||
policy_specs=google_batch["policy_specs"], | ||
), | ||
deferrable=False, | ||
**kwargs, | ||
) |
21 changes: 21 additions & 0 deletions
21
src/ot_orchestration/operators/batch/manifest_generators/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Manifest generators.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Protocol | ||
|
||
from ot_orchestration.operators.batch.batch_index import BatchIndex | ||
from ot_orchestration.types import ManifestGeneratorSpecs | ||
|
||
|
||
class ProtoManifestGenerator(Protocol): | ||
@classmethod | ||
def from_generator_config( | ||
cls, specs: ManifestGeneratorSpecs, max_task_count: int | ||
) -> ProtoManifestGenerator: | ||
"""Constructor for Manifest Generator.""" | ||
raise NotImplementedError("Implement it in subclasses") | ||
|
||
def generate_batch_index(self) -> BatchIndex: | ||
"""Generate batch index.""" | ||
raise NotImplementedError("Implement it in subclasses") |
Oops, something went wrong.