Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for sample group infrastructure #441

Merged
merged 7 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@
NODE_IS_RECOMBINANT = 1 << 23
NODE_IS_EXACT_MATCH = 1 << 24
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25
NODE_IN_SAMPLE_GROUP = 1 << 26
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP = 1 << 27
NODE_IS_REFERENCE = 1 << 28
NODE_IS_UNCONDITIONALLY_INCLUDED = 1 << 29
NODE_IS_REFERENCE = 1 << 26
NODE_IS_UNCONDITIONALLY_INCLUDED = 1 << 27


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -77,18 +75,6 @@ class FlagValue:
"Node is marking the existance of an immediate reversion which "
"has not been removed for technical reasons",
),
FlagValue(
NODE_IN_SAMPLE_GROUP,
"G",
"SampleGroup",
"Node is a member of a sample group",
),
FlagValue(
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
"Q",
"RetroSampleGroup",
"Node is a member of a retrospective sample group",
),
FlagValue(
NODE_IS_REFERENCE,
"F",
Expand Down
32 changes: 9 additions & 23 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def check_base_ts(ts):
return sc2ts_md["date"]



def mask_ambiguous(a):
a = a.copy()
a[a > DELETION] = -1
Expand Down Expand Up @@ -620,7 +619,7 @@ def extend(
show_progress=show_progress,
random_seed=random_seed,
num_threads=num_threads,
memory_limit=memory_limit * 2**30, # Convert to bytes
memory_limit=memory_limit * 2**30, # Convert to bytes
)

prov = get_provenance_dict("extend", params, start_time)
Expand Down Expand Up @@ -751,7 +750,6 @@ def _extend(
match_db=match_db,
date=date,
min_group_size=1,
additional_node_flags=core.NODE_IN_SAMPLE_GROUP,
show_progress=show_progress,
phase="close",
)
Expand All @@ -769,7 +767,6 @@ def _extend(
min_root_mutations=min_root_mutations,
max_mutations_per_sample=max_mutations_per_sample,
max_recurrent_mutations=max_recurrent_mutations,
additional_node_flags=core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
show_progress=show_progress,
phase="retro",
)
Expand Down Expand Up @@ -864,9 +861,10 @@ def match_path_ts(group):
site_id_map = {}
first_sample = len(tables.nodes)
root = len(group)
group_id = group.sample_hash
for sample in group:
assert sample.hmm_match.path == list(group.path)
node_id = add_sample_to_tables(sample, tables, group_id=group.sample_hash)
node_id = add_sample_to_tables(sample, tables, group_id=group_id)
tables.edges.add_row(0, tables.sequence_length, parent=root, child=node_id)
for mut in sample.hmm_match.mutations:
if (mut.site_id, mut.derived_state) in group.immediate_reversions:
Expand Down Expand Up @@ -968,7 +966,6 @@ class SampleGroup:
samples: List = None
path: List = None
immediate_reversions: List = None
additional_keys: Dict = None
sample_hash: str = None
tree_quality_metrics: GroupTreeQualityMetrics = None

Expand Down Expand Up @@ -1002,7 +999,6 @@ def summary(self):
f"{dict(self.date_count)} "
f"{dict(self.pango_count)} "
f"immediate_reversions={self.immediate_reversions} "
f"additional_keys={self.additional_keys} "
f"path={path_summary(self.path)} "
f"strains={self.strains}"
)
Expand Down Expand Up @@ -1034,9 +1030,7 @@ def add_matching_results(
min_root_mutations=0,
max_mutations_per_sample=np.inf,
max_recurrent_mutations=np.inf,
additional_node_flags=None,
show_progress=False,
additional_group_metadata_keys=list(),
phase=None,
):
logger.info(f"Querying match DB WHERE: {where_clause}")
Expand All @@ -1057,10 +1051,7 @@ def add_matching_results(
for mut in sample.hmm_match.mutations
if mut.is_immediate_reversion
)
additional_metadata = [
sample.metadata.get(k, None) for k in additional_group_metadata_keys
]
key = (path, immediate_reversions, *additional_metadata)
key = (path, immediate_reversions)
grouped_matches[key].append(sample)
num_samples += 1

