diff --git a/sc2ts/lineages.py b/sc2ts/lineages.py deleted file mode 100644 index e56cd59..0000000 --- a/sc2ts/lineages.py +++ /dev/null @@ -1,116 +0,0 @@ -import json -from collections import defaultdict -import pandas as pd - - -class MutationContainer: - def __init__(self): - self.names = {} - self.positions = [] - self.alts = [] - self.size = 0 - self.all_positions = set() - - def add_root(self, root_lineage_name): - self.names[root_lineage_name] = self.size - self.size += 1 - self.positions.append([]) - self.alts.append([]) - - def add_item(self, item, position, alt): - if item not in self.names: - self.names[item] = self.size - self.positions.append([position]) - self.alts.append([alt]) - self.size += 1 - else: - index = self.names[item] - self.positions[index].append(position) - self.alts[index].append(alt) - if position not in self.all_positions: - self.all_positions.add(position) - - def get_mutations(self, item): - index = self.names[item] - return self.positions[index], self.alts[index] - - -def read_in_mutations(json_filepath, verbose=False): - """ - Read in lineage-defining mutations from COVIDCG input json file. - Assumes root lineage is B. - """ - - with open(json_filepath, "r") as file: - linmuts = json.load(file) - - # Read in lineage defining mutations - linmuts_dict = MutationContainer() - linmuts_dict.add_root("B") - if verbose: - check_multiallelic_sites = defaultdict( - set - ) # will check how many multi-allelic sites there are - - for item in linmuts: - if item["alt"] != "-" and item["ref"] != "-": # ignoring indels - linmuts_dict.add_item(item["name"], item["pos"], item["alt"]) - if verbose: - check_multiallelic_sites[item["pos"]].add(item["ref"]) - if verbose: - check_multiallelic_sites[item["pos"]].add(item["alt"]) - - if verbose: - multiallelic_sites_count = 0 - for value in check_multiallelic_sites.values(): - if len(value) > 2: - multiallelic_sites_count += 1 - print( - "Multiallelic sites:", - multiallelic_sites_count, - "out of", - len(check_multiallelic_sites), - ) - print("Number of lineages:", linmuts_dict.size) - - return linmuts_dict - - -class OHE_transform: - """ - One hot encoder using pandas get_dummies() for dealing with categorical data (alleles at each position) - """ - - def __init__(self): - self.new_colnames = None - self.old_colnames = None - - def fit(self, X): - self.old_colnames = X.columns - X = pd.get_dummies(X, drop_first=True) - self.new_colnames = X.columns - return X - - def transform(self, X): - X = pd.get_dummies(X) - X = X.reindex(columns=self.new_colnames, fill_value=0) - return X - - -def read_in_mutations_json(json_filepath): - """ - Read in COVIDCG json file of lineage-defining mutations into a pandas data frame - """ - df = pd.read_json(json_filepath) - df = df.loc[df["ref"] != "-"] - df = df.loc[df["alt"] != "-"] - df = df.pivot_table( - index="name", columns="pos", values="alt", aggfunc="min", fill_value="." - ) - idx = df.index.append(pd.Index(["B"])) - df = df.reindex(idx, fill_value=".") - ohe = OHE_transform() - df_ohe = ohe.fit(df) - return df, df_ohe, ohe - - diff --git a/sc2ts/utils.py b/sc2ts/utils.py index daaa084..0db671f 100644 --- a/sc2ts/utils.py +++ b/sc2ts/utils.py @@ -4,26 +4,17 @@ import collections import dataclasses import itertools -import operator -import warnings -import datetime -import logging import tskit -import tszip import numpy as np -import pandas as pd import tqdm import matplotlib.pyplot as plt from matplotlib import colors from IPython.display import Markdown, HTML import networkx as nx -import numba import sc2ts -from . import core -from . import lineages @dataclasses.dataclass @@ -255,7 +246,7 @@ def plot_subgraph( nodes, ts, ti=None, - mutations_json_filepath=None, + show_mutation_positions=None, # NB - can pass linmuts.all_positions filepath=None, *, ax=None, @@ -284,10 +275,10 @@ def plot_subgraph( ``None`` calculate the TreeInfo within this function. However, as calculating the TreeInfo class takes some time, if you have it calculated already, it is far more efficient to pass it in here. - :param str mutations_json_filepath: The path to a list of mutations (only relevant + :param str show_mutation_positions: A set of integer positions (only relevant if ``edge_labels`` is ``None``). If provided, only mutations in this file will be listed on edges of the plot, with others shown as "+N mutations". If ``None`` - (default), list all mutations. If "", only plot the number of mutations. + (default), show all mutations. If the empty set, only plot the number of mutations. :param str filepath: If given, save the plot to this file path. :param plt.Axes ax: a matplotlib axis object on which to plot the graph. This allows the graph to be placed as a subplot or the size and aspect ratio @@ -394,15 +385,6 @@ def sort_mutation_label(s): "show_descendant_samples must be one of 'samples', 'tips', 'sample_tips', 'all', or '' / False" ) - # Read in characteristic mutations info - linmuts_dict = None - if mutations_json_filepath is not None: - if mutations_json_filepath == "": - TmpClass = collections.namedtuple("TmpClass", ["all_positions"]) - linmuts_dict = TmpClass({}) # an empty dict - else: - linmuts_dict = lineages.read_in_mutations(mutations_json_filepath) - exterior_edges = None if exterior_edge_len != 0: G, exterior_edges = to_nx_subgraph(ts, nodes, return_external_edges=True) @@ -464,7 +446,7 @@ def sort_mutation_label(s): mutstr = f"$\\bf{{{inherited_state.upper()}{pos}{m.derived_state.upper()}}}$" else: mutstr = f"{inherited_state.upper()}{pos}{m.derived_state.upper()}" - if linmuts_dict is None or pos in linmuts_dict.all_positions: + if show_mutation_positions is None or pos in show_mutation_positions: includemut = True if includemut: mutation_labels[(edge.parent, edge.child)].add(mutstr) @@ -753,76 +735,6 @@ def sample_subgraph(sample_node, ts, ti=None, **kwargs): return plot_subgraph(nodes, ts, ti, **kwargs) -def add_gisaid_lineages_to_ts(ts, node_gisaid_lineages, linmuts_dict): - """ - Adds lineages from GISAID to ts metadata (as 'GISAID_lineage'). - """ - tables = ts.tables - new_metadata = [] - ndiffs = 0 - for node in ts.nodes(): - md = node.metadata - if node_gisaid_lineages[node.id] is not None: - if node_gisaid_lineages[node.id] in linmuts_dict.names: - md["GISAID_lineage"] = str(node_gisaid_lineages[node.id]) - else: - md["GISAID_lineage"] = md["Nextclade_pango"] - ndiffs += 1 - new_metadata.append(md) - validated_metadata = [ - tables.nodes.metadata_schema.validate_and_encode_row(row) - for row in new_metadata - ] - tables.nodes.packset_metadata(validated_metadata) - edited_ts = tables.tree_sequence() - print("Filling in missing GISAID lineages with Nextclade lineages:", ndiffs) - return edited_ts - - -# NOTE: this is broken since moving to Viridian metadata, we no longer have -# GISAID EPI ISL in the metadata -def check_lineages( - ts, - ti, - gisaid_data, - linmuts_dict, - diff_filehandle="lineage_disagreement", -): - n_diffs = 0 - total = 0 - diff_file = diff_filehandle + ".csv" - node_gisaid_lineages = [None] * ts.num_nodes - with tqdm.tqdm(total=len(gisaid_data)) as pbar: - with open(diff_file, "w") as file: - file.write("sample_node,gisaid_epi_isl,gisaid_lineage,ts_lineage\n") - for gisaid_id, gisaid_lineage in gisaid_data: - if gisaid_id in ti.epi_isl_map: - sample_node = ts.node(ti.epi_isl_map[gisaid_id]) - if gisaid_lineage != sample_node.metadata["Nextclade_pango"]: - n_diffs += 1 - file.write( - str(sample_node.id) - + "," - + gisaid_id - + "," - + gisaid_lineage - + "," - + sample_node.metadata["Nextclade_pango"] - + "\n" - ) - node_gisaid_lineages[sample_node.id] = gisaid_lineage - total += 1 - pbar.update(1) - print("ts number of samples:", ts.num_samples) - print("number matched to gisaid data:", total) - print("number of differences:", n_diffs) - print("proportion:", n_diffs / total) - - edited_ts = add_gisaid_lineages_to_ts(ts, node_gisaid_lineages, linmuts_dict) - - return edited_ts - - def compute_left_bound(ts, parents, right): right_index = np.searchsorted(ts.sites_position, right) assert ts.sites_position[right_index] == right diff --git a/tests/data/cache/match.db b/tests/data/cache/match.db new file mode 100644 index 0000000..c924799 Binary files /dev/null and b/tests/data/cache/match.db differ