diff --git a/c/tskit/trees.c b/c/tskit/trees.c index c6c2089023..5338f4b6cf 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2762,17 +2762,21 @@ genetic_relatedness_summary_func(size_t state_dim, const double *state, tsk_id_t i, j; size_t k; double sumx = 0; - double meanx; + double sumn = 0; + double meanx, ni, nj; for (k = 0; k < state_dim; k++) { sumx += x[k]; + sumn += args.sample_set_sizes[k]; } - meanx = sumx / (double) state_dim; + meanx = sumx / sumn; for (k = 0; k < result_dim; k++) { i = args.set_indexes[2 * k]; j = args.set_indexes[2 * k + 1]; - result[k] = (x[i] - meanx) * (x[j] - meanx) / 2; + ni = args.sample_set_sizes[i]; + nj = args.sample_set_sizes[j]; + result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; } return 0; } diff --git a/python/tests/test_covariance.py b/python/tests/test_covariance.py deleted file mode 100644 index e252360a01..0000000000 --- a/python/tests/test_covariance.py +++ /dev/null @@ -1,253 +0,0 @@ -# MIT License -# -# Copyright (c) 2018-2020 Tskit Developers -# Copyright (c) 2016-2017 University of Oxford -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Test cases for covariance computation. -""" -import io -import itertools -import unittest - -import msprime -import numpy as np - -import tskit - - -def naive_genetic_relatedness(ts, proportion=True): - G = ts.genotype_matrix() - denominator = ts.sequence_length - if proportion: - all_samples = ts.samples() - num = ts.segregating_sites(all_samples) - denominator = denominator * num - G = G.T - np.mean(G, axis=1) - return G @ G.T / denominator - - -def genetic_relatedness(ts, polarised=False, proportion=True): - n = ts.num_samples - sample_sets = [[u] for u in ts.samples()] - - def f(x): - return np.array( - [ - (x[i] - sum(x) / n) * (x[j] - sum(x) / n) - for i in range(n) - for j in range(n) - ] - ) - - denominator = 2 - polarised - if proportion: - all_samples = list({u for s in sample_sets for u in s}) - num = ts.segregating_sites(all_samples) - denominator = denominator * num - return ( - ts.sample_count_stat( - sample_sets, - f, - output_dim=n * n, - mode="site", - span_normalise=True, - polarised=polarised, - ).reshape((n, n)) - / denominator - ) - - -def c_genetic_relatedness(ts, sample_sets, indexes, polarised=False, proportion=True): - m = len(indexes) - state_dim = len(sample_sets) - - def f(x): - sumx = 0 - for k in range(state_dim): - sumx += x[k] - meanx = sumx / state_dim - result = np.zeros(m) - for k in range(m): - i = indexes[k][0] - j = indexes[k][1] - result[k] = (x[i] - meanx) * (x[j] - meanx) - return result - - denominator = 2 - polarised - if proportion: - all_samples = list({u for s in sample_sets for u in s}) - num = ts.segregating_sites(all_samples) - denominator = denominator * num - return ( - ts.sample_count_stat( - sample_sets, - f, - output_dim=m, - mode="site", - span_normalise=True, - polarised=False, - strict=False, - ) - / denominator - ) - - -class TestCovariance(unittest.TestCase): - """ - Tests on covariance matrix computation - """ - - def verify(self, ts, proportion=True): - cov1 = naive_genetic_relatedness(ts, proportion=proportion) - cov2 = genetic_relatedness(ts, proportion=proportion) - sample_sets = [[u] for u in ts.samples()] - n = len(sample_sets) - indexes = [ - (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) - ] - cov3 = np.zeros((n, n)) - cov4 = np.zeros((n, n)) - i_upper = np.triu_indices(n) - cov3[i_upper] = c_genetic_relatedness( - ts, sample_sets, indexes, proportion=proportion - ) - cov3 = cov3 + cov3.T - np.diag(cov3.diagonal()) - cov4[i_upper] = ts.genetic_relatedness( - sample_sets, - indexes, - mode="site", - span_normalise=True, - proportion=proportion, - ) - cov4 = cov4 + cov4.T - np.diag(cov4.diagonal()) - assert np.allclose(cov1, cov2) - assert np.allclose(cov1, cov3) - assert np.allclose(cov1, cov4) - - def test_single_coalescent_tree(self): - ts = msprime.simulate(10, random_seed=1, length=10, mutation_rate=1) - self.verify(ts) - self.verify(ts, proportion=False) - - def test_coalescent_trees(self): - ts = msprime.simulate( - 8, recombination_rate=5, random_seed=1, length=2, mutation_rate=1 - ) - assert ts.num_trees > 2 - self.verify(ts) - self.verify(ts, proportion=False) - - def test_internal_samples(self): - nodes = io.StringIO( - """\ - id is_sample time - 0 0 0 - 1 1 0.1 - 2 1 0.1 - 3 1 0.2 - 4 0 0.4 - 5 1 0.5 - 6 0 0.7 - 7 0 1.0 - 8 0 0.8 - """ - ) - edges = io.StringIO( - """\ - left right parent child - 0.0 0.2 4 2,3 - 0.2 0.8 4 0,2 - 0.8 1.0 4 2,3 - 0.0 1.0 5 1,4 - 0.8 1.0 6 0,5 - 0.2 0.8 8 3,5 - 0.0 0.2 7 0,5 - """ - ) - sites = io.StringIO( - """\ - position ancestral_state - 0.1 0 - 0.5 0 - 0.9 0 - """ - ) - mutations = io.StringIO( - """\ - site node derived_state - 0 1 1 - 1 3 1 - 2 5 1 - """ - ) - ts = tskit.load_text( - nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False - ) - self.verify(ts) - self.verify(ts, proportion=False) - - def validate_trees(self, n): - for seed in range(1, 10): - ts = msprime.simulate( - n, random_seed=seed, recombination_rate=1, mutation_rate=2 - ) - self.verify(ts) - self.verify(ts, proportion=False) - - def test_sample_5(self): - self.validate_trees(5) - - def test_sample_10(self): - self.validate_trees(10) - - def test_sample_20(self): - self.validate_trees(20) - - def validate_nonbinary_trees(self, n): - demographic_events = [ - msprime.SimpleBottleneck(0.02, 0, proportion=0.25), - msprime.SimpleBottleneck(0.2, 0, proportion=1), - ] - - for seed in range(1, 10): - ts = msprime.simulate( - n, - random_seed=seed, - demographic_events=demographic_events, - recombination_rate=1, - mutation_rate=5, - ) - # Check if this is really nonbinary - found = False - for edgeset in ts.edgesets(): - if len(edgeset.children) > 2: - found = True - break - assert found - - self.verify(ts) - self.verify(ts, proportion=False) - - def test_non_binary_sample_10(self): - self.validate_nonbinary_trees(10) - - def test_non_binary_sample_20(self): - self.validate_nonbinary_trees(20) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index a383e28055..84222d6ba7 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1772,6 +1772,315 @@ class TestSiteDivergence(TestDivergence, MutatedTopologyExamplesMixin): mode = "site" +############################################ +# Genetic relatedness +############################################ + + +def site_genetic_relatedness( + ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True +): + out = np.zeros((len(windows) - 1, len(indexes))) + samples = [u for u in ts.samples()] + all_samples = list({u for s in sample_sets for u in s}) + sample_ind = [samples.index(x) for x in all_samples] + haps = ts.genotype_matrix(isolated_as_missing=False).T + haps = haps[sample_ind] + denom = np.ones(len(windows)) + if proportion: + denom = ts.segregating_sites( + sample_sets=all_samples, + windows=windows, + mode="site", + span_normalise=span_normalise, + ) + alleles = np.unique(haps) + for j in range(len(windows) - 1): + begin = windows[j] + end = windows[j + 1] + site_positions = [x.position for x in ts.sites()] + for i, (ix, iy) in enumerate(indexes): + X = sample_sets[ix] + Y = sample_sets[iy] + S = 0 + for a in alleles: + this_haps = haps == a + haps_mean = this_haps.mean(axis=0) + haps_centered = this_haps - haps_mean + for k in range(ts.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for x in X: + x_index = np.where(all_samples == x)[0][0] + for y in Y: + y_index = np.where(all_samples == y)[0][0] + S += ( + haps_centered[x_index][k] + * haps_centered[y_index][k] + / 2 + ) + with np.errstate(invalid="ignore", divide="ignore"): + out[j][i] = S / denom[j] + if span_normalise: + out[j][i] /= end - begin + return out + + +def branch_genetic_relatedness( + ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True +): + out = np.zeros((len(windows) - 1, len(indexes))) + all_samples = list({u for s in sample_sets for u in s}) + denom = np.ones(len(windows)) + if proportion: + denom = ts.segregating_sites( + sample_sets=all_samples, + windows=windows, + mode="branch", + span_normalise=span_normalise, + ) + for j in range(len(windows) - 1): + begin = windows[j] + end = windows[j + 1] + for tr in ts.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + branches = [(c, tr.parent(c)) for c in tr.nodes()] + span = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + for B in branches: + v = B[0] + area = tr.branch_length(v) * span + haps = np.zeros(len(all_samples)) + for x, u in enumerate(all_samples): + haps[x] = np.int(tr.is_descendant(u, v)) + haps_mean = haps.mean() + haps_centered = haps - haps_mean + for i, (ix, iy) in enumerate(indexes): + X = sample_sets[ix] + Y = sample_sets[iy] + for x in X: + x_index = np.where(all_samples == x)[0][0] + for y in Y: + y_index = np.where(all_samples == y)[0][0] + out[j][i] += ( + area * haps_centered[x_index] * haps_centered[y_index] + ) + for i in range(len(indexes)): + with np.errstate(invalid="ignore", divide="ignore"): + out[j][i] /= denom[j] + if span_normalise: + out[j][i] /= end - begin + return out + + +def node_genetic_relatedness( + ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True +): + out = np.zeros((len(windows) - 1, ts.num_nodes, len(indexes))) + all_samples = list({u for s in sample_sets for u in s}) + denom = np.ones((len(windows), ts.num_nodes)) + if proportion: + denom = ts.segregating_sites( + sample_sets=all_samples, + windows=windows, + mode="node", + span_normalise=span_normalise, + ) + for j in range(len(windows) - 1): + begin = windows[j] + end = windows[j + 1] + for tr in ts.trees(): + span = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + for v in tr.nodes(): + haps = np.zeros(len(all_samples)) + for x, u in enumerate(all_samples): + haps[x] = np.int(tr.is_descendant(u, v)) + haps_mean = haps.mean() + haps_centered = haps - haps_mean + for i, (ix, iy) in enumerate(indexes): + X = sample_sets[ix] + Y = sample_sets[iy] + for x in X: + x_index = np.where(all_samples == x)[0][0] + for y in Y: + y_index = np.where(all_samples == y)[0][0] + out[j][v][i] += ( + haps_centered[x_index] * haps_centered[y_index] * span + ) + for i in range(len(indexes)): + for v in ts.nodes(): + iV = v.id + with np.errstate(invalid="ignore", divide="ignore"): + out[j, iV, i] /= denom[j, iV] + if span_normalise: + out[j, iV, i] /= end - begin + return out + + +def genetic_relatedness( + ts, + sample_sets, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + proportion=True, +): + """ + Computes genetic relatedness between two random choices from x + over the window specified. + """ + windows = ts.parse_windows(windows) + if indexes is None: + indexes = [(0, 1)] + method_map = { + "site": site_genetic_relatedness, + "node": node_genetic_relatedness, + "branch": branch_genetic_relatedness, + } + return method_map[mode]( + ts, + sample_sets, + indexes=indexes, + windows=windows, + span_normalise=span_normalise, + proportion=proportion, + ) + + +class TestGeneticRelatedness(StatsTestCase, TwoWaySampleSetStatsMixin): + + # Derived classes define this to get a specific stats mode. + mode = None + + def verify_definition( + self, + ts, + sample_sets, + indexes, + windows, + summary_func, + ts_method, + definition, + proportion, + ): + def wrapped_summary_func(x): + with suppress_division_by_zero_warning(): + return summary_func(x) + + W = np.array([[u in A for A in sample_sets] for u in ts.samples()], dtype=float) + # Determine output_dim of the function + M = len(wrapped_summary_func(W[0])) + denom = 1 + if proportion: + all_samples = list({u for s in sample_sets for u in s}) + denom = ts.segregating_sites( + sample_sets=[all_samples], windows=windows, mode=self.mode + ) + + with np.errstate(divide="ignore", invalid="ignore"): + sigma1 = ( + ts.general_stat(W, wrapped_summary_func, M, windows, mode=self.mode) + / denom + ) + sigma2 = ( + general_stat(ts, W, wrapped_summary_func, windows, mode=self.mode) + / denom + ) + sigma3 = ts_method( + sample_sets, + indexes=indexes, + windows=windows, + mode=self.mode, + proportion=proportion, + ) + sigma4 = definition( + ts, + sample_sets, + indexes=indexes, + windows=windows, + mode=self.mode, + proportion=proportion, + ) + assert sigma1.shape == sigma2.shape + assert sigma1.shape == sigma3.shape + assert sigma1.shape == sigma4.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma1, sigma4) + + def verify_sample_sets_indexes(self, ts, sample_sets, indexes, windows): + + n = np.array([len(x) for x in sample_sets]) + n_total = sum(n) + + def f(x): + mx = np.sum(x) / n_total + return np.array( + [(x[i] - n[i] * mx) * (x[j] - n[j] * mx) / 2 for i, j in indexes] + ) + + for proportion in [True, False]: + self.verify_definition( + ts, + sample_sets, + indexes, + windows, + f, + ts.genetic_relatedness, + genetic_relatedness, + proportion, + ) + + +class TestBranchGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): + mode = "branch" + + +class TestNodeGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): + mode = "node" + + +class TestSiteGeneticRelatedness(TestGeneticRelatedness, MutatedTopologyExamplesMixin): + mode = "site" + + def test_match_K_c0(self): + # This test checks that ts.genetic_relatedness() matches K_c0 + # from Speed & Balding (2014) https://www.nature.com/articles/nrg3821 + ts = msprime.simulate( + 10, mutation_rate=0.01, length=100, recombination_rate=0.01, random_seed=23 + ) + samples = [u for u in ts.samples()] + sample_sets = [[0, 1], [2, 3], [4, 5]] + all_samples = list({u for s in sample_sets for u in s}) + sample_ind = [samples.index(x) for x in all_samples] + indexes = [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)] + A = ts.genetic_relatedness( + sample_sets, indexes=indexes, mode="site", span_normalise=False + ) + # Genotype covariance as in Speed and Balding + G = ts.genotype_matrix().T + G = G[sample_ind] + G_centered = G - G.mean(axis=0) + B = np.zeros(len(indexes)) + for i, (ix, iy) in enumerate(indexes): + x1 = sample_sets[ix][0] + x2 = sample_sets[ix][1] + y1 = sample_sets[iy][0] + y2 = sample_sets[iy][1] + B[i] = ( + (G_centered[x1] + G_centered[x2]) + @ (G_centered[y1] + G_centered[y2]) + / ts.segregating_sites(sample_sets=all_samples, span_normalise=False) + ) + self.assertArrayAlmostEqual(A, B) + + ############################################ # Fst ############################################ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 418c642289..7633cf1251 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5758,7 +5758,7 @@ def genetic_relatedness( # TODO this should be done in C also all_samples = list({u for s in sample_sets for u in s}) denominator = self.segregating_sites( - sample_sets=all_samples, + sample_sets=[all_samples], windows=windows, mode=mode, span_normalise=span_normalise, @@ -5776,7 +5776,10 @@ def genetic_relatedness( span_normalise=span_normalise, polarised=polarised, ) - return numerator / denominator + with np.errstate(divide="ignore", invalid="ignore"): + out = numerator / denominator + + return out def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """