From bfcc922e21173658b04e458e43f9d51a21baa0ff Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 12 Dec 2024 21:17:56 +0000 Subject: [PATCH] Add some basic summaries on sample groups --- sc2ts/info.py | 54 +++++++++++++++++++++++++++++++++++++++++----- tests/test_info.py | 6 ++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index 08a56bf..06213a9 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -451,6 +451,13 @@ def __init__( # The number of samples per day in time-ago (i.e., the nodes_time units). self.num_samples_per_day = np.bincount(ts.nodes_time[samples].astype(int)) + self.sample_group_id_prefix_len = 8 + self.sample_group_nodes = collections.defaultdict(list) + self.sample_group_mutations = collections.defaultdict(list) + self.retro_sample_groups = {} + # print(top_level_md) + # # for retro_group in ope + if not quick: self._preprocess_nodes(show_progress) self._preprocess_sites(show_progress) @@ -492,7 +499,6 @@ def _preprocess_nodes(self, show_progress): self.nodes_num_deletion_sites = np.zeros(ts.num_nodes, dtype=np.int32) self.nodes_num_exact_matches = np.zeros(ts.num_nodes, dtype=np.int32) self.nodes_metadata = {} - self.nodes_sample_group = collections.defaultdict(list) samples = ts.samples() self.time_zero_as_date = np.array([self.date], dtype="datetime64[D]")[0] @@ -518,7 +524,9 @@ def _preprocess_nodes(self, show_progress): sc2ts_md = md["sc2ts"] group_id = sc2ts_md.get("group_id", None) if group_id is not None: - self.nodes_sample_group[group_id].append(node.id) + # Shorten key for readability. + gid = group_id[: self.sample_group_id_prefix_len] + self.sample_group_nodes[gid].append(node.id) if node.is_sample(): self.nodes_date[node.id] = md["date"] pango = md.get(self.pango_source, "unknown") @@ -617,6 +625,7 @@ def _preprocess_mutations(self, show_progress): iterator = tqdm( np.arange(N), desc="Classifying mutations", disable=not show_progress ) + mutation_table = ts.tables.mutations for mut_id in iterator: tree.seek(self.mutations_position[mut_id]) mutation_node = ts.mutations_node[mut_id] @@ -645,6 +654,13 @@ def _preprocess_mutations(self, show_progress): sites_num_transitions[site] += mutations_is_transition[mut_id] sites_num_transversions[site] += mutations_is_transversion[mut_id] + # classify by origin + md = mutation_table[mut_id].metadata["sc2ts"] + inference_type = md.get("type", None) + if inference_type == "parsimony": + gid = md["group_id"][: self.sample_group_id_prefix_len] + self.sample_group_mutations[gid].append(mut_id) + # Note: no real good reason for not just using self.mutations_num_descendants # etc above self.mutations_num_descendants = mutations_num_descendants @@ -757,7 +773,10 @@ def _node_summary(self, u, child_mutations=True): strain = md["strain"] else: md = md["sc2ts"] - if flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH) > 0: + if ( + flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH) + > 0 + ): try: strain = f"{md['date_added']}:{', '.join(md['mutations'])}" except KeyError: @@ -883,6 +902,29 @@ def samples_summary(self): df["total_hmm_cost"] = df["mean_hmm_cost"] * df["total"] return df.astype({"date": "datetime64[s]"}) + def sample_groups_summary(self): + data = [] + for group_id, nodes in self.sample_group_nodes.items(): + samples = [] + full_hashes = [] + for u in nodes: + node = self.ts.node(u) + if node.is_sample(): + samples.append(u) + full_hashes.append(node.metadata["sc2ts"]["group_id"]) + assert len(set(full_hashes)) == 1 + assert full_hashes[0].startswith(group_id) + data.append( + { + "group_id": group_id, + "nodes": len(nodes), + "samples": len(samples), + "mutations": len(self.sample_group_mutations[group_id]), + "is_retro": group_id in self.retro_sample_groups, + } + ) + return pd.DataFrame(data).set_index("group_id") + def recombinants_summary(self): data = [] for u in self.recombinants: @@ -1130,7 +1172,7 @@ def css_cell(allele, bold=False, show_colour=True): parent_allele = var.alleles[var.genotypes[j]] css = css_cell( parent_allele, - bold=parent_allele==child_allele, + bold=parent_allele == child_allele, show_colour=j == parent_col, ) parents[j - 1].append(f"{parent_allele}") @@ -1305,7 +1347,9 @@ def recombinant_samples_report(self, nodes): closest_recombinant, path_length = self._get_closest_recombinant(tree, node) sample_is_recombinant = False if closest_recombinant != -1: - recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"]["date_added"] + recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"][ + "date_added" + ] sample_is_recombinant = recomb_date == str(node_summary["date"]) summary = { "recombinant": closest_recombinant, diff --git a/tests/test_info.py b/tests/test_info.py index a37d6ae..ad38b15 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -231,6 +231,12 @@ def test_samples_summary(self, fx_ti_2020_02_13): assert np.all(df["total"] >= (df["inserted"] + df["exact_matches"])) assert df.shape[0] > 0 + def test_sample_group_summary(self, fx_ti_2020_02_13): + df = fx_ti_2020_02_13.sample_groups_summary() + assert df.shape[0] == 26 + assert np.all(df["nodes"] >= df["samples"]) + assert np.all(df["nodes"] > 0) + def test_node_summary(self, fx_ti_2020_02_13): ti = fx_ti_2020_02_13 for u in range(ti.ts.num_nodes):