Skip to content

Commit

Permalink
feat(feature_matrix): impute values for gene attribute cols (#895)
Browse files Browse the repository at this point in the history
* feat(feature_matrix): impute values for gene attribute cols + semantic test

* fix: change window

* chore: fill na in the feature matrix generation step
  • Loading branch information
ireneisdoomed authored Nov 5, 2024
1 parent 3639b23 commit 04b1e22
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 27 deletions.
22 changes: 19 additions & 3 deletions src/gentropy/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import reduce
from typing import TYPE_CHECKING, Type

import pyspark.sql.functions as f
from pyspark.sql import Window
from typing_extensions import Self

from gentropy.common.spark_helpers import convert_from_long_to_wide
Expand Down Expand Up @@ -128,18 +130,32 @@ def calculate_feature_missingness_rate(
}

def fill_na(
self: L2GFeatureMatrix, value: float = 0.0, subset: list[str] | None = None
self: L2GFeatureMatrix, na_value: float = 0.0, subset: list[str] | None = None
) -> L2GFeatureMatrix:
"""Fill missing values in a column with a given value.
For features that correspond to gene attributes, missing values are imputed using the mean of the column.
Args:
value (float): Value to replace missing values with. Defaults to 0.0.
na_value (float): Value to replace missing values with. Defaults to 0.0.
subset (list[str] | None): Subset of columns to consider. Defaults to None.
Returns:
L2GFeatureMatrix: L2G feature matrix dataset
"""
self._df = self._df.fillna(value, subset=subset)
cols_to_impute = ["proteinGeneCount500kb", "geneCount500kb", "isProteinCoding"]
for col in cols_to_impute:
if col not in self._df.columns:
continue
else:
self._df = self._df.withColumn(
col,
f.when(
f.col(col).isNull(),
f.mean(f.col(col)).over(Window.partitionBy("studyLocusId")),
).otherwise(f.col(col)),
)
self._df = self._df.fillna(na_value, subset=subset)
return self

def select_features(
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/dataset/l2g_gold_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def build_feature_matrix(
.drop("studyId", "variantId")
.distinct(),
with_gold_standard=True,
)
).fill_na()

def filter_unique_associations(
self: L2GGoldStandard,
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def build_feature_matrix(
self,
features_list,
features_input_loader,
)
).fill_na()

def annotate_credible_sets(self: StudyLocus) -> StudyLocus:
"""Annotate study-locus dataset with credible set flags.
Expand Down
46 changes: 24 additions & 22 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def _annotate_gold_standards_w_feature_matrix(self) -> L2GFeatureMatrix:
gold_standards.build_feature_matrix(
self.feature_matrix, self.credible_set
)
.fill_na()
.select_features(self.features_list)
.persist()
)
Expand Down Expand Up @@ -322,6 +321,7 @@ def __init__(
.json(evidence_output_path)
)


class LocusToGeneAssociationsStep:
"""Locus to gene associations step."""

Expand All @@ -343,39 +343,41 @@ def __init__(
indirect_associations_output_path (str): Path to the indirect associations output dataset
"""
# Read in the disease index
disease_index = (
session.spark.read.parquet(disease_index_path)
.select(
f.col("id").alias("diseaseId"),
f.explode("ancestors").alias("ancestorDiseaseId")
)
disease_index = session.spark.read.parquet(disease_index_path).select(
f.col("id").alias("diseaseId"),
f.explode("ancestors").alias("ancestorDiseaseId"),
)

# Read in the L2G evidence
disease_target_evidence = (
session.spark.read.json(evidence_input_path)
.select(
f.col("targetFromSourceId").alias("targetId"),
f.col("diseaseFromSourceMappedId").alias("diseaseId"),
f.col("resourceScore")
)
disease_target_evidence = session.spark.read.json(evidence_input_path).select(
f.col("targetFromSourceId").alias("targetId"),
f.col("diseaseFromSourceMappedId").alias("diseaseId"),
f.col("resourceScore"),
)

# Generate direct assocations and save file
(
disease_target_evidence
.groupBy("targetId", "diseaseId")
disease_target_evidence.groupBy("targetId", "diseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "diseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(direct_associations_output_path)
.select(
"targetId",
"diseaseId",
calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"),
)
.write.mode(session.write_mode)
.parquet(direct_associations_output_path)
)

# Generate indirect assocations and save file
(
disease_target_evidence
.join(disease_index, on="diseaseId", how="inner")
disease_target_evidence.join(disease_index, on="diseaseId", how="inner")
.groupBy("targetId", "ancestorDiseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "ancestorDiseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(indirect_associations_output_path)
.select(
"targetId",
"ancestorDiseaseId",
calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"),
)
.write.mode(session.write_mode)
.parquet(indirect_associations_output_path)
)
57 changes: 57 additions & 0 deletions tests/gentropy/dataset/test_l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import pytest
from pyspark.sql.types import (
ArrayType,
Expand Down Expand Up @@ -184,3 +185,59 @@ def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None:
),
_schema=GeneIndex.get_schema(),
)


def test_fill_na(spark: SparkSession) -> None:
"""Tests L2GFeatureMatrix.fill_na, particularly the imputation logic."""
sample_fm = L2GFeatureMatrix(
_df=spark.createDataFrame(
[
{
"studyLocusId": "1",
"geneId": "gene1",
"proteinGeneCount500kb": 3.0,
"geneCount500kb": 8.0,
"isProteinCoding": 1.0,
"anotherFeature": None,
},
{
"studyLocusId": "1",
"geneId": "gene2",
"proteinGeneCount500kb": 4.0,
"geneCount500kb": 10.0,
"isProteinCoding": 1.0,
"anotherFeature": None,
},
{
"studyLocusId": "1",
"geneId": "gene3",
"proteinGeneCount500kb": None,
"geneCount500kb": None,
"isProteinCoding": None,
"anotherFeature": None,
},
],
schema="studyLocusId STRING, geneId STRING, proteinGeneCount500kb DOUBLE, geneCount500kb DOUBLE, isProteinCoding DOUBLE, anotherFeature DOUBLE",
),
)
observed_df = sample_fm.fill_na()._df.filter(f.col("geneId") == "gene3")
expected_df_missing_row = spark.createDataFrame(
[
{
"studyLocusId": "1",
"geneId": "gene3",
"proteinGeneCount500kb": 3.5,
"geneCount500kb": 9.0,
"isProteinCoding": 1.0,
"anotherFeature": 0.0,
},
],
).select(
"studyLocusId",
"geneId",
"proteinGeneCount500kb",
"geneCount500kb",
"isProteinCoding",
"anotherFeature",
)
assert observed_df.collect() == expected_df_missing_row.collect()

0 comments on commit 04b1e22

Please sign in to comment.