From dff40d5dd1771daeddf430908035173adb6fefc9 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sun, 17 Nov 2024 16:35:35 +0000 Subject: [PATCH] Default subplots to time difference from first sample But also allow calendar date, if required. This makes it much easier to compare plots from ARGs with different endpoints --- sc2ts/info.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index 2fe0292..748b7f2 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1617,6 +1617,9 @@ def draw_subtree( extra_tracked_samples=None, pack_untracked_polytomies=True, time_scale="rank", + y_ticks=None, + y_label=None, + date_format=None, title=None, mutation_labels=None, append_mutation_recurrence=None, @@ -1647,6 +1650,12 @@ def draw_subtree( containing tracked nodes. If ``None`` (default), do not collapse any such clades , otherwise only collapse them when they contain more than a fraction `collapse_tracked` of tracked nodes. + :param date_format str: How to format the displayed node date: one of "ts" (use + the tree sequence time, usually where the most recent sample is at zero), + "from_zero" (set the earliest sample to zero and count time as negative + days from the start of the genealogy), or "cal" (use the calendar date from + the sample metadata, hence do not display times for nonsample nodes). + Default: ``None`` treated as "from_zero". :param remove_clones bool: Whether to remove samples that are clones of other samples (i.e. that have no mutations above them). Currently unimplemented. :param extra_tracked_samples list: Additional nodes to track in the tree, to @@ -1658,11 +1667,14 @@ def draw_subtree( :param mutation_labels dict: A dictionary mapping mutation IDs to labels. If not provided, mutation labels are generated automatically, in the form ``{inherited_state}{position}{derived_state}`` - :params append_mutation_dupes bool: If True (default), append a count to the + :param append_mutation_dupes bool: If True (default), append a count to the mutation label indicating the number of other such mutations above the shown nodes that are at the same position and to the same derived state. :param time_scale str: As for the ``time_scale`` parameter of `draw_svg()`, but defaults to "rank". + :param y_label str: As for the ``y_label`` parameter of `draw_svg()`. + :param y_ticks array: As for the ``y_ticks`` parameter of `draw_svg()`. Cannot be + combined with ``date_format="cal"``. .. note:: By default, styles are set such that tracked pango / strain / sample nodes @@ -1676,6 +1688,15 @@ def draw_subtree( position = 21563 # pick the start of the spike if size is None: size = (1000, 1000) + if date_format is None: + date_format = "from_zero" + if y_label is None: + if date_format == "cal": + y_label = "Calendar date" + elif date_format == "ts": + y_label = f"Time ({self.ts.time_units} ago)" + else: + y_label = f"Time difference from earliest sample ({self.ts.time_units})" if append_mutation_recurrence is None: append_mutation_recurrence = True if remove_clones: @@ -1686,7 +1707,14 @@ def draw_subtree( time = tables.mutations.time time[:] = tskit.UNKNOWN_TIME tables.mutations.time = time + #rescale to negative times + if date_format == "from_zero": + for node in reversed(self.ts.nodes(order="timeasc")): + if node.is_sample(): + break + tables.nodes.time = tables.nodes.time - node.time ts = tables.tree_sequence() + tracked_nodes = [] if tracked_pango is not None: @@ -1725,7 +1753,7 @@ def draw_subtree( if title is None: title = f"Sc2ts genealogy of {len(tracked_nodes)} samples. " - simplified_ts = self.ts.simplify( + simplified_ts = ts.simplify( order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]] ) num_trees = simplified_ts.num_trees @@ -1822,15 +1850,32 @@ def draw_subtree( ",".join([f".node.n{u} > .sym" for u in re_nodes]) + f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}" ) + if date_format == "cal": + if y_ticks is not None: + raise ValueError("Cannot set y_ticks when date_format='cal'") + shown_times = ts.nodes_time[shown_nodes] + if time_scale == "rank": + _, index = np.unique(shown_times, return_index=True) + y_ticks = { + i: ts.node(shown_nodes[t]).metadata.get("date", "") + for i, t in enumerate(index) + } + else: + # only place ticks at the sample nodes + y_ticks = {t: ts.node(u).metadata.get("date", "") + for u, t in zip(shown_nodes, shown_times) + } return tree.draw_svg( time_scale=time_scale, y_axis=True, x_axis=False, + y_label=y_label, title=title, size=size, order=order, mutation_labels=mut_labels, all_edge_mutations=True, + y_ticks=y_ticks, symbol_size=symbol_size, pack_untracked_polytomies=pack_untracked_polytomies, style="".join(styles) + style,