Skip to content

Commit

Permalink
Merge pull request #430 from jeromekelleher/misc-1.0-updates
Browse files Browse the repository at this point in the history
Misc 1.0 updates
  • Loading branch information
jeromekelleher authored Dec 8, 2024
2 parents 89421fb + b221fd4 commit 53be6d9
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 218 deletions.
32 changes: 27 additions & 5 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,36 @@ def import_metadata(dataset, metadata, viridian, verbose):
@click.command()
@click.argument("in_dataset", type=click.Path(dir_okay=True, file_okay=False))
@click.argument("out_dataset", type=click.Path(dir_okay=True, file_okay=False))
@click.option("--date-field", default="date", help="The metadata field to use for dates")
@click.option("-a", "--additional-field", default=[], help="Additional fields to sort by",
multiple=True)
@click.option(
"--date-field", default="date", help="The metadata field to use for dates"
)
@click.option(
"-a",
"--additional-field",
default=[],
help="Additional fields to sort by",
multiple=True,
)
@chunk_cache_size
@progress
@verbose
def reorder_dataset(in_dataset, out_dataset, chunk_cache_size, date_field, additional_field, progress, verbose):
def reorder_dataset(
in_dataset,
out_dataset,
chunk_cache_size,
date_field,
additional_field,
progress,
verbose,
):
"""
Create a copy of the specified dataset where the samples are reordered by
date (and optionally other fields).
"""
setup_logging(verbose)
ds = sc2ts.Dataset(in_dataset, chunk_cache_size=chunk_cache_size, date_field=date_field)
ds = sc2ts.Dataset(
in_dataset, chunk_cache_size=chunk_cache_size, date_field=date_field
)
ds.reorder(out_dataset, show_progress=progress, additional_fields=additional_field)


Expand Down Expand Up @@ -638,6 +655,11 @@ def extend(
logger.info(resource_usage)
if progress:
print(resource_usage, file=sys.stderr)
df = pd.DataFrame(
ts_out.metadata["sc2ts"]["daily_stats"][date]["samples_processed"]
).set_index("scorpio")
df = df[list(df.columns)[::-1]].sort_values("total")
print(df)


@click.command()
Expand Down
90 changes: 84 additions & 6 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
import collections.abc
import csv

import tskit
import numba
import pyfaidx
import numpy as np


__version__ = "undefined"
try:
from . import _version

__version__ = _version.version
except ImportError:
pass

TIME_UNITS = "days"

REFERENCE_STRAIN = "Wuhan/Hu-1/2019"
Expand All @@ -23,15 +32,84 @@
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25
NODE_IN_SAMPLE_GROUP = 1 << 26
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP = 1 << 27
NODE_IS_REFERENCE = 1 << 28
NODE_IS_UNCONDITIONALLY_INCLUDED = 1 << 29


__version__ = "undefined"
try:
from . import _version
@dataclasses.dataclass(frozen=True)
class FlagValue:
value: int
short: str
long: str
description: str

__version__ = _version.version
except ImportError:
pass

flag_values = [
FlagValue(tskit.NODE_IS_SAMPLE, "S", "Sample", "Tskit defined sample node"),
FlagValue(
NODE_IS_MUTATION_OVERLAP,
"O",
"MutationOverlap",
"Node created by coalescing mutations shared by siblings",
),
FlagValue(
NODE_IS_REVERSION_PUSH,
"P",
"ReversionPush",
"Node created by pushing immediate reversions upwards",
),
FlagValue(
NODE_IS_RECOMBINANT,
"R",
"Recombinant",
"Node has two or more parents",
),
FlagValue(
NODE_IS_EXACT_MATCH,
"E",
"ExactMatch",
"Node is an exact match of its parent",
),
FlagValue(
NODE_IS_IMMEDIATE_REVERSION_MARKER,
"I",
"ImmediateReversion",
"Node is marking the existance of an immediate reversion which "
"has not been removed for technical reasons",
),
FlagValue(
NODE_IN_SAMPLE_GROUP,
"G",
"SampleGroup",
"Node is a member of a sample group",
),
FlagValue(
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
"Q",
"RetroSampleGroup",
"Node is a member of a retrospective sample group",
),
FlagValue(
NODE_IS_REFERENCE,
"F",
"Reference",
"Node is a reference sequence",
),
FlagValue(
NODE_IS_UNCONDITIONALLY_INCLUDED,
"U",
"UnconditionalInclude",
"A sample that was flagged for unconditional inclusion",
),
]


def decode_flags(f):
return [v for v in flag_values if (v.value & f) > 0]


def flags_summary(f):
return "".join([v.short if (v.value & f) > 0 else "_" for v in flag_values])


class FastaReader(collections.abc.Mapping):
Expand Down
85 changes: 44 additions & 41 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ def initial_ts(problematic_sites=list()):
tables.metadata = {
"sc2ts": {
"date": core.REFERENCE_DATE,
"samples_strain": [core.REFERENCE_STRAIN],
"exact_matches": {
"pango": {},
"date": {},
"node": {},
"samples_strain": [],
"daily_stats": {},
"cumulative_stats": {
"exact_matches": {
"pango": {},
"node": {},
}
},
"samples_processed": {},
"samples_rejected": {},
"retro_groups": [],
}
}
Expand Down Expand Up @@ -302,14 +302,16 @@ def initial_ts(problematic_sites=list()):
},
)
tables.nodes.add_row(
flags=tskit.NODE_IS_SAMPLE,
flags=core.NODE_IS_REFERENCE,
time=0,
metadata={
"strain": core.REFERENCE_STRAIN,
"date": core.REFERENCE_DATE,
"sc2ts": {"notes": "Reference sequence"},
},
)
# NOTE: we don't actually need to store this edge, we could include it
# at the point where we call the low-level match_tsinfer operations.
tables.edges.add_row(0, L, 0, 1)
return tables.tree_sequence()

