Skip to content

Commit

Permalink
Merge pull request #441 from jeromekelleher/no-singleton-sample-groups
Browse files Browse the repository at this point in the history
Updates for sample group infrastructure
  • Loading branch information
jeromekelleher authored Dec 12, 2024
2 parents 00303ed + 376e51d commit ecf29f4
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 72 deletions.
18 changes: 2 additions & 16 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@
NODE_IS_RECOMBINANT = 1 << 23
NODE_IS_EXACT_MATCH = 1 << 24
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25
NODE_IN_SAMPLE_GROUP = 1 << 26
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP = 1 << 27
NODE_IS_REFERENCE = 1 << 28
NODE_IS_UNCONDITIONALLY_INCLUDED = 1 << 29
NODE_IS_REFERENCE = 1 << 26
NODE_IS_UNCONDITIONALLY_INCLUDED = 1 << 27


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -77,18 +75,6 @@ class FlagValue:
"Node is marking the existance of an immediate reversion which "
"has not been removed for technical reasons",
),
FlagValue(
NODE_IN_SAMPLE_GROUP,
"G",
"SampleGroup",
"Node is a member of a sample group",
),
FlagValue(
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
"Q",
"RetroSampleGroup",
"Node is a member of a retrospective sample group",
),
FlagValue(
NODE_IS_REFERENCE,
"F",
Expand Down
32 changes: 9 additions & 23 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def check_base_ts(ts):
return sc2ts_md["date"]



def mask_ambiguous(a):
a = a.copy()
a[a > DELETION] = -1
Expand Down Expand Up @@ -620,7 +619,7 @@ def extend(
show_progress=show_progress,
random_seed=random_seed,
num_threads=num_threads,
memory_limit=memory_limit * 2**30, # Convert to bytes
memory_limit=memory_limit * 2**30, # Convert to bytes
)

prov = get_provenance_dict("extend", params, start_time)
Expand Down Expand Up @@ -751,7 +750,6 @@ def _extend(
match_db=match_db,
date=date,
min_group_size=1,
additional_node_flags=core.NODE_IN_SAMPLE_GROUP,
show_progress=show_progress,
phase="close",
)
Expand All @@ -769,7 +767,6 @@ def _extend(
min_root_mutations=min_root_mutations,
max_mutations_per_sample=max_mutations_per_sample,
max_recurrent_mutations=max_recurrent_mutations,
additional_node_flags=core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
show_progress=show_progress,
phase="retro",
)
Expand Down Expand Up @@ -864,9 +861,10 @@ def match_path_ts(group):
site_id_map = {}
first_sample = len(tables.nodes)
root = len(group)
group_id = group.sample_hash
for sample in group:
assert sample.hmm_match.path == list(group.path)
node_id = add_sample_to_tables(sample, tables, group_id=group.sample_hash)
node_id = add_sample_to_tables(sample, tables, group_id=group_id)
tables.edges.add_row(0, tables.sequence_length, parent=root, child=node_id)
for mut in sample.hmm_match.mutations:
if (mut.site_id, mut.derived_state) in group.immediate_reversions:
Expand Down Expand Up @@ -968,7 +966,6 @@ class SampleGroup:
samples: List = None
path: List = None
immediate_reversions: List = None
additional_keys: Dict = None
sample_hash: str = None
tree_quality_metrics: GroupTreeQualityMetrics = None

Expand Down Expand Up @@ -1002,7 +999,6 @@ def summary(self):
f"{dict(self.date_count)} "
f"{dict(self.pango_count)} "
f"immediate_reversions={self.immediate_reversions} "
f"additional_keys={self.additional_keys} "
f"path={path_summary(self.path)} "
f"strains={self.strains}"
)
Expand Down Expand Up @@ -1034,9 +1030,7 @@ def add_matching_results(
min_root_mutations=0,
max_mutations_per_sample=np.inf,
max_recurrent_mutations=np.inf,
additional_node_flags=None,
show_progress=False,
additional_group_metadata_keys=list(),
phase=None,
):
logger.info(f"Querying match DB WHERE: {where_clause}")
Expand All @@ -1057,10 +1051,7 @@ def add_matching_results(
for mut in sample.hmm_match.mutations
if mut.is_immediate_reversion
)
additional_metadata = [
sample.metadata.get(k, None) for k in additional_group_metadata_keys
]
key = (path, immediate_reversions, *additional_metadata)
key = (path, immediate_reversions)
grouped_matches[key].append(sample)
num_samples += 1

