diff --git a/src/gentropy/credible_set_qc.py b/src/gentropy/credible_set_qc.py index 31d298886..8ea9e06dd 100644 --- a/src/gentropy/credible_set_qc.py +++ b/src/gentropy/credible_set_qc.py @@ -16,11 +16,12 @@ def __init__( self, session: Session, credible_sets_path: str, - study_index_path: str, - ld_index_path: str, output_path: str, p_value_threshold: float = 1e-5, purity_min_r2: float = 0.01, + clump: bool = False, + ld_index_path: str | None = None, + study_index_path: str | None = None, ld_min_r2: float = 0.8, ) -> None: """Run credible set quality control step. @@ -28,23 +29,32 @@ def __init__( Args: session (Session): Session object. credible_sets_path (str): Path to credible sets file. - study_index_path (str): Path to study index file. - ld_index_path (str): Path to LD index file. output_path (str): Path to write the output file. p_value_threshold (float): P-value threshold for credible set quality control. Default is 1e-5. purity_min_r2 (float): Minimum R2 for purity estimation. Default is 0.01. + clump (bool): Whether to clump the credible sets by LD. Default is False. + ld_index_path (str | None): Path to LD index file. + study_index_path (str | None): Path to study index file. ld_min_r2 (float): Minimum R2 for LD estimation. Default is 0.8. """ - cred_sets = StudyLocus.from_parquet(session, credible_sets_path) - study_index = StudyIndex.from_parquet(session, study_index_path) - ld_index = LDIndex.from_parquet(session, ld_index_path) - + cred_sets = StudyLocus.from_parquet( + session, credible_sets_path, recursiveFileLookup=True + ).coalesce(200) + ld_index = ( + LDIndex.from_parquet(session, ld_index_path) if ld_index_path else None + ) + study_index = ( + StudyIndex.from_parquet(session, study_index_path) + if study_index_path + else None + ) cred_sets_clean = SUSIE_inf.credible_set_qc( cred_sets, - study_index, - ld_index, p_value_threshold, purity_min_r2, + clump, + ld_index, + study_index, ld_min_r2, ) diff --git a/src/gentropy/method/susie_inf.py b/src/gentropy/method/susie_inf.py index 4f75faad8..f90d15a14 100644 --- a/src/gentropy/method/susie_inf.py +++ b/src/gentropy/method/susie_inf.py @@ -9,6 +9,7 @@ import pyspark.sql.functions as f import scipy.linalg import scipy.special +from pyspark.sql.window import Window from scipy.optimize import minimize, minimize_scalar from scipy.special import logsumexp @@ -469,43 +470,59 @@ def cred_inf( @staticmethod def credible_set_qc( cred_sets: StudyLocus, - study_index: StudyIndex, - ld_index: LDIndex, p_value_threshold: float = 1e-5, purity_min_r2: float = 0.01, + clump: bool = False, + ld_index: LDIndex | None = None, + study_index: StudyIndex | None = None, ld_min_r2: float = 0.8, ) -> StudyLocus: """Filter credible sets by lead P-value and min-R2 purity, and performs LD clumping. + In case of duplicated loci, the filtering retains the loci wth the highest credibleSetLog10BF + Args: cred_sets (StudyLocus): StudyLocus object with credible sets to filter/clump - study_index (StudyIndex): StudyIndex object - ld_index (LDIndex): LDIndex object p_value_threshold (float): p-value threshold for filtering credible sets, default is 1e-5 purity_min_r2 (float): min-R2 purity threshold for filtering credible sets, default is 0.01 + clump (bool): Whether to clump the credible sets by LD, default is False + ld_index (LDIndex | None): LDIndex object + study_index (StudyIndex | None): StudyIndex object ld_min_r2 (float): LD R2 threshold for clumping, default is 0.8 Returns: StudyLocus: Credible sets which pass filters and LD clumping. """ - df = ( + cred_sets.df = ( cred_sets.df.withColumn( "pValue", f.col("pValueMantissa") * f.pow(10, f.col("pValueExponent")) ) .filter(f.col("pValue") <= p_value_threshold) .filter(f.col("purityMinR2") >= purity_min_r2) .drop("pValue") + .withColumn( + "rn", + f.row_number().over( + Window.partitionBy("studyLocusId").orderBy( + f.desc("credibleSetLog10BF") + ) + ), + ) + .filter(f.col("rn") == 1) + .drop("rn") ) - cred_sets.df = df - cred_sets = ( - cred_sets.annotate_ld(study_index, ld_index, ld_min_r2) - .clump() - .filter( - ~f.array_contains( - f.col("qualityControls"), - "Explained by a more significant variant in high LD (clumped)", + if clump: + assert study_index, "Running in clump mode, which requires study_index." + assert ld_index, "Running in clump mode, which requires ld_index." + cred_sets = ( + cred_sets.annotate_ld(study_index, ld_index, ld_min_r2) + .clump() + .filter( + ~f.array_contains( + f.col("qualityControls"), + "Explained by a more significant variant in high LD (clumped)", + ) ) ) - ) return cred_sets