Expand All @@ -1073,7 +1064,6 @@ def add_matching_results(
samples,
key[0],
key[1],
{k: v for k, v in zip(additional_group_metadata_keys, key[2:])},
)
for key, samples in grouped_matches.items()
]
Expand Down Expand Up @@ -1120,7 +1110,7 @@ def add_matching_results(
f"exceeds threshold: {group.summary()}"
)
continue
nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags)
nodes = attach_tree(ts, tables, group, poly_ts, date)
logger.debug(
f"Attach {phase} metrics:{tqm.summary()}"
f"attach_nodes={len(nodes)} "
Expand Down Expand Up @@ -1701,7 +1691,6 @@ def attach_tree(
group,
child_ts,
date,
additional_node_flags,
epsilon=None,
):
attach_path = group.path
Expand Down Expand Up @@ -1742,6 +1731,7 @@ def attach_tree(
if child_ts.nodes_time[tree.root] != 1.0:
raise ValueError("Time must be scaled from 0 to 1.")

group_id = group.sample_hash
num_internal_nodes_visited = 0
for u in tree.postorder()[:-1]:
node = child_ts.node(u)
Expand All @@ -1756,14 +1746,12 @@ def attach_tree(
if tree.is_internal(u):
metadata = {
"sc2ts": {
"group_id": group.sample_hash,
"group_id": group_id,
"date_added": date,
}
}
new_id = parent_tables.nodes.append(
node.replace(
flags=node.flags | additional_node_flags, time=time, metadata=metadata
)
node.replace(flags=node.flags, time=time, metadata=metadata)
)
node_id_map[node.id] = new_id
for v in tree.children(u):
Expand All @@ -1790,9 +1778,7 @@ def attach_tree(
node=node_id_map[mutation.node],
derived_state=mutation.derived_state,
time=node_time[mutation.node],
metadata={
"sc2ts": {"type": "parsimony", "group_id": group.sample_hash}
},
metadata={"sc2ts": {"type": "parsimony", "group_id": group_id}},
)

if len(group.immediate_reversions) > 0:
Expand Down
87 changes: 71 additions & 16 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@ def __init__(
# The number of samples per day in time-ago (i.e., the nodes_time units).
self.num_samples_per_day = np.bincount(ts.nodes_time[samples].astype(int))

self.sample_group_id_prefix_len = 8
self.sample_group_nodes = collections.defaultdict(list)
self.sample_group_mutations = collections.defaultdict(list)
self.retro_sample_groups = {}
for retro_group in top_level_md["retro_groups"]:
gid = retro_group["group_id"][: self.sample_group_id_prefix_len]
self.retro_sample_groups[gid] = retro_group

if not quick:
self._preprocess_nodes(show_progress)
self._preprocess_sites(show_progress)
Expand All @@ -461,10 +469,7 @@ def node_counts(self):
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0)
sg_nodes = np.sum((self.ts.nodes_flags & core.NODE_IN_SAMPLE_GROUP) > 0)
rsg_nodes = np.sum(
(self.ts.nodes_flags & core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP) > 0
)
u_nodes = np.sum((self.ts.nodes_flags & core.NODE_IS_UNCONDITIONALLY_INCLUDED) > 0)
immediate_reversion_marker = np.sum(
(self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0
)
Expand All @@ -476,8 +481,7 @@ def node_counts(self):
"mc": mc_nodes,
"pr": pr_nodes,
"re": re_nodes,
"sg": sg_nodes,
"rsg": rsg_nodes,
"u": u_nodes,
"imr": immediate_reversion_marker,
"zero_muts": nodes_with_zero_muts,
}
Expand All @@ -492,7 +496,6 @@ def _preprocess_nodes(self, show_progress):
self.nodes_num_deletion_sites = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_num_exact_matches = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_metadata = {}
self.nodes_sample_group = collections.defaultdict(list)
samples = ts.samples()

self.time_zero_as_date = np.array([self.date], dtype="datetime64[D]")[0]
Expand All @@ -518,7 +521,9 @@ def _preprocess_nodes(self, show_progress):
sc2ts_md = md["sc2ts"]
group_id = sc2ts_md.get("group_id", None)
if group_id is not None:
self.nodes_sample_group[group_id].append(node.id)
# Shorten key for readability.
gid = group_id[: self.sample_group_id_prefix_len]
self.sample_group_nodes[gid].append(node.id)
if node.is_sample():
self.nodes_date[node.id] = md["date"]
pango = md.get(self.pango_source, "unknown")
Expand Down Expand Up @@ -617,6 +622,7 @@ def _preprocess_mutations(self, show_progress):
iterator = tqdm(
np.arange(N), desc="Classifying mutations", disable=not show_progress
)
mutation_table = ts.tables.mutations
for mut_id in iterator:
tree.seek(self.mutations_position[mut_id])
mutation_node = ts.mutations_node[mut_id]
Expand Down Expand Up @@ -645,6 +651,13 @@ def _preprocess_mutations(self, show_progress):
sites_num_transitions[site] += mutations_is_transition[mut_id]
sites_num_transversions[site] += mutations_is_transversion[mut_id]

# classify by origin
md = mutation_table[mut_id].metadata["sc2ts"]
inference_type = md.get("type", None)
if inference_type == "parsimony":
gid = md["group_id"][: self.sample_group_id_prefix_len]
self.sample_group_mutations[gid].append(mut_id)

# Note: no real good reason for not just using self.mutations_num_descendants
# etc above
self.mutations_num_descendants = mutations_num_descendants
Expand Down Expand Up @@ -717,6 +730,8 @@ def summary(self):
),
("max_samples_per_day", np.max(self.num_samples_per_day)),
("mean_samples_per_day", np.mean(self.num_samples_per_day)),
("sample_groups", len(self.sample_group_nodes)),
("retro_sample_groups", len(self.retro_sample_groups)),
]
df = pd.DataFrame(
{"property": [d[0] for d in data], "value": [d[1] for d in data]}
Expand Down Expand Up @@ -757,7 +772,10 @@ def _node_summary(self, u, child_mutations=True):
strain = md["strain"]
else:
md = md["sc2ts"]
if flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH) > 0:
if (
flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH)
> 0
):
try:
strain = f"{md['date_added']}:{', '.join(md['mutations'])}"
except KeyError:
Expand Down Expand Up @@ -883,16 +901,50 @@ def samples_summary(self):
df["total_hmm_cost"] = df["mean_hmm_cost"] * df["total"]
return df.astype({"date": "datetime64[s]"})