Expand All @@ -1073,7 +1064,6 @@ def add_matching_results(
samples,
key[0],
key[1],
{k: v for k, v in zip(additional_group_metadata_keys, key[2:])},
)
for key, samples in grouped_matches.items()
]
Expand Down Expand Up @@ -1120,7 +1110,7 @@ def add_matching_results(
f"exceeds threshold: {group.summary()}"
)
continue
nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags)
nodes = attach_tree(ts, tables, group, poly_ts, date)
logger.debug(
f"Attach {phase} metrics:{tqm.summary()}"
f"attach_nodes={len(nodes)} "
Expand Down Expand Up @@ -1701,7 +1691,6 @@ def attach_tree(
group,
child_ts,
date,
additional_node_flags,
epsilon=None,
):
attach_path = group.path
Expand Down Expand Up @@ -1742,6 +1731,7 @@ def attach_tree(
if child_ts.nodes_time[tree.root] != 1.0:
raise ValueError("Time must be scaled from 0 to 1.")

group_id = group.sample_hash
num_internal_nodes_visited = 0
for u in tree.postorder()[:-1]:
node = child_ts.node(u)
Expand All @@ -1756,14 +1746,12 @@ def attach_tree(
if tree.is_internal(u):
metadata = {
"sc2ts": {
"group_id": group.sample_hash,
"group_id": group_id,
"date_added": date,
}
}
new_id = parent_tables.nodes.append(
node.replace(
flags=node.flags | additional_node_flags, time=time, metadata=metadata
)
node.replace(flags=node.flags, time=time, metadata=metadata)
)
node_id_map[node.id] = new_id
for v in tree.children(u):
Expand All @@ -1790,9 +1778,7 @@ def attach_tree(
node=node_id_map[mutation.node],
derived_state=mutation.derived_state,
time=node_time[mutation.node],
metadata={
"sc2ts": {"type": "parsimony", "group_id": group.sample_hash}
},
metadata={"sc2ts": {"type": "parsimony", "group_id": group_id}},
)

if len(group.immediate_reversions) > 0:
Expand Down
87 changes: 71 additions & 16 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@ def __init__(
# The number of samples per day in time-ago (i.e., the nodes_time units).
self.num_samples_per_day = np.bincount(ts.nodes_time[samples].astype(int))

self.sample_group_id_prefix_len = 8
self.sample_group_nodes = collections.defaultdict(list)
self.sample_group_mutations = collections.defaultdict(list)
self.retro_sample_groups = {}
for retro_group in top_level_md["retro_groups"]:
gid = retro_group["group_id"][: self.sample_group_id_prefix_len]
self.retro_sample_groups[gid] = retro_group

if not quick:
self._preprocess_nodes(show_progress)
self._preprocess_sites(show_progress)
Expand All @@ -461,10 +469,7 @@ def node_counts(self):
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0)
sg_nodes = np.sum((self.ts.nodes_flags & core.NODE_IN_SAMPLE_GROUP) > 0)
rsg_nodes = np.sum(
(self.ts.nodes_flags & core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP) > 0
)
u_nodes = np.sum((self.ts.nodes_flags & core.NODE_IS_UNCONDITIONALLY_INCLUDED) > 0)
immediate_reversion_marker = np.sum(
(self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0
)
Expand All @@ -476,8 +481,7 @@ def node_counts(self):
"mc": mc_nodes,
"pr": pr_nodes,
"re": re_nodes,
"sg": sg_nodes,
"rsg": rsg_nodes,
"u": u_nodes,
"imr": immediate_reversion_marker,
"zero_muts": nodes_with_zero_muts,
}
Expand All @@ -492,7 +496,6 @@ def _preprocess_nodes(self, show_progress):
self.nodes_num_deletion_sites = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_num_exact_matches = np.zeros(ts.num_nodes, dtype=np.int32)
self.nodes_metadata = {}
self.nodes_sample_group = collections.defaultdict(list)
samples = ts.samples()

self.time_zero_as_date = np.array([self.date], dtype="datetime64[D]")[0]
Expand All @@ -518,7 +521,9 @@ def _preprocess_nodes(self, show_progress):
sc2ts_md = md["sc2ts"]
group_id = sc2ts_md.get("group_id", None)
if group_id is not None:
self.nodes_sample_group[group_id].append(node.id)
# Shorten key for readability.
gid = group_id[: self.sample_group_id_prefix_len]
self.sample_group_nodes[gid].append(node.id)
if node.is_sample():
self.nodes_date[node.id] = md["date"]
pango = md.get(self.pango_source, "unknown")
Expand Down Expand Up @@ -617,6 +622,7 @@ def _preprocess_mutations(self, show_progress):
iterator = tqdm(
np.arange(N), desc="Classifying mutations", disable=not show_progress
)
mutation_table = ts.tables.mutations
for mut_id in iterator:
tree.seek(self.mutations_position[mut_id])
mutation_node = ts.mutations_node[mut_id]
Expand Down Expand Up @@ -645,6 +651,13 @@ def _preprocess_mutations(self, show_progress):
sites_num_transitions[site] += mutations_is_transition[mut_id]
sites_num_transversions[site] += mutations_is_transversion[mut_id]

# classify by origin
md = mutation_table[mut_id].metadata["sc2ts"]
inference_type = md.get("type", None)
if inference_type == "parsimony":
gid = md["group_id"][: self.sample_group_id_prefix_len]
self.sample_group_mutations[gid].append(mut_id)

# Note: no real good reason for not just using self.mutations_num_descendants
# etc above
self.mutations_num_descendants = mutations_num_descendants
Expand Down Expand Up @@ -717,6 +730,8 @@ def summary(self):
),
("max_samples_per_day", np.max(self.num_samples_per_day)),
("mean_samples_per_day", np.mean(self.num_samples_per_day)),
("sample_groups", len(self.sample_group_nodes)),
("retro_sample_groups", len(self.retro_sample_groups)),
]
df = pd.DataFrame(
{"property": [d[0] for d in data], "value": [d[1] for d in data]}
Expand Down Expand Up @@ -757,7 +772,10 @@ def _node_summary(self, u, child_mutations=True):
strain = md["strain"]
else:
md = md["sc2ts"]
if flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH) > 0:
if (
flags & (core.NODE_IS_MUTATION_OVERLAP | core.NODE_IS_REVERSION_PUSH)
> 0
):
try:
strain = f"{md['date_added']}:{', '.join(md['mutations'])}"
except KeyError:
Expand Down Expand Up @@ -883,16 +901,50 @@ def samples_summary(self):
df["total_hmm_cost"] = df["mean_hmm_cost"] * df["total"]
return df.astype({"date": "datetime64[s]"})

def sample_groups_summary(self):
data = []
for group_id, nodes in self.sample_group_nodes.items():
samples = []
full_hashes = []
for u in nodes:
node = self.ts.node(u)
if node.is_sample():
samples.append(u)
full_hashes.append(node.metadata["sc2ts"]["group_id"])
assert len(set(full_hashes)) == 1
assert full_hashes[0].startswith(group_id)
data.append(
{
"group_id": group_id,
"nodes": len(nodes),
"samples": len(samples),
"mutations": len(self.sample_group_mutations[group_id]),
"is_retro": group_id in self.retro_sample_groups,
}
)
return pd.DataFrame(data).set_index("group_id")

def retro_sample_groups_summary(self):
data = []
for group_id, retro_group in self.retro_sample_groups.items():
d = dict(retro_group)
d["group_id"] = group_id
d["dates"] = len(set(d["dates"]))
d["samples"] = len(d.pop("strains"))
d["pango_lineages"] = len(set(d["pango_lineages"]))
data.append(d)
return pd.DataFrame(data).set_index("group_id")

def recombinants_summary(self):
data = []
for u in self.recombinants:
md = self.nodes_metadata[u]["sc2ts"]
group_id = md["group_id"]
group_id = md["group_id"][: self.sample_group_id_prefix_len]
# NOTE this is overlapping quite a bit with the SampleGroupInfo
# class functionality here, but we just want something quick for
# now here.
causal_lineages = collections.Counter()
for v in self.nodes_sample_group[group_id]:
for v in self.sample_group_nodes[group_id]:
if self.ts.nodes_flags[v] & tskit.NODE_IS_SAMPLE > 0:
pango = self.nodes_metadata[v].get(self.pango_source, "Unknown")
causal_lineages[pango] += 1
Expand Down Expand Up @@ -1130,7 +1182,7 @@ def css_cell(allele, bold=False, show_colour=True):
parent_allele = var.alleles[var.genotypes[j]]
css = css_cell(
parent_allele,
bold=parent_allele==child_allele,
bold=parent_allele == child_allele,
show_colour=j == parent_col,
)
parents[j - 1].append(f"<td{css}>{parent_allele}</td>")
Expand Down Expand Up @@ -1305,7 +1357,9 @@ def recombinant_samples_report(self, nodes):
closest_recombinant, path_length = self._get_closest_recombinant(tree, node)
sample_is_recombinant = False
if closest_recombinant != -1:
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"]["date_added"]
recomb_date = self.ts.node(closest_recombinant).metadata["sc2ts"][
"date_added"
]
sample_is_recombinant = recomb_date == str(node_summary["date"])
summary = {
"recombinant": closest_recombinant,
Expand Down Expand Up @@ -1972,12 +2026,13 @@ def draw_subtree(
def get_sample_group_info(self, group_id):
samples = []

for u in self.nodes_sample_group[group_id]:
group_nodes = self.sample_group_nodes[group_id]
for u in group_nodes:
if self.ts.nodes_flags[u] & tskit.NODE_IS_SAMPLE > 0:
samples.append(u)

tree = self.ts.first()
while self.nodes_metadata[u]["sc2ts"].get("group_id", None) == group_id:
while u in group_nodes:
u = tree.parent(u)
attach_date = self.nodes_date[u]
ts = self.ts.simplify(samples + [u])
Expand All @@ -2004,7 +2059,7 @@ def get_sample_group_info(self, group_id):

return SampleGroupInfo(
group_id,
self.nodes_sample_group[group_id],
group_nodes,
ts=tables.tree_sequence(),
attach_date=attach_date,
)
Expand Down
Loading

0 comments on commit ecf29f4

Please sign in to comment.