Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flagging duplicated entries while keeping one of the duplicates #876

Merged
merged 8 commits into from
Oct 26, 2024
11 changes: 4 additions & 7 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def update_quality_flag(

@staticmethod
def flag_duplicates(test_column: Column) -> Column:
"""Return True for duplicated values in column.
"""Return True for rows, where the value was already seen in column.

This implementation allows keeping the first occurrence of the value.

Args:
test_column (Column): Column to check for duplicates
Expand All @@ -331,12 +333,7 @@ def flag_duplicates(test_column: Column) -> Column:
Column: Column with a boolean flag for duplicates
"""
return (
f.count(test_column).over(
Window.partitionBy(test_column).rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
)
> 1
f.row_number().over(Window.partitionBy(test_column).orderBy(f.rand())) > 1
)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
.filter_credible_set(credible_interval=CredibleInterval.IS99)
# Annotate credible set confidence:
.assign_confidence()
# Flagging credible sets that are duplicated:
.validate_unique_study_locus_id()
).persist() # we will need this for 2 types of outputs

study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode(
Expand Down
55 changes: 55 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,58 @@ def test_qc_valid_chromosomes(
StudyLocusQualityCheck.INVALID_CHROMOSOME.value
in row["qualityControls"]
)


class TestStudyLocusDuplicationFlagging:
"""Collection of tests related to flagging redundant credible sets."""

STUDY_LOCUS_DATA = [
# Non-duplicated:
("1", "v1", "s1", "pics"),
# Triplicate:
("3", "v3", "s1", "pics"),
("3", "v3", "s1", "pics"),
("3", "v3", "s1", "pics"),
]

STUDY_LOCUS_SCHEMA = t.StructType(
[
t.StructField("studyLocusId", t.StringType(), False),
t.StructField("variantId", t.StringType(), False),
t.StructField("studyId", t.StringType(), False),
t.StructField("finemappingMethod", t.StringType(), False),
]
)

@pytest.fixture(autouse=True)
def _setup(self: TestStudyLocusDuplicationFlagging, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.study_locus = StudyLocus(
_df=spark.createDataFrame(
self.STUDY_LOCUS_DATA, schema=self.STUDY_LOCUS_SCHEMA
).withColumn(
"qualityControls", f.array().cast(t.ArrayType(t.StringType()))
),
_schema=StudyLocus.get_schema(),
)

# Run validation:
self.validated = self.study_locus.validate_unique_study_locus_id()

def test_duplication_flag_type(self: TestStudyLocusDuplicationFlagging) -> None:
"""Test duplication flagging return type."""
assert isinstance(self.validated, StudyLocus)

def test_duplication_flag_no_data_loss(
self: TestStudyLocusDuplicationFlagging,
) -> None:
"""Test duplication flagging no data loss."""
assert self.validated.df.count() == self.study_locus.df.count()

def test_duplication_flag_correctness(
self: TestStudyLocusDuplicationFlagging,
) -> None:
"""Make sure that the end, there are two study loci that pass the validation."""
assert self.validated.df.filter(f.size("qualityControls") == 0).count() == 2

assert self.validated.df.filter(f.size("qualityControls") > 0).count() == 2
Loading