def sample_groups_summary(self):
data = []
for group_id, nodes in self.sample_group_nodes.items():
samples = []
full_hashes = []
for u in nodes:
node = self.ts.node(u)
if node.is_sample():
samples.append(u)
full_hashes.append(node.metadata["sc2ts"]["group_id"])
assert len(set(full_hashes)) == 1
assert full_hashes[0].startswith(group_id)
data.append(
{
"group_id": group_id,
"nodes": len(nodes),
"samples": len(samples),
"mutations": len(self.sample_group_mutations[group_id]),
"is_retro": group_id in self.retro_sample_groups,
}
)
return pd.DataFrame(data).set_index("group_id")

def retro_sample_groups_summary(self):
data = []
for group_id, retro_group in self.retro_sample_groups.items():
d = dict(retro_group)
d["group_id"] = group_id
d["dates"] = len(set(d["dates"]))
d["samples"] = len(d.pop("strains"))
d["pango_lineages"] = len(set(d["pango_lineages"]))
data.append(d)
return pd.DataFrame(data).set_index("group_id")

def recombinants_summary(self):
data = []
for u in self.recombinants:
md = self.nodes_metadata[u]["sc2ts"]
group_id = md["group_id"]
group_id = md["group_id"][: self.sample_group_id_prefix_len]
# NOTE this is overlapping quite a bit with the SampleGroupInfo
# class functionality here, but we just want something quick for
# now here.
causal_lineages = collections.Counter()
for v in self.nodes_sample_group[group_id]:
for v in self.sample_group_nodes[group_id]:
if self.ts.nodes_flags[v] & tskit.NODE_IS_SAMPLE > 0:
pango = self.nodes_metadata[v].get(self.pango_source, "Unknown")
causal_lineages[pango] += 1
Expand Down Expand Up @@ -1130,7 +1182,7 @@ def css_cell(allele, bold=False, show_colour=True):
parent_allele = var.alleles[var.genotypes[j]]
css = css_cell(
parent_allele,
bold=parent_allele==child_allele,
bold=parent_allele == child_allele,
show_colour=j == parent_col,
)
parents[j - 1].append(f"<td{css}>{parent_allele}</td>")
Expand Down Expand Up @@ -1305,7 +1357,9 @@ def recombinant_samples_report(self, nodes):
closest_recombinant, path_length = self._get_closest_recombinant(tree, node)
sample_is_recombinant = False
if closest_recombinant != -1:
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"]["date_added"]
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"][
"date_added"
]
sample_is_recombinant = recomb_date == str(node_summary["date"])
summary = {
"recombinant": closest_recombinant,
Expand Down Expand Up @@ -1972,12 +2026,13 @@ def draw_subtree(
def get_sample_group_info(self, group_id):
samples = []

for u in self.nodes_sample_group[group_id]:
group_nodes = self.sample_group_nodes[group_id]
for u in group_nodes:
if self.ts.nodes_flags[u] & tskit.NODE_IS_SAMPLE > 0:
samples.append(u)

tree = self.ts.first()
while self.nodes_metadata[u]["sc2ts"].get("group_id", None) == group_id:
while u in group_nodes:
u = tree.parent(u)
attach_date = self.nodes_date[u]
ts = self.ts.simplify(samples + [u])
Expand All @@ -2004,7 +2059,7 @@ def get_sample_group_info(self, group_id):

return SampleGroupInfo(
group_id,
self.nodes_sample_group[group_id],
group_nodes,
ts=tables.tree_sequence(),
attach_date=attach_date,
)
Expand Down
Loading
Loading