From 181b2295e769148404cf284a504abf5346220d07 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 8 Nov 2024 13:03:40 +0000 Subject: [PATCH] Remove lineages.py and functions that use it --- sc2ts/lineages.py | 116 -------------------------------------- sc2ts/utils.py | 96 ++----------------------------- tests/data/cache/match.db | Bin 0 -> 16384 bytes 3 files changed, 4 insertions(+), 208 deletions(-) delete mode 100644 sc2ts/lineages.py create mode 100644 tests/data/cache/match.db diff --git a/sc2ts/lineages.py b/sc2ts/lineages.py deleted file mode 100644 index e56cd59..0000000 --- a/sc2ts/lineages.py +++ /dev/null @@ -1,116 +0,0 @@ -import json -from collections import defaultdict -import pandas as pd - - -class MutationContainer: - def __init__(self): - self.names = {} - self.positions = [] - self.alts = [] - self.size = 0 - self.all_positions = set() - - def add_root(self, root_lineage_name): - self.names[root_lineage_name] = self.size - self.size += 1 - self.positions.append([]) - self.alts.append([]) - - def add_item(self, item, position, alt): - if item not in self.names: - self.names[item] = self.size - self.positions.append([position]) - self.alts.append([alt]) - self.size += 1 - else: - index = self.names[item] - self.positions[index].append(position) - self.alts[index].append(alt) - if position not in self.all_positions: - self.all_positions.add(position) - - def get_mutations(self, item): - index = self.names[item] - return self.positions[index], self.alts[index] - - -def read_in_mutations(json_filepath, verbose=False): - """ - Read in lineage-defining mutations from COVIDCG input json file. - Assumes root lineage is B. - """ - - with open(json_filepath, "r") as file: - linmuts = json.load(file) - - # Read in lineage defining mutations - linmuts_dict = MutationContainer() - linmuts_dict.add_root("B") - if verbose: - check_multiallelic_sites = defaultdict( - set - ) # will check how many multi-allelic sites there are - - for item in linmuts: - if item["alt"] != "-" and item["ref"] != "-": # ignoring indels - linmuts_dict.add_item(item["name"], item["pos"], item["alt"]) - if verbose: - check_multiallelic_sites[item["pos"]].add(item["ref"]) - if verbose: - check_multiallelic_sites[item["pos"]].add(item["alt"]) - - if verbose: - multiallelic_sites_count = 0 - for value in check_multiallelic_sites.values(): - if len(value) > 2: - multiallelic_sites_count += 1 - print( - "Multiallelic sites:", - multiallelic_sites_count, - "out of", - len(check_multiallelic_sites), - ) - print("Number of lineages:", linmuts_dict.size) - - return linmuts_dict - - -class OHE_transform: - """ - One hot encoder using pandas get_dummies() for dealing with categorical data (alleles at each position) - """ - - def __init__(self): - self.new_colnames = None - self.old_colnames = None - - def fit(self, X): - self.old_colnames = X.columns - X = pd.get_dummies(X, drop_first=True) - self.new_colnames = X.columns - return X - - def transform(self, X): - X = pd.get_dummies(X) - X = X.reindex(columns=self.new_colnames, fill_value=0) - return X - - -def read_in_mutations_json(json_filepath): - """ - Read in COVIDCG json file of lineage-defining mutations into a pandas data frame - """ - df = pd.read_json(json_filepath) - df = df.loc[df["ref"] != "-"] - df = df.loc[df["alt"] != "-"] - df = df.pivot_table( - index="name", columns="pos", values="alt", aggfunc="min", fill_value="." - ) - idx = df.index.append(pd.Index(["B"])) - df = df.reindex(idx, fill_value=".") - ohe = OHE_transform() - df_ohe = ohe.fit(df) - return df, df_ohe, ohe - - diff --git a/sc2ts/utils.py b/sc2ts/utils.py index daaa084..0db671f 100644 --- a/sc2ts/utils.py +++ b/sc2ts/utils.py @@ -4,26 +4,17 @@ import collections import dataclasses import itertools -import operator -import warnings -import datetime -import logging import tskit -import tszip import numpy as np -import pandas as pd import tqdm import matplotlib.pyplot as plt from matplotlib import colors from IPython.display import Markdown, HTML import networkx as nx -import numba import sc2ts -from . import core -from . import lineages @dataclasses.dataclass @@ -255,7 +246,7 @@ def plot_subgraph( nodes, ts, ti=None, - mutations_json_filepath=None, + show_mutation_positions=None, # NB - can pass linmuts.all_positions filepath=None, *, ax=None, @@ -284,10 +275,10 @@ def plot_subgraph( ``None`` calculate the TreeInfo within this function. However, as calculating the TreeInfo class takes some time, if you have it calculated already, it is far more efficient to pass it in here. - :param str mutations_json_filepath: The path to a list of mutations (only relevant + :param str show_mutation_positions: A set of integer positions (only relevant if ``edge_labels`` is ``None``). If provided, only mutations in this file will be listed on edges of the plot, with others shown as "+N mutations". If ``None`` - (default), list all mutations. If "", only plot the number of mutations. + (default), show all mutations. If the empty set, only plot the number of mutations. :param str filepath: If given, save the plot to this file path. :param plt.Axes ax: a matplotlib axis object on which to plot the graph. This allows the graph to be placed as a subplot or the size and aspect ratio @@ -394,15 +385,6 @@ def sort_mutation_label(s): "show_descendant_samples must be one of 'samples', 'tips', 'sample_tips', 'all', or '' / False" ) - # Read in characteristic mutations info - linmuts_dict = None - if mutations_json_filepath is not None: - if mutations_json_filepath == "": - TmpClass = collections.namedtuple("TmpClass", ["all_positions"]) - linmuts_dict = TmpClass({}) # an empty dict - else: - linmuts_dict = lineages.read_in_mutations(mutations_json_filepath) - exterior_edges = None if exterior_edge_len != 0: G, exterior_edges = to_nx_subgraph(ts, nodes, return_external_edges=True) @@ -464,7 +446,7 @@ def sort_mutation_label(s): mutstr = f"$\\bf{{{inherited_state.upper()}{pos}{m.derived_state.upper()}}}$" else: mutstr = f"{inherited_state.upper()}{pos}{m.derived_state.upper()}" - if linmuts_dict is None or pos in linmuts_dict.all_positions: + if show_mutation_positions is None or pos in show_mutation_positions: includemut = True if includemut: mutation_labels[(edge.parent, edge.child)].add(mutstr) @@ -753,76 +735,6 @@ def sample_subgraph(sample_node, ts, ti=None, **kwargs): return plot_subgraph(nodes, ts, ti, **kwargs) -def add_gisaid_lineages_to_ts(ts, node_gisaid_lineages, linmuts_dict): - """ - Adds lineages from GISAID to ts metadata (as 'GISAID_lineage'). - """ - tables = ts.tables - new_metadata = [] - ndiffs = 0 - for node in ts.nodes(): - md = node.metadata - if node_gisaid_lineages[node.id] is not None: - if node_gisaid_lineages[node.id] in linmuts_dict.names: - md["GISAID_lineage"] = str(node_gisaid_lineages[node.id]) - else: - md["GISAID_lineage"] = md["Nextclade_pango"] - ndiffs += 1 - new_metadata.append(md) - validated_metadata = [ - tables.nodes.metadata_schema.validate_and_encode_row(row) - for row in new_metadata - ] - tables.nodes.packset_metadata(validated_metadata) - edited_ts = tables.tree_sequence() - print("Filling in missing GISAID lineages with Nextclade lineages:", ndiffs) - return edited_ts - - -# NOTE: this is broken since moving to Viridian metadata, we no longer have -# GISAID EPI ISL in the metadata -def check_lineages( - ts, - ti, - gisaid_data, - linmuts_dict, - diff_filehandle="lineage_disagreement", -): - n_diffs = 0 - total = 0 - diff_file = diff_filehandle + ".csv" - node_gisaid_lineages = [None] * ts.num_nodes - with tqdm.tqdm(total=len(gisaid_data)) as pbar: - with open(diff_file, "w") as file: - file.write("sample_node,gisaid_epi_isl,gisaid_lineage,ts_lineage\n") - for gisaid_id, gisaid_lineage in gisaid_data: - if gisaid_id in ti.epi_isl_map: - sample_node = ts.node(ti.epi_isl_map[gisaid_id]) - if gisaid_lineage != sample_node.metadata["Nextclade_pango"]: - n_diffs += 1 - file.write( - str(sample_node.id) - + "," - + gisaid_id - + "," - + gisaid_lineage - + "," - + sample_node.metadata["Nextclade_pango"] - + "\n" - ) - node_gisaid_lineages[sample_node.id] = gisaid_lineage - total += 1 - pbar.update(1) - print("ts number of samples:", ts.num_samples) - print("number matched to gisaid data:", total) - print("number of differences:", n_diffs) - print("proportion:", n_diffs / total) - - edited_ts = add_gisaid_lineages_to_ts(ts, node_gisaid_lineages, linmuts_dict) - - return edited_ts - - def compute_left_bound(ts, parents, right): right_index = np.searchsorted(ts.sites_position, right) assert ts.sites_position[right_index] == right diff --git a/tests/data/cache/match.db b/tests/data/cache/match.db new file mode 100644 index 0000000000000000000000000000000000000000..c924799372f88a47dfe6e0ffc3dbba41b6e7e382 GIT binary patch literal 16384 zcmeI#%TB^T6oBDrA)+J*I~LuX9gwJrapB5D&@PDe0CBn=+RZ#xEH7~7c{uJnft;y~(lm7`qLk9Z5)aEVHrv{A;GbGo zC-0+F`dh^gKUMl5{S{UO0tg_000IagfB*srAb;{5i8>{plt1AIW5*jv&8*js)&^EtwhVUreY-vM0|~cpte0RaRMKmY**5I_I{1Q0*~f!!DQ0eh%%RsaA1 literal 0 HcmV?d00001