Skip to content

Commit

Permalink
Add some basic summaries on sample groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Dec 12, 2024
1 parent c4813a7 commit bfcc922
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
54 changes: 49 additions & 5 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ def __init__(
# The number of samples per day in time-ago (i.e., the nodes_time units).
self.num_samples_per_day = np.bincount(ts.nodes_time[samples].astype(int))

self.sample_group_id_prefix_len = 8
self.sample_group_nodes = collections.defaultdict(list)
self.sample_group_mutations = collections.defaultdict(list)
self.retro_sample_groups = {}
# print(top_level_md)
# # for retro_group in ope

if not quick:
self._preprocess_nodes(show_progress)
self._preprocess_sites(show_progress)
Expand Down Expand Up @@ -492,7 +499,6 @@ def _preprocess_nodes(self, show_progress):
self.nodes_num_deletion_sites = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_num_exact_matches = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_metadata = {}
self.nodes_sample_group = collections.defaultdict(list)
samples = ts.samples()

self.time_zero_as_date = np.array([self.date], dtype="datetime64[D]")[0]
Expand All @@ -518,7 +524,9 @@ def _preprocess_nodes(self, show_progress):
sc2ts_md = md["sc2ts"]
group_id = sc2ts_md.get("group_id", None)
if group_id is not None:
self.nodes_sample_group[group_id].append(node.id)
# Shorten key for readability.
gid = group_id[: self.sample_group_id_prefix_len]
self.sample_group_nodes[gid].append(node.id)
if node.is_sample():
self.nodes_date[node.id] = md["date"]
pango = md.get(self.pango_source, "unknown")
Expand Down Expand Up @@ -617,6 +625,7 @@ def _preprocess_mutations(self, show_progress):
iterator = tqdm(
np.arange(N), desc="Classifying mutations", disable=not show_progress
)
mutation_table = ts.tables.mutations
for mut_id in iterator:
tree.seek(self.mutations_position[mut_id])
mutation_node = ts.mutations_node[mut_id]
Expand Down Expand Up @@ -645,6 +654,13 @@ def _preprocess_mutations(self, show_progress):
sites_num_transitions[site] += mutations_is_transition[mut_id]
sites_num_transversions[site] += mutations_is_transversion[mut_id]

# classify by origin
md = mutation_table[mut_id].metadata["sc2ts"]
inference_type = md.get("type", None)
if inference_type == "parsimony":
gid = md["group_id"][: self.sample_group_id_prefix_len]
self.sample_group_mutations[gid].append(mut_id)

# Note: no real good reason for not just using self.mutations_num_descendants
# etc above
self.mutations_num_descendants = mutations_num_descendants
Expand Down Expand Up @@ -757,7 +773,10 @@ def _node_summary(self, u, child_mutations=True):
strain = md["strain"]
else:
md = md["sc2ts"]
if flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH) > 0:
if (
flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH)
> 0
):
try:
strain = f"{md['date_added']}:{', '.join(md['mutations'])}"
except KeyError:
Expand Down Expand Up @@ -883,6 +902,29 @@ def samples_summary(self):
df["total_hmm_cost"] = df["mean_hmm_cost"] * df["total"]
return df.astype({"date": "datetime64[s]"})

def sample_groups_summary(self):
data = []
for group_id, nodes in self.sample_group_nodes.items():
samples = []
full_hashes = []
for u in nodes:
node = self.ts.node(u)
if node.is_sample():
samples.append(u)
full_hashes.append(node.metadata["sc2ts"]["group_id"])
assert len(set(full_hashes)) == 1
assert full_hashes[0].startswith(group_id)
data.append(
{
"group_id": group_id,
"nodes": len(nodes),
"samples": len(samples),
"mutations": len(self.sample_group_mutations[group_id]),
"is_retro": group_id in self.retro_sample_groups,
}
)
return pd.DataFrame(data).set_index("group_id")

def recombinants_summary(self):
data = []
for u in self.recombinants:
Expand Down Expand Up @@ -1130,7 +1172,7 @@ def css_cell(allele, bold=False, show_colour=True):
parent_allele = var.alleles[var.genotypes[j]]
css = css_cell(
parent_allele,
bold=parent_allele==child_allele,
bold=parent_allele == child_allele,
show_colour=j == parent_col,
)
parents[j - 1].append(f"<td{css}>{parent_allele}</td>")
Expand Down Expand Up @@ -1305,7 +1347,9 @@ def recombinant_samples_report(self, nodes):
closest_recombinant, path_length = self._get_closest_recombinant(tree, node)
sample_is_recombinant = False
if closest_recombinant != -1:
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"]["date_added"]
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"][
"date_added"
]
sample_is_recombinant = recomb_date == str(node_summary["date"])
summary = {
"recombinant": closest_recombinant,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def test_samples_summary(self, fx_ti_2020_02_13):
assert np.all(df["total"] >= (df["inserted"] + df["exact_matches"]))
assert df.shape[0] > 0

def test_sample_group_summary(self, fx_ti_2020_02_13):
df = fx_ti_2020_02_13.sample_groups_summary()
assert df.shape[0] == 26
assert np.all(df["nodes"] >= df["samples"])
assert np.all(df["nodes"] > 0)

def test_node_summary(self, fx_ti_2020_02_13):
ti = fx_ti_2020_02_13
for u in range(ti.ts.num_nodes):
Expand Down

0 comments on commit bfcc922

Please sign in to comment.