Skip to content

Commit

Permalink
Default subplots to time difference from first sample
Browse files Browse the repository at this point in the history
But also allow calendar date, if required.

This makes it much easier to compare plots from ARGs with different endpoints
  • Loading branch information
hyanwong committed Nov 19, 2024
1 parent 3dd5b05 commit dff40d5
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,9 @@ def draw_subtree(
extra_tracked_samples=None,
pack_untracked_polytomies=True,
time_scale="rank",
y_ticks=None,
y_label=None,
date_format=None,
title=None,
mutation_labels=None,
append_mutation_recurrence=None,
Expand Down Expand Up @@ -1647,6 +1650,12 @@ def draw_subtree(
containing tracked nodes. If ``None`` (default), do not collapse
any such clades , otherwise only collapse them when they contain more
than a fraction `collapse_tracked` of tracked nodes.
:param date_format str: How to format the displayed node date: one of "ts" (use
the tree sequence time, usually where the most recent sample is at zero),
"from_zero" (set the earliest sample to zero and count time as negative
days from the start of the genealogy), or "cal" (use the calendar date from
the sample metadata, hence do not display times for nonsample nodes).
Default: ``None`` treated as "from_zero".
:param remove_clones bool: Whether to remove samples that are clones of other
samples (i.e. that have no mutations above them). Currently unimplemented.
:param extra_tracked_samples list: Additional nodes to track in the tree, to
Expand All @@ -1658,11 +1667,14 @@ def draw_subtree(
:param mutation_labels dict: A dictionary mapping mutation IDs to labels. If not
provided, mutation labels are generated automatically, in the form
``{inherited_state}{position}{derived_state}``
:params append_mutation_dupes bool: If True (default), append a count to the
:param append_mutation_dupes bool: If True (default), append a count to the
mutation label indicating the number of other such mutations above the
shown nodes that are at the same position and to the same derived state.
:param time_scale str: As for the ``time_scale`` parameter of `draw_svg()`, but
defaults to "rank".
:param y_label str: As for the ``y_label`` parameter of `draw_svg()`.
:param y_ticks array: As for the ``y_ticks`` parameter of `draw_svg()`. Cannot be
combined with ``date_format="cal"``.
.. note::
By default, styles are set such that tracked pango / strain / sample nodes
Expand All @@ -1676,6 +1688,15 @@ def draw_subtree(
position = 21563 # pick the start of the spike
if size is None:
size = (1000, 1000)
if date_format is None:
date_format = "from_zero"
if y_label is None:
if date_format == "cal":
y_label = "Calendar date"
elif date_format == "ts":
y_label = f"Time ({self.ts.time_units} ago)"
else:
y_label = f"Time difference from earliest sample ({self.ts.time_units})"
if append_mutation_recurrence is None:
append_mutation_recurrence = True
if remove_clones:
Expand All @@ -1686,7 +1707,14 @@ def draw_subtree(
time = tables.mutations.time
time[:] = tskit.UNKNOWN_TIME
tables.mutations.time = time
#rescale to negative times
if date_format == "from_zero":
for node in reversed(self.ts.nodes(order="timeasc")):
if node.is_sample():
break
tables.nodes.time = tables.nodes.time - node.time
ts = tables.tree_sequence()


tracked_nodes = []
if tracked_pango is not None:
Expand Down Expand Up @@ -1725,7 +1753,7 @@ def draw_subtree(

if title is None:
title = f"Sc2ts genealogy of {len(tracked_nodes)} samples. "
simplified_ts = self.ts.simplify(
simplified_ts = ts.simplify(
order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]]
)
num_trees = simplified_ts.num_trees
Expand Down Expand Up @@ -1822,15 +1850,32 @@ def draw_subtree(
",".join([f".node.n{u} > .sym" for u in re_nodes])
+ f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}"
)
if date_format == "cal":
if y_ticks is not None:
raise ValueError("Cannot set y_ticks when date_format='cal'")
shown_times = ts.nodes_time[shown_nodes]
if time_scale == "rank":
_, index = np.unique(shown_times, return_index=True)
y_ticks = {
i: ts.node(shown_nodes[t]).metadata.get("date", "")
for i, t in enumerate(index)
}
else:
# only place ticks at the sample nodes
y_ticks = {t: ts.node(u).metadata.get("date", "")
for u, t in zip(shown_nodes, shown_times)
}
return tree.draw_svg(
time_scale=time_scale,
y_axis=True,
x_axis=False,
y_label=y_label,
title=title,
size=size,
order=order,
mutation_labels=mut_labels,
all_edge_mutations=True,
y_ticks=y_ticks,
symbol_size=symbol_size,
pack_untracked_polytomies=pack_untracked_polytomies,
style="".join(styles) + style,
Expand Down

0 comments on commit dff40d5

Please sign in to comment.