Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flagging duplicated entries while keeping one of the duplicates #876

Merged
merged 8 commits into from
Oct 26, 2024
11 changes: 4 additions & 7 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def update_quality_flag(

@staticmethod
def flag_duplicates(test_column: Column) -> Column:
"""Return True for duplicated values in column.
"""Return True for rows, where the value was already seen in column.

This implementation allows keeping the first occurrence of the value.

Args:
test_column (Column): Column to check for duplicates
Expand All @@ -331,12 +333,7 @@ def flag_duplicates(test_column: Column) -> Column:
Column: Column with a boolean flag for duplicates
"""
return (
f.count(test_column).over(
Window.partitionBy(test_column).rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
)
> 1
f.row_number().over(Window.partitionBy(test_column).orderBy(f.rand())) > 1
)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
.filter_credible_set(credible_interval=CredibleInterval.IS99)
# Annotate credible set confidence:
.assign_confidence()
# Flagging credible sets that are duplicated:
.validate_unique_study_locus_id()
).persist() # we will need this for 2 types of outputs

study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode(
Expand Down
8 changes: 4 additions & 4 deletions tests/gentropy/dataset/test_study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class TestUniquenessValidation:
STUDY_DATA = [
# This is the only study to be flagged:
("s1", "eqtl", "p"),
("s1", "eqtl", "p"),
("s1", "eqtl", "p"), # Duplicate -> one should be flagged
("s3", "gwas", "p"),
("s4", "gwas", "p"),
]
Expand All @@ -337,8 +337,8 @@ def test_uniqueness_correct_data(self: TestUniquenessValidation) -> None:
"""Testing if the function returns the right type."""
validated = self.study_index.validate_unique_study_id().persist()

# We have more than one flagged studies:
assert validated.df.filter(f.size(f.col("qualityControls")) > 0).count() > 1
# We have only one flagged study:
assert validated.df.filter(f.size(f.col("qualityControls")) > 0).count() == 1

# The flagged study identifiers are found more than once:
flagged_ids = {
Expand All @@ -350,7 +350,7 @@ def test_uniqueness_correct_data(self: TestUniquenessValidation) -> None:
}

for _, count in flagged_ids.items():
assert count > 1
assert count == 1

# the right study is found:
assert "s1" in flagged_ids
Expand Down
55 changes: 55 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,58 @@ def test_qc_valid_chromosomes(
StudyLocusQualityCheck.INVALID_CHROMOSOME.value
in row["qualityControls"]
)


class TestStudyLocusDuplicationFlagging:
"""Collection of tests related to flagging redundant credible sets."""

STUDY_LOCUS_DATA = [
# Non-duplicated:
("1", "v1", "s1", "pics"),
# Triplicate:
("3", "v3", "s1", "pics"),
("3", "v3", "s1", "pics"),
("3", "v3", "s1", "pics"),
]

STUDY_LOCUS_SCHEMA = t.StructType(
[
t.StructField("studyLocusId", t.StringType(), False),
t.StructField("variantId", t.StringType(), False),
t.StructField("studyId", t.StringType(), False),
t.StructField("finemappingMethod", t.StringType(), False),
]
)

@pytest.fixture(autouse=True)
def _setup(self: TestStudyLocusDuplicationFlagging, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.study_locus = StudyLocus(
_df=spark.createDataFrame(
self.STUDY_LOCUS_DATA, schema=self.STUDY_LOCUS_SCHEMA
).withColumn(
"qualityControls", f.array().cast(t.ArrayType(t.StringType()))
),
_schema=StudyLocus.get_schema(),
)

# Run validation:
self.validated = self.study_locus.validate_unique_study_locus_id()

def test_duplication_flag_type(self: TestStudyLocusDuplicationFlagging) -> None:
"""Test duplication flagging return type."""
assert isinstance(self.validated, StudyLocus)

def test_duplication_flag_no_data_loss(
self: TestStudyLocusDuplicationFlagging,
) -> None:
"""Test duplication flagging no data loss."""
assert self.validated.df.count() == self.study_locus.df.count()

def test_duplication_flag_correctness(
self: TestStudyLocusDuplicationFlagging,
) -> None:
"""Make sure that the end, there are two study loci that pass the validation."""
assert self.validated.df.filter(f.size("qualityControls") == 0).count() == 2

assert self.validated.df.filter(f.size("qualityControls") > 0).count() == 2
Loading