Expand Down Expand Up @@ -345,6 +347,7 @@ class Sample:
haplotype: List = None
hmm_match: HmmMatch = None
hmm_reruns: Dict = dataclasses.field(default_factory=dict)
flags: int = tskit.NODE_IS_SAMPLE

@property
def is_recombinant(self):
Expand Down Expand Up @@ -643,6 +646,7 @@ def extend(
f"deletions={num_deletion_sites}"
)
if s.strain in include_strains:
s.flags |= core.NODE_IS_UNCONDITIONALLY_INCLUDED
unconditional_include_samples.append(s)
elif num_missing_sites <= max_missing_sites:
samples.append(s)
Expand Down Expand Up @@ -742,36 +746,42 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
s = node.metadata["strain"]
samples_strain.append(s)
inserted_samples.add(s)
md["sc2ts"]["samples_strain"] = samples_strain

overall_processed = collections.Counter()
overall_hmm_cost = collections.Counter()
overall_processed = collections.defaultdict(list)
rejected = collections.Counter()
rejected_hmm_cost = collections.Counter()
for sample in samples:
overall_processed[sample.scorpio] += 1
overall_hmm_cost[sample.scorpio] += float(sample.hmm_match.cost)
overall_processed[sample.scorpio].append(sample.hmm_match.cost)
if sample.strain not in inserted_samples and sample.hmm_match.cost > 0:
rejected[sample.scorpio] += 1
rejected[sample.scorpio] += float(sample.hmm_match.cost)

for scorpio in overall_processed.keys():
overall_hmm_cost[scorpio] /= overall_processed[scorpio]
for scorpio in rejected.keys():
rejected_hmm_cost[scorpio] /= rejected[scorpio]
samples_processed = []
for scorpio, hmm_cost in overall_processed.items():
hmm_cost = np.array(hmm_cost)
samples_processed.append(
{
"scorpio": scorpio,
"total": hmm_cost.shape[0],
"rejected": rejected[scorpio],
"exact_matches": int(np.sum(hmm_cost == 0)),
# Store the total as well to make later aggregation easier
"total_hmm_cost": float(np.sum(hmm_cost)),
"mean_hmm_cost": round(float(np.mean(hmm_cost)), 2),
"median_hmm_cost": float(np.median(hmm_cost)),
}
)

md["sc2ts"]["samples_strain"] = samples_strain
md["sc2ts"]["samples_processed"][date] = {
"count": dict(overall_processed),
"mean_hmm_cost": dict(overall_hmm_cost),
}
md["sc2ts"]["samples_rejected"][date] = {
"count": dict(rejected),
"mean_hmm_cost": dict(rejected_hmm_cost),
daily_stats = {
"samples_processed": samples_processed,
"arg": {
"nodes": ts.num_nodes,
"edges": ts.num_edges,
"mutations": ts.num_mutations,
},
}
md["sc2ts"]["daily_stats"][date] = daily_stats

existing_retro_groups = md["sc2ts"].get("retro_groups", [])
if isinstance(existing_retro_groups, dict):
# Hack to implement metadata format change
existing_retro_groups = []
for group in retro_groups:
d = group.tree_quality_metrics.asdict()
d["group_id"] = group.sample_hash
Expand All @@ -781,7 +791,7 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
return tables.tree_sequence()


def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, group_id=None):
def add_sample_to_tables(sample, tables, group_id=None):
sc2ts_md = {
"hmm_match": sample.hmm_match.asdict(),
"hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()},
Expand All @@ -791,7 +801,7 @@ def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, group_id=No
if group_id is not None:
sc2ts_md["group_id"] = group_id
metadata = {**sample.metadata, "sc2ts": sc2ts_md}
return tables.nodes.add_row(flags=flags, metadata=metadata)
return tables.nodes.add_row(flags=sample.flags, metadata=metadata)


def match_path_ts(group):
Expand Down Expand Up @@ -849,17 +859,10 @@ def add_exact_matches(match_db, ts, date):
# JSON treats dictionary keys as strings
node_counts[str(parent)] += 1
pango_counts[sample.pango] += 1
# node_id = add_sample_to_tables(
# sample,
# tables,
# flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH,
# )
# logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}")
# tables.edges.add_row(0, ts.sequence_length, parent=parent, child=node_id)
tables = ts.dump_tables()
md = tables.metadata
exact_matches_md = md["sc2ts"]["exact_matches"]
exact_matches_md["date"][date] = sum(pango_counts.values())
cstats = md["sc2ts"]["cumulative_stats"]
exact_matches_md = cstats["exact_matches"]
pango_counts.update(exact_matches_md["pango"])
exact_matches_md["pango"] = dict(pango_counts)
node_counts.update(exact_matches_md["node"])
Expand Down Expand Up @@ -1104,7 +1107,7 @@ def add_matching_results(
tables.compute_mutation_parents()
ts = tables.tree_sequence()
ts = tree_ops.push_up_reversions(ts, attach_nodes, date)
ts = tree_ops.coalesce_mutations(ts, attach_nodes)
ts = tree_ops.coalesce_mutations(ts, attach_nodes, date)
ts = delete_immediate_reversion_nodes(ts, attach_nodes)
return ts, added_groups

Expand Down
Loading

0 comments on commit 53be6d9

Please sign in to comment.