diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 238ff4087..57247a49a 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -18,6 +18,17 @@ "type": "double", "nullable": false, "metadata": {} + }, + { + "metadata": {}, + "name": "locusToGeneFeatures", + "nullable": true, + "type": { + "keyType": "string", + "type": "map", + "valueContainsNull": true, + "valueType": "float" + } } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 169f5a846..64ce964c7 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -126,3 +126,54 @@ def to_disease_target_evidence( "studyLocusId", ) ) + + def add_locus_to_gene_features( + self: L2GPrediction, feature_matrix: L2GFeatureMatrix + ) -> L2GPrediction: + """Add features to the L2G predictions. + + Args: + feature_matrix (L2GFeatureMatrix): Feature matrix dataset + + Returns: + L2GPrediction: L2G predictions with additional features + """ + # Testing if `locusToGeneFeatures` column already exists: + if "locusToGeneFeatures" in self.df.columns: + self.df = self.df.drop("locusToGeneFeatures") + + # Columns identifying a studyLocus/gene pair + prediction_id_columns = ["studyLocusId", "geneId"] + + # L2G matrix columns to build the map: + columns_to_map = [ + column + for column in feature_matrix._df.columns + if column not in prediction_id_columns + ] + + # Aggregating all features into a single map column: + aggregated_features = ( + feature_matrix._df.withColumn( + "locusToGeneFeatures", + f.create_map( + *sum( + [ + (f.lit(colname), f.col(colname)) + for colname in columns_to_map + ], + (), + ) + ), + ) + # from the freshly created map, we filter out the null values + .withColumn( + "locusToGeneFeatures", + f.expr("map_filter(locusToGeneFeatures, (k, v) -> v is not null)"), + ) + .drop(*columns_to_map) + ) + return L2GPrediction( + _df=self.df.join(aggregated_features, on=prediction_id_columns, how="left"), + _schema=self.get_schema(), + ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 79be2385a..cc1b5a5d1 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -190,9 +190,9 @@ def run_predict(self) -> None: hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) - predictions.df.write.mode(self.session.write_mode).parquet( - self.predictions_path - ) + predictions.add_locus_to_gene_features(self.feature_matrix).df.write.mode( + self.session.write_mode + ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") def run_train(self) -> None: