diff --git a/gnomad/assessment/summary_stats.py b/gnomad/assessment/summary_stats.py index 401e890e5..d2418f475 100644 --- a/gnomad/assessment/summary_stats.py +++ b/gnomad/assessment/summary_stats.py @@ -12,7 +12,7 @@ add_most_severe_consequence_to_consequence, filter_vep_to_canonical_transcripts, filter_vep_to_mane_select_transcripts, - get_most_severe_consequence_for_summary, + get_most_severe_csq_from_multiple_csq_lists, process_consequences, ) @@ -231,11 +231,11 @@ def get_summary_counts( logger.info("Filtering to mane select transcripts...") ht = filter_vep_to_mane_select_transcripts(ht) - logger.info("Getting VEP summary annotations...") - ht = get_most_severe_consequence_for_summary(ht) - - logger.info("Annotating with frequency bin information...") - ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index)) + logger.info("Annotating with VEP summary and frequency bin information...") + ht = ht.annotate( + freq_bin=freq_bin_expr(ht[freq_field], index), + **get_most_severe_csq_from_multiple_csq_lists(ht.vep), + ) logger.info( "Annotating HT globals with total counts/total allele counts per variant" @@ -248,7 +248,7 @@ def get_summary_counts( ht.alleles, ht.lof, ht.no_lof_flags, - ht.most_severe_csq, + ht.most_severe_consequence, prefix_str="total_", ) ) @@ -259,7 +259,7 @@ def get_summary_counts( ht[freq_field][index].AC, ht.lof, ht.no_lof_flags, - ht.most_severe_csq, + ht.most_severe_consequence, ) ) ) @@ -272,7 +272,7 @@ def get_summary_counts( ht.alleles, ht.lof, ht.no_lof_flags, - ht.most_severe_csq, + ht.most_severe_consequence, ) ) @@ -518,7 +518,7 @@ def _create_filter_by_csq( if not isinstance(csq_set, hl.expr.CollectionExpression): csq_set = hl.set(csq_set) - return csq_set.contains(t.most_severe_csq) + return csq_set.contains(t.most_severe_consequence) # Set up filters for specific consequences or sets of consequences. csq_filters = { @@ -1062,11 +1062,11 @@ def default_generate_gene_lof_summary( ) if filter_loftee: - lof_ht = get_most_severe_consequence_for_summary(mt.rows()) + lof_expr = get_most_severe_csq_from_multiple_csq_lists(mt.vep) mt = mt.filter_rows( - hl.is_defined(lof_ht[mt.row_key].lof) - & (lof_ht[mt.row_key].lof == "HC") - & (lof_ht[mt.row_key].no_lof_flags) + hl.is_defined(lof_expr.lof) + & (lof_expr.lof == "HC") + & (lof_expr.no_lof_flags) ) ht = mt.annotate_rows( diff --git a/gnomad/utils/vep.py b/gnomad/utils/vep.py index 5d45924a4..f6c994e1d 100644 --- a/gnomad/utils/vep.py +++ b/gnomad/utils/vep.py @@ -4,9 +4,10 @@ import logging import os import subprocess -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import hail as hl +from deprecated import deprecated from gnomad.resources.resource_utils import VersionedTableResource from gnomad.utils.filtering import combine_functions @@ -292,11 +293,8 @@ def get_most_severe_consequence_expr( `CSQ_ORDER` global. :return: Most severe consequence in `csq_expr`. """ - if csq_order is None: - csq_order = CSQ_ORDER - csqs = hl.literal(csq_order) - - return csqs.find(lambda c: csq_expr.contains(c)) + csq_order = csq_order or CSQ_ORDER + return hl.literal(csq_order).find(lambda c: csq_expr.contains(c)) def add_most_severe_consequence_to_consequence( @@ -720,12 +718,169 @@ def get_csq_from_struct( return hl.or_missing(hl.len(csq) > 0, csq) +def filter_to_most_severe_consequences( + csq_expr: hl.expr.ArrayExpression, + csq_order: Optional[List[str]] = None, + loftee_labels: Optional[List[str]] = None, + prioritize_protein_coding: bool = False, + prioritize_loftee: bool = False, + prioritize_loftee_no_flags: bool = False, + additional_order_field: Optional[str] = None, + additional_order: Optional[List[str]] = None, +) -> hl.StructExpression: + """ + Filter an array of VEP consequences to all entries that have the most severe consequence. + + Returns a struct with the following annotations: + + - most_severe_consequence: Most severe consequence for variant. + - lof: Whether the variant is a loss-of-function variant. + - no_lof_flags: Whether the variant has any LOFTEE flags (True if no flags). + - consequences: Array of consequences that match the most severe consequence. + + .. note:: + + - If you have multiple lists of consequences (such as lists of both + 'transcript_consequences' and 'intergenic_consequences') and want to + determine the most severe consequence across all lists, consider using + `get_most_severe_csq_from_multiple_csq_lists`. + + - If you want to group consequences by gene and determine the most severe + consequence for each gene, consider using `process_consequences`. + + If `prioritize_protein_coding` is True, protein-coding transcripts are prioritized + by filtering to only protein-coding transcripts and determining the + most severe consequence. If no protein-coding transcripts are present, determine + the most severe consequence for all transcripts. + + If `prioritize_loftee` is True, prioritize consequences with LOFTEE annotations, in + the order of `loftee_labels`, over those without LOFTEE annotations. If + `prioritize_loftee_no_flags` is True, prioritize LOFTEE consequences with no flags + over those with flags. + + If `additional_order` is provided, additional ordering is applied to the + consequences in the list after any of the above prioritization. An example use of + this parameter is to prioritize by PolyPhen predictions. + + :param csq_expr: ArrayExpression of VEP consequences to filter. + :param csq_order: List indicating the order of VEP consequences, sorted from high to + low impact. Default is None, which uses the value of the `CSQ_ORDER` global. + :param loftee_labels: Annotations added by LOFTEE, sorted from high to low impact. + Default is None, which uses the value of the `LOFTEE_LABELS` global. + :param prioritize_protein_coding: Whether to prioritize protein-coding transcripts + when determining the worst consequence. Default is False. + :param prioritize_loftee: Whether to prioritize LOFTEE consequences. Default is + False. + :param prioritize_loftee_no_flags: Whether to prioritize LOFTEE consequences with no + flags over those with flags. Default is False. + :param additional_order_field: Field name of the consequence annotation to use for + additional ordering to apply to the consequences in the list. Default is None. + :param additional_order: The ordering to use for prioritizing consequences in the + `additional_order_field`. Default is None. + :return: ArrayExpression with of the consequences that match the most severe + consequence. + """ + # Get the dtype of the csq_expr ArrayExpression elements + csq_type = csq_expr.dtype.element_type + + if ((additional_order_field is None) + (additional_order is None)) == 1: + raise ValueError( + "If `additional_order_field` is provided, `additional_order` must also be" + " provided and vice versa." + ) + + if additional_order_field and additional_order_field not in csq_type.fields: + raise ValueError("Additional order field not found in consequence type.") + + # Define the order of fields to prioritize by, based on specified parameters. + priority_order = ( + (["protein_coding"] if prioritize_protein_coding else []) + + (["lof"] if prioritize_loftee else []) + + (["no_lof_flags"] if prioritize_loftee_no_flags else []) + + ["most_severe_consequence"] + ) + + # Define the impact ordering of VEP consequences and LOFTEE labels. If not provided, + # use the globals CSQ_ORDER (set in get_most_severe_consequence_expr) and + # LOFTEE_LABELS. + loftee_labels = loftee_labels or LOFTEE_LABELS + term_order = {"most_severe_consequence": csq_order, "lof": loftee_labels} + + # Add the additional order field to the priority order and term order if provided. + if additional_order_field: + priority_order.append(additional_order_field) + term_order[additional_order_field] = additional_order + + # Define initial result and current expression. + result = {} + order_result_fields = ["most_severe_consequence", "lof"] + curr_expr = hl.or_missing(hl.len(csq_expr) > 0, csq_expr) + + for curr_field in priority_order: + order = term_order.get(curr_field) + # Added the below line to get around the pylint error: Unexpected keyword + # argument 'most_severe_consequence' in function call. + f_param = curr_field if curr_field not in order_result_fields else None + if f_param is not None and order is None: + # If there is no order specified for the current field, then the field is + # used as a parameter to filter_vep_transcript_csqs_expr and if there are + # any consequences remaining, the result is set to True. + curr_expr = filter_vep_transcript_csqs_expr(curr_expr, **{f_param: True}) + result[curr_field] = hl.len(curr_expr) > 0 + else: + # Handle the case where the current field is a collection of consequences + # each with a 'consequence_terms' field (e.g. transcript_consequences) that + # need to be flattened to determine the most severe consequence. + f = curr_field if curr_field in csq_type.fields else "consequence_terms" + if isinstance(csq_type[f], hl.tarray) or isinstance(csq_type[f], hl.tset): + f_map = csq_expr.flatmap + f_func = lambda x, csq: x.contains(csq) + else: + f_map = csq_expr.map + f_func = lambda x, csq: x == csq + + # Get the most severe (highest impact) consequence for the current field. + ms_csq_expr = get_most_severe_consequence_expr(f_map(lambda x: x[f]), order) + + # Filter to only elements that contain the most severe (highest impact) + # consequence for the current field, and return missing if the most severe + # consequence is missing. + curr_expr = hl.or_missing( + hl.is_defined(ms_csq_expr), + csq_expr.filter(lambda x: f_func(x[f], ms_csq_expr)), + ) + + # Add the most severe consequence to the result if the field is in the order + # result fields. When there is no most severe consequence and the current + # field has a result expression, the result is kept as the existing result + # value. + if curr_field in order_result_fields: + if curr_field in result: + ms_csq_expr = hl.or_else(ms_csq_expr, result[curr_field]) + result[curr_field] = ms_csq_expr + + if curr_field == "lof": + result["no_lof_flags"] = hl.any( + curr_expr.map( + lambda x: hl.is_missing(x.lof_flags) | (x.lof_flags == "") + ) + ) + + curr_expr = hl.or_missing(hl.len(curr_expr) > 0, curr_expr) + csq_expr = hl.or_else(curr_expr, csq_expr) + + return hl.struct(**result, consequences=csq_expr) + + +@deprecated(reason="Replaced by get_most_severe_csq_from_multiple_csq_lists") def get_most_severe_consequence_for_summary( ht: hl.Table, csq_order: List[str] = CSQ_ORDER, loftee_labels: List[str] = LOFTEE_LABELS, ) -> hl.Table: """ + Use `get_most_severe_csq_from_multiple_csq_lists` instead, this function is deprecated. + Prepare a hail Table for summary statistics generation. Adds the following annotations: @@ -741,61 +896,153 @@ def get_most_severe_consequence_for_summary( :param loftee_labels: Annotations added by LOFTEE. Default is LOFTEE_LABELS. :return: Table annotated with VEP summary annotations. """ + csq_expr = get_most_severe_csq_from_multiple_csq_lists( + ht.vep, csq_order=csq_order, loftee_labels=loftee_labels + ) - def _get_most_severe_csq( - csq_list: hl.expr.ArrayExpression, protein_coding: bool - ) -> hl.expr.StructExpression: - """ - Process VEP consequences to generate summary annotations. + # Rename most_severe_consequence to most_severe_csq for consistency with older + # version of code. + csq_expr = csq_expr.rename({"most_severe_consequence": "most_severe_csq"}) - :param csq_list: VEP consequences list to be processed. - :param protein_coding: Whether variant is in a protein-coding transcript. - :return: Struct containing summary annotations. - """ - lof = hl.null(hl.tstr) - no_lof_flags = hl.null(hl.tbool) - if protein_coding: - all_lofs = csq_list.map(lambda x: x.lof) - lof = hl.literal(loftee_labels).find(lambda x: all_lofs.contains(x)) - csq_list = hl.if_else( - hl.is_defined(lof), csq_list.filter(lambda x: x.lof == lof), csq_list - ) - no_lof_flags = hl.or_missing( - hl.is_defined(lof), - csq_list.any(lambda x: (x.lof == lof) & hl.is_missing(x.lof_flags)), - ) - all_csq_terms = csq_list.flatmap(lambda x: x.consequence_terms) - most_severe_csq = hl.literal(csq_order).find( - lambda x: all_csq_terms.contains(x) - ) - return hl.struct( - most_severe_csq=most_severe_csq, - protein_coding=protein_coding, - lof=lof, - no_lof_flags=no_lof_flags, - ) + return ht.annotate(**csq_expr) - protein_coding = ht.vep.transcript_consequences.filter( - lambda x: x.biotype == "protein_coding" - ) - return ht.annotate( - **hl.case(missing_false=True) - .when(hl.len(protein_coding) > 0, _get_most_severe_csq(protein_coding, True)) - .when( - hl.len(ht.vep.transcript_consequences) > 0, - _get_most_severe_csq(ht.vep.transcript_consequences, False), - ) - .when( - hl.len(ht.vep.regulatory_feature_consequences) > 0, - _get_most_severe_csq(ht.vep.regulatory_feature_consequences, False), - ) - .when( - hl.len(ht.vep.motif_feature_consequences) > 0, - _get_most_severe_csq(ht.vep.motif_feature_consequences, False), - ) - .default(_get_most_severe_csq(ht.vep.intergenic_consequences, False)) + +def get_most_severe_csq_from_multiple_csq_lists( + vep_expr: hl.expr.StructExpression, + csq_order: Optional[List[str]] = None, + loftee_labels: Optional[List[str]] = None, + include_csqs: bool = False, + prioritize_protein_coding: bool = True, + prioritize_loftee: bool = True, + prioritize_loftee_no_flags: bool = False, + csq_list_order: Union[List[str], Tuple[str]] = ( + "transcript_consequences", + "regulatory_feature_consequences", + "motif_feature_consequences", + "intergenic_consequences", + ), + add_order_by_csq_list: Dict[str, Tuple[str, List[str]]] = None, +) -> hl.expr.StructExpression: + """ + Process multiple VEP consequences lists to determine the most severe consequence. + + Returns a struct expression with the following annotations: + - most_severe_consequence: Most severe consequence for variant. + - protein_coding: Whether the variant is present on a protein-coding transcript. + - lof: Whether the variant is a loss-of-function variant. + - no_lof_flags: Whether the variant has any LOFTEE flags (True if no flags). + + .. note:: + + Assumes input Table is annotated with VEP and that VEP annotations have been + filtered to canonical or MANE Select transcripts if wanted. + + If `include_csqs` is True, additional annotations are added for each VEP + consequences list in `csq_list_order`, with the consequences that match the most + severe consequence term. + + If `prioritize_protein_coding` is True and "transcript_consequences" is in + `csq_list_order`, protein-coding transcripts are prioritized by filtering to only + protein-coding transcripts and determining the most severe consequence. If no + protein-coding transcripts are present, determine the most severe consequence for + all transcripts. If additional VEP consequences lists are requested, process those + lists in the order they appear in `csq_list_order`. + + If `add_order_by_csq_list` is provided, additional ordering is applied to the + consequences in the list. The key is the name of the consequences list and the value + is the order of consequences, sorted from high to low impact. An example use of this + parameter is to prioritize PolyPhen consequences for protein-coding transcripts. + + If `prioritize_loftee` is True, prioritize consequences with LOFTEE annotations, in + the order of `loftee_labels`, over those without LOFTEE annotations. If + `prioritize_loftee_no_flags` is True, prioritize LOFTEE consequences with no flags + over those with flags. + + :param vep_expr: StructExpression of VEP consequences to get the most severe + consequence from. + :param csq_order: Order of VEP consequences, sorted from high to low impact. Default + is None, which uses the value of the `CSQ_ORDER` global. + :param loftee_labels: Annotations added by LOFTEE, sorted from high to low impact. + Default is None, which uses the value of the `LOFTEE_LABELS` global. + :param include_csqs: Whether to include all consequences for the most severe + consequence. Default is False. + :param prioritize_protein_coding: Whether to prioritize protein-coding transcripts + when determining the worst consequence. Default is True. + :param prioritize_loftee: Whether to prioritize consequences with LOFTEE annotations + over those without. Default is True. + :param prioritize_loftee_no_flags: Whether to prioritize LOFTEE annotated + consequences with no flags over those with flags. Default is False. + :param csq_list_order: Order of VEP consequences lists to be processed. Default is + ('transcript_consequences', 'regulatory_feature_consequences', + 'motif_feature_consequences', 'intergenic_consequences'). + :param add_order_by_csq_list: Dictionary of additional ordering for VEP consequences + lists. The key is the name of the consequences list and the value is the order + of consequences, sorted from high to low impact. Default is None. + :return: StructExpression with the most severe consequence and additional annotations. + """ + add_order_by_csq_list = add_order_by_csq_list or {} + loftee_labels = loftee_labels or LOFTEE_LABELS + + result = { + **({"protein_coding": hl.tbool} if prioritize_protein_coding else {}), + **({"lof": hl.tstr} if prioritize_loftee else {}), + **( + {"no_lof_flags": hl.tbool} + if prioritize_loftee or prioritize_protein_coding + else {} + ), + "most_severe_consequence": hl.tstr, + } + result = hl.struct(**{k: hl.missing(v) for k, v in result.items()}) + + # Create a struct with missing values for each VEP consequences list. + ms_csq_list_expr = hl.struct( + **{c: hl.missing(vep_expr[c].dtype) for c in csq_list_order if c in vep_expr} ) + # Build the case expression to determine the most severe consequence. + ms_csq_expr = hl.case(missing_false=True) + for csq_list in csq_list_order: + if csq_list not in vep_expr: + logger.warning("VEP consequences list %s not found in input!", csq_list) + continue + + is_tc = csq_list == "transcript_consequences" + csq_expr = vep_expr[csq_list] + + # Set the base arguments for filtering to the most severe consequence using + # filter_to_most_severe_consequences. + add_order = add_order_by_csq_list.get(csq_list) + base_args = { + "csq_order": csq_order, + "loftee_labels": loftee_labels, + "prioritize_protein_coding": ( + True if (prioritize_protein_coding and is_tc) else False + ), + "prioritize_loftee": True if (prioritize_loftee and is_tc) else False, + "prioritize_loftee_no_flags": ( + True if (prioritize_loftee_no_flags and is_tc) else False + ), + "additional_order_field": add_order[0] if add_order else None, + "additional_order": add_order[1] if add_order else None, + } + ms_expr = filter_to_most_severe_consequences(csq_expr, **base_args) + ms_expr = result.annotate(**ms_expr) + + # Annotate the current consequence list with the consequences that match the + # most severe consequence term. + if include_csqs: + ms_expr = ms_expr.annotate( + **ms_csq_list_expr.annotate(**{csq_list: ms_expr.consequences}) + ) + ms_expr = ms_expr.drop("consequences") + + # If the length of the consequence list is not 0, set the most severe + # consequence to the most severe consequence for the current list. + ms_csq_expr = ms_csq_expr.when(hl.len(csq_expr) > 0, ms_expr) + + return ms_csq_expr.or_missing() + def filter_vep_transcript_csqs( t: Union[hl.Table, hl.MatrixTable], @@ -866,13 +1113,15 @@ def filter_vep_transcript_csqs_expr( mane_select: bool = False, ensembl_only: bool = False, protein_coding: bool = False, + loftee_labels: Optional[List[str]] = None, + no_lof_flags: bool = False, csqs: List[str] = None, keep_csqs: bool = True, genes: Optional[List[str]] = None, keep_genes: bool = True, match_by_gene_symbol: bool = False, additional_filtering_criteria: Optional[List[Callable]] = None, -) -> Union[hl.Table, hl.MatrixTable]: +) -> hl.expr.ArrayExpression: """ Filter VEP transcript consequences based on specified criteria, and optionally filter to variants where transcript consequences is not empty after filtering. @@ -893,6 +1142,10 @@ def filter_vep_transcript_csqs_expr( Emsembl. Default is False. :param protein_coding: Whether to filter to only protein-coding transcripts. Default is False. + :param loftee_labels: List of LOFTEE labels to filter to. Default is None, which + filters to all LOFTEE labels. + :param no_lof_flags: Whether to filter to consequences with no LOFTEE flags. + Default is False. :param csqs: Optional list of consequence terms to filter to. Transcript consequences are filtered to those where 'most_severe_consequence' is in the list of consequence terms `csqs`. Default is None. @@ -910,12 +1163,13 @@ def filter_vep_transcript_csqs_expr( criteria to apply to the VEP transcript consequences. :return: ArrayExpression of filtered VEP transcript consequences. """ + csq_fields = csq_expr.dtype.element_type.fields criteria = [lambda csq: True] if synonymous: logger.info("Filtering to most severe consequence of synonymous_variant...") csqs = ["synonymous_variant"] if csqs is not None: - if "most_severe_consequence" not in csq_expr.dtype.element_type.fields: + if "most_severe_consequence" not in csq_fields: logger.info("Adding most_severe_consequence annotation...") csq_expr = add_most_severe_consequence_to_consequence(csq_expr) @@ -936,6 +1190,19 @@ def filter_vep_transcript_csqs_expr( if protein_coding: logger.info("Filtering to protein coding transcripts...") criteria.append(lambda csq: csq.biotype == "protein_coding") + if loftee_labels: + logger.info( + "Filtering to consequences with LOFTEE labels: %s...", loftee_labels + ) + criteria.append(lambda x: hl.set(loftee_labels).contains(x.lof)) + if no_lof_flags: + logger.info("Filtering to consequences with no LOFTEE flags...") + if "lof_flags" in csq_fields: + criteria.append(lambda x: hl.is_missing(x.lof_flags) | (x.lof_flags == "")) + else: + logger.warning( + "'lof_flags' not present in consequence struct, no consequences are filtered based on LOFTEE flags" + ) if genes is not None: logger.info("Filtering to genes of interest...") genes = hl.literal(genes) @@ -955,24 +1222,41 @@ def filter_vep_transcript_csqs_expr( def add_most_severe_csq_to_tc_within_vep_root( - t: Union[hl.Table, hl.MatrixTable], vep_root: str = "vep" + t: Union[hl.Table, hl.MatrixTable], + vep_root: str = "vep", + csq_field: str = "transcript_consequences", + most_severe_csq_field: str = "most_severe_consequence", + csq_order: Optional[List[str]] = None, ) -> Union[hl.Table, hl.MatrixTable]: """ - Add most_severe_consequence annotation to 'transcript_consequences' within the vep root annotation. + Add `most_severe_csq_field` annotation to `csq_field` within the `vep_root` annotation. :param t: Input Table or MatrixTable. :param vep_root: Root for vep annotation (probably vep). + :param csq_field: Field name of VEP consequences ArrayExpression within `vep_root` + to add most severe consequence to. Default is 'transcript_consequences'. + :param most_severe_csq_field: Field name to use for most severe consequence. Default + is 'most_severe_consequence'. + :param csq_order: Optional list indicating the order of VEP consequences, sorted + from high to low impact. Default is None, which uses the value of the + `CSQ_ORDER` global. :return: Table or MatrixTable with most_severe_consequence annotation added. """ - annotation = t[vep_root].annotate( - transcript_consequences=t[vep_root].transcript_consequences.map( - add_most_severe_consequence_to_consequence - ) + vep_expr = t[vep_root] + csq_expr = vep_expr[csq_field] + vep_expr = vep_expr.annotate( + **{ + csq_field: add_most_severe_consequence_to_consequence( + csq_expr, + csq_order=csq_order, + most_severe_csq_field=most_severe_csq_field, + ) + } ) return ( - t.annotate_rows(**{vep_root: annotation}) + t.annotate_rows(**{vep_root: vep_expr}) if isinstance(t, hl.MatrixTable) - else t.annotate(**{vep_root: annotation}) + else t.annotate(**{vep_root: vep_expr}) ) diff --git a/requirements.txt b/requirements.txt index f30acc1d6..563a61d1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ annoy +deprecated ga4gh.vrs[extras]==0.8.4 hail hdbscan diff --git a/tests/utils/test_vep.py b/tests/utils/test_vep.py new file mode 100644 index 000000000..d9e76b6dc --- /dev/null +++ b/tests/utils/test_vep.py @@ -0,0 +1,603 @@ +"""Tests for the gnomAD VEP utility functions.""" + +from typing import Any, Dict, List, Optional + +import hail as hl +import pytest + +from gnomad.utils.vep import ( + add_most_severe_consequence_to_consequence, + add_most_severe_csq_to_tc_within_vep_root, + filter_to_most_severe_consequences, + filter_vep_transcript_csqs_expr, + get_most_severe_consequence_expr, + get_most_severe_csq_from_multiple_csq_lists, +) + + +class TestGetMostSevereConsequenceExpr: + """Test the get_most_severe_consequence_expr function.""" + + @pytest.mark.parametrize( + "csq_expr, custom_csq_order, expected", + [ + # Test with default csq_order. + ( + hl.literal(["splice_region_variant", "intron_variant"]), + None, + "splice_region_variant", + ), + # Test with custom csq_order. + ( + hl.literal(["splice_region_variant", "intron_variant"]), + ["intron_variant", "splice_region_variant"], + "intron_variant", + ), + # Test with csq_expr that contains a consequence not in csq_order. + (hl.literal(["non_existent_consequence"]), None, None), + # Test with empty csq_expr. + (hl.empty_array(hl.tstr), None, None), + ], + ) + def test_get_most_severe_consequence_expr( + self, + csq_expr: hl.expr.ArrayExpression, + custom_csq_order: List[str], + expected: str, + ) -> None: + """ + Test get_most_severe_consequence_expr with various parameters. + + :param csq_expr: The consequence terms to evaluate. + :param custom_csq_order: The custom consequence order to use. + :param expected: The expected most severe consequence. + :return: None. + """ + result = get_most_severe_consequence_expr(csq_expr, custom_csq_order) + assert hl.eval(result) == expected, f"Expected {expected}" + + +class TestAddMostSevereConsequenceToConsequence: + """Tests for the add_most_severe_consequence_to_consequence function.""" + + @pytest.mark.parametrize( + "tc_expr, custom_csq_order, expected", + [ + # Test with default csq_order. + ( + hl.struct(consequence_terms=["missense_variant", "synonymous_variant"]), + None, + "missense_variant", + ), + # Test with custom csq_order. + ( + hl.struct(consequence_terms=["missense_variant", "synonymous_variant"]), + ["synonymous_variant", "missense_variant"], + "synonymous_variant", + ), + # Test with csq_expr that contains a consequence not in csq_order. + (hl.struct(consequence_terms=["non_existent_consequence"]), None, None), + ], + ) + def test_add_most_severe_consequence_to_consequence( + self, + tc_expr: hl.expr.StructExpression, + custom_csq_order: List[str], + expected: str, + ) -> None: + """ + Test add_most_severe_consequence_to_consequence with various parameters. + + :param tc_expr: The transcript consequence to evaluate. + :param custom_csq_order: The custom consequence order to use. + :param expected: The expected most severe consequence. + :return: None. + """ + result = add_most_severe_consequence_to_consequence(tc_expr, custom_csq_order) + assert hl.eval(result.most_severe_consequence) == hl.eval( + get_most_severe_consequence_expr( + tc_expr.consequence_terms, custom_csq_order + ) + ), f"Expected {expected}" + + +class TestAddMostSevereCsqToTcWithinVepRoot: + """Tests for the add_most_severe_csq_to_tc_within_vep_root function.""" + + @pytest.mark.parametrize( + "transcript_consequences, expected_most_severe", + [ + # Test with multiple consequence terms of different severity. + ( + [ + hl.struct( + consequence_terms=["missense_variant", "synonymous_variant"] + ) + ], + "missense_variant", + ), + # Test with a single consequence term. + ( + [hl.struct(consequence_terms=["synonymous_variant"])], + "synonymous_variant", + ), + # Test with multiple consequence terms of the same severity. + ( + [ + hl.struct( + consequence_terms=["synonymous_variant", "synonymous_variant"] + ) + ], + "synonymous_variant", + ), + # Test with a consequence term not in the default order. + ([hl.struct(consequence_terms=["non_existent_consequence"])], None), + ], + ) + def test_add_most_severe_csq_to_tc_within_vep_root( + self, + transcript_consequences: List[hl.expr.StructExpression], + expected_most_severe: str, + ) -> None: + """ + Test the add_most_severe_csq_to_tc_within_vep_root function. + + :param transcript_consequences: List of transcript consequences to annotate. + :param expected_most_severe: The expected most severe consequence. + :return: None. + """ + # Create a mock MatrixTable with vep.transcript_consequences. + mt = hl.utils.range_matrix_table(1, 1) + mt = mt.annotate_rows( + vep=hl.struct(transcript_consequences=transcript_consequences) + ) + + # Apply the function. + result_mt = add_most_severe_csq_to_tc_within_vep_root(mt) + + # Check that the most_severe_consequence is correct. + assert ( + result_mt.vep.transcript_consequences[0].most_severe_consequence.take(1)[0] + == expected_most_severe + ), f"Expected '{expected_most_severe}'" + + +class TestFilterToMostSevereConsequences: + """Tests for the filter_to_most_severe_consequences function.""" + + @pytest.fixture + def mock_csq_expr(self) -> hl.expr.ArrayExpression: + """Fixture to create a mock array of VEP consequences.""" + + def _build_csq_struct( + csq: str, + protein_coding: bool, + lof: str, + no_lof_flags: bool, + polyphen_prediction: Optional[str] = None, + ) -> hl.Struct: + """ + Build a mock VEP consequence struct. + + :param csq: The consequence term. + :param protein_coding: Whether the consequence is protein coding. + :param lof: The LOF value. + :param no_lof_flags: Whether the consequence has no LOF flags. + :param polyphen_prediction: The PolyPhen prediction. + :return: The mock VEP consequence struct. + """ + return hl.Struct( + biotype="protein_coding" if protein_coding else "not_protein_coding", + lof=lof, + lof_flags=hl.missing(hl.tstr) if no_lof_flags else "flag1", + consequence_terms=[csq], + polyphen_prediction=( + hl.missing(hl.tstr) + if polyphen_prediction is None + else polyphen_prediction + ), + ) + + struct_order = [ + ["stop_gained", True, "HC", True, None], + ["stop_lost", True, "HC", True, None], + ["splice_acceptor_variant", True, "HC", False, "benign"], + ["splice_acceptor_variant", True, "HC", False, "possibly_damaging"], + ["splice_acceptor_variant", True, "LC", True, None], + ["splice_acceptor_variant", True, "LC", True, "possibly_damaging"], + ["splice_acceptor_variant", True, "LC", False, "possibly_damaging"], + ["stop_gained", False, "HC", True, None], + ["stop_gained", False, "HC", True, "probably_damaging"], + ["splice_acceptor_variant", False, "HC", False, "possibly_damaging"], + ["splice_acceptor_variant", False, "LC", True, "probably_damaging"], + ["splice_acceptor_variant", False, "LC", True, None], + ["splice_acceptor_variant", False, "LC", True, "possibly_damaging"], + ["splice_acceptor_variant", False, "LC", False, "probably_damaging"], + ] + + return hl.literal([_build_csq_struct(*p) for p in struct_order]) + + polyphen_order = ["probably_damaging", "possibly_damaging", "benign"] + polyphen_params = ["polyphen_prediction", polyphen_order] + + @pytest.mark.parametrize( + "prioritize_protein_coding, prioritize_loftee, prioritize_loftee_no_flags, additional_order_field, additional_order, expected_most_severe_csq, expected_polyphen_prediction", + [ + (False, False, False, None, None, None, None), + (True, False, False, None, None, None, None), + (False, True, False, None, None, None, None), + (False, False, True, None, None, None, None), + (False, False, False, *polyphen_params, None, None), + (True, True, False, None, None, None, None), + (True, False, True, None, None, None, None), + (True, False, False, *polyphen_params, None, "possibly_damaging"), + (False, True, True, None, None, "stop_gained", None), + (False, True, False, *polyphen_params, None, "possibly_damaging"), + (False, False, True, *polyphen_params, None, None), + (True, True, True, None, None, "stop_gained", None), + (True, True, False, *polyphen_params, None, "possibly_damaging"), + (True, False, True, *polyphen_params, None, "possibly_damaging"), + (False, True, True, *polyphen_params, "stop_gained", None), + # Need to figure out class too large error + (True, True, True, *polyphen_params, "stop_gained", "possibly_damaging"), + ], + ) + def test_filter_to_most_severe_consequences( + self, + mock_csq_expr: hl.expr.ArrayExpression, + prioritize_protein_coding: bool, + prioritize_loftee: bool, + prioritize_loftee_no_flags: bool, + additional_order_field: str, + additional_order: List[str], + expected_most_severe_csq: str, + expected_polyphen_prediction: str, + ) -> None: + """ + Test the filter_to_most_severe_consequences function. + + :param mock_csq_expr: The mock VEP consequence expression. + :param prioritize_protein_coding: Whether to prioritize protein coding. + :param prioritize_loftee: Whether to prioritize LOFTEE. + :param prioritize_loftee_no_flags: Whether to prioritize LOFTEE no flags. + :param additional_order_field: The additional order field to use. + :param additional_order: The additional order to use. + :param expected_most_severe_csq: The expected most severe consequence. + :param expected_polyphen_prediction: The expected PolyPhen prediction. + :return: None. + """ + result = filter_to_most_severe_consequences( + mock_csq_expr, + prioritize_protein_coding=prioritize_protein_coding, + prioritize_loftee=prioritize_loftee, + prioritize_loftee_no_flags=prioritize_loftee_no_flags, + additional_order_field=additional_order_field, + additional_order=additional_order, + ) + + expected_dict = hl.Struct( + protein_coding=True, + lof="HC", + most_severe_consequence="splice_acceptor_variant", + no_lof_flags=True, + ) + + def _get_csq_structs( + csq: str, + protein_coding: Optional[bool] = None, + lof: Optional[str] = None, + no_lof_flags: Optional[bool] = None, + polyphen_prediction: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Get the expected consequence structs. + + :param csq: The consequence term to filter by. + :param protein_coding: Whether to filter by protein coding. + :param lof: The LOF value to filter by. + :param no_lof_flags: Whether to filter by no LOF flags. + :param polyphen_prediction: The PolyPhen prediction to filter by. + :return: The expected consequence structs. + """ + + def _get_csq_criteria(s): + keep = s.consequence_terms.contains(csq) + if protein_coding: + keep &= s.biotype == "protein_coding" + if lof: + keep &= s.lof == lof + if no_lof_flags: + keep &= hl.is_missing(s.lof_flags) + if polyphen_prediction: + keep &= s.polyphen_prediction == polyphen_prediction + + return keep + + return hl.eval(mock_csq_expr.filter(_get_csq_criteria)) + + expected_polyphen_prediction = ( + expected_polyphen_prediction or "probably_damaging" + ) + expected_most_severe_csq = expected_most_severe_csq or "splice_acceptor_variant" + add_ms = expected_most_severe_csq != "splice_acceptor_variant" + + expected_select = ( + (["protein_coding"] if prioritize_protein_coding else []) + + (["lof"] if prioritize_loftee else []) + + ( + ["no_lof_flags"] + if prioritize_loftee_no_flags or prioritize_loftee + else [] + ) + + (["most_severe_consequence"] if not add_ms else []) + ) + expected = expected_dict.select( + *expected_select, + **({"most_severe_consequence": expected_most_severe_csq} if add_ms else {}), + consequences=_get_csq_structs( + expected_most_severe_csq, + protein_coding=prioritize_protein_coding, + lof="HC" if prioritize_loftee else None, + no_lof_flags=prioritize_loftee_no_flags, + polyphen_prediction=( + expected_polyphen_prediction + if additional_order_field == "polyphen_prediction" + else None + ), + ), + ) + assert hl.eval(result) == expected, f"Expected '{expected}'" + + +class TestFilterVepTranscriptCsqsExprLoftee: + """Tests for the filter_vep_transcript_csqs_expr function.""" + + @pytest.fixture + def csq_expr(self) -> hl.expr.ArrayExpression: + """Fixture to create a mock array of VEP consequences with LOFTEE annotations.""" + return hl.literal( + [ + hl.struct(lof="HC", lof_flags=hl.missing(hl.tstr)), + hl.struct(lof="HC", lof_flags=""), + hl.struct(lof="LC", lof_flags="flag1"), + hl.struct(lof="OS", lof_flags=hl.missing(hl.tstr)), + hl.struct(lof=hl.missing(hl.tstr), lof_flags="flag2"), + ] + ) + + @staticmethod + def check_length(result: hl.expr.ArrayExpression, expected_len: int) -> None: + """ + Check the length of the result. + + :param result: The result to check. + :param expected_len: The expected length. + :return: None. + """ + assert ( + hl.eval(hl.len(result)) == expected_len + ), f"Expected {expected_len} consequences" + + @staticmethod + def check_lof(result: hl.expr.ArrayExpression, expected_lof: str) -> None: + """ + Check the LOF annotation. + + :param result: The result to check. + :param expected_lof: The expected LOF annotation. + :return: None. + """ + assert hl.eval(result[0].lof) == expected_lof, f"Expected '{expected_lof}'" + + @staticmethod + def check_no_lof_flags(result: hl.expr.ArrayExpression) -> None: + """ + Check the no LOF flags value. + + :param result: The result to check. + :return: None. + """ + assert hl.eval( + hl.all( + result.map( + lambda csq: hl.is_missing(csq.lof_flags) | (csq.lof_flags == "") + ) + ) + ), "Expected no LOFTEE flags" + + @pytest.mark.parametrize( + "loftee_labels, no_lof_flags, expected_len, expected_lof, expected_no_lof_flags", + [ + (None, None, 5, None, None), + (["HC"], None, 2, "HC", None), + (None, True, 3, None, True), + (["HC"], True, 2, "HC", True), + ], + ) + def test_filter_vep_transcript_csqs_expr_loftee( + self, + csq_expr: hl.expr.ArrayExpression, + loftee_labels: list, + no_lof_flags: bool, + expected_len: int, + expected_lof: str, + expected_no_lof_flags: bool, + ) -> None: + """ + Test the filter_vep_transcript_csqs_expr function. + + :param csq_expr: The VEP consequence expression to filter. + :param loftee_labels: The LOFTEE labels to filter by. + :param no_lof_flags: Whether to filter by no LOF flags. + :param expected_len: The expected length of the result. + :param expected_lof: The expected LOF value. + :param expected_no_lof_flags: The expected no LOF flags value. + :return: None. + """ + result = filter_vep_transcript_csqs_expr( + csq_expr, loftee_labels=loftee_labels, no_lof_flags=no_lof_flags + ) + self.check_length(result, expected_len) + if expected_lof is not None: + self.check_lof(result, expected_lof) + if expected_no_lof_flags is not None: + self.check_no_lof_flags(result) + + +class TestGetMostSevereCsqFromMultipleCsqLists: + """Tests for the get_most_severe_csq_from_multiple_csq_lists function.""" + + @pytest.fixture + def vep_expr(self) -> hl.expr.StructExpression: + """Fixture to create a mock VEP expression.""" + return hl.struct( + transcript_consequences=[ + hl.struct( + biotype="protein_coding", + lof="HC", + lof_flags="flag1", + consequence_terms=["splice_acceptor_variant"], + ), + hl.struct( + biotype="protein_coding", + lof="HC", + lof_flags="flag1", + consequence_terms=["splice_acceptor_variant"], + ), + hl.struct( + biotype="protein_coding", + lof="HC", + lof_flags=hl.missing(hl.tstr), + consequence_terms=["stop_lost"], + ), + hl.struct( + biotype="protein_coding", + lof="HC", + lof_flags=hl.missing(hl.tstr), + consequence_terms=["stop_gained"], + ), + hl.struct( + biotype="protein_coding", + lof=hl.missing(hl.tstr), + lof_flags=hl.missing(hl.tstr), + consequence_terms=["missense_variant"], + ), + ], + intergenic_consequences=[ + hl.struct( + biotype="intergenic", + consequence_terms=["intergenic_variant"], + ) + ], + ) + + @pytest.mark.parametrize( + "prioritize_loftee_no_flags, include_csqs, expected_most_severe, expected_protein_coding, expected_lof, expected_no_lof_flags, expected_transcript_consequences_len", + [ + (False, False, "splice_acceptor_variant", True, "HC", True, None), + (True, False, "stop_gained", True, "HC", True, None), + (False, True, "splice_acceptor_variant", True, "HC", True, 2), + (True, True, "stop_gained", True, "HC", True, 1), + ], + ) + def test_get_most_severe_csq_from_multiple_csq_lists( + self, + vep_expr: hl.expr.StructExpression, + prioritize_loftee_no_flags: bool, + include_csqs: bool, + expected_most_severe: str, + expected_protein_coding: bool, + expected_lof: str, + expected_no_lof_flags: bool, + expected_transcript_consequences_len: Optional[int], + ) -> None: + """ + Test the get_most_severe_csq_from_multiple_csq_lists function. + + :param vep_expr: The VEP expression to test. + :param prioritize_loftee_no_flags: Whether to prioritize LOFTEE no flags. + :param include_csqs: Whether to include consequences. + :param expected_most_severe: The expected most severe consequence. + :param expected_protein_coding: The expected protein coding value. + :param expected_lof: The expected LOF value. + :param expected_no_lof_flags: The expected no LOF flags value. + :param expected_transcript_consequences_len: The expected length of transcript + consequences. + :return: None. + """ + result = get_most_severe_csq_from_multiple_csq_lists( + vep_expr, + prioritize_loftee_no_flags=prioritize_loftee_no_flags, + include_csqs=include_csqs, + ) + self.check_most_severe_consequence(result, expected_most_severe) + self.check_protein_coding(result, expected_protein_coding) + self.check_lof(result, expected_lof) + self.check_no_lof_flags(result, expected_no_lof_flags) + if include_csqs: + self.check_transcript_consequences_len( + result, expected_transcript_consequences_len + ) + + @staticmethod + def check_most_severe_consequence( + result: hl.expr.StructExpression, expected: str + ) -> None: + """ + Check the most severe consequence. + + :param result: The result to check. + :param expected: The expected most severe consequence. + :return: None. + """ + assert ( + hl.eval(result.most_severe_consequence) == expected + ), f"Expected '{expected}'" + + @staticmethod + def check_protein_coding(result: hl.expr.StructExpression, expected: bool) -> None: + """ + Check the protein coding value. + + :param result: The result to check. + :param expected: The expected protein coding value. + :return: None. + """ + assert hl.eval(result.protein_coding) == expected, f"Expected '{expected}'" + + @staticmethod + def check_lof(result: hl.expr.StructExpression, expected: str) -> None: + """ + Check the LOF value. + + :param result: The result to check. + :param expected: The expected LOF value. + :return: None. + """ + assert hl.eval(result.lof) == expected, f"Expected '{expected}'" + + @staticmethod + def check_no_lof_flags(result: hl.expr.StructExpression, expected: bool) -> None: + """ + Check the no LOF flags value. + + :param result: The result to check. + :param expected: The expected no LOF flags value. + :return: None. + """ + assert hl.eval(result.no_lof_flags) == expected, f"Expected '{expected}'" + + @staticmethod + def check_transcript_consequences_len( + result: hl.expr.StructExpression, expected: int + ) -> None: + """ + Check the length of transcript consequences. + + :param result: The result to check. + :param expected: The expected length. + :return: None. + """ + assert ( + hl.eval(hl.len(result.transcript_consequences)) == expected + ), f"Expected {expected} transcript consequences"