diff --git a/requirements/CI-tests-complete/requirements.txt b/requirements/CI-tests-complete/requirements.txt index 5c859cf8..b590814d 100644 --- a/requirements/CI-tests-complete/requirements.txt +++ b/requirements/CI-tests-complete/requirements.txt @@ -16,8 +16,9 @@ pytest-cov==3.0.0 pytest-xdist==2.5.0 seaborn==0.11.2 setuptools==65.4.1 +sgkit[vcf]==0.5.0 sortedcontainers==2.4.0 tqdm==4.64.0 tskit==0.5.3 twine==4.0.1 -zarr==2.11.3 +zarr==2.10.3 diff --git a/requirements/CI-tests-pip/requirements.txt b/requirements/CI-tests-pip/requirements.txt index bda92366..a26320bb 100644 --- a/requirements/CI-tests-pip/requirements.txt +++ b/requirements/CI-tests-pip/requirements.txt @@ -8,7 +8,8 @@ pandas==1.4.2 pytest==7.1.2 pytest-xdist==2.5.0 seaborn==0.11.2 +sgkit[vcf]==0.5.0 sortedcontainers==2.4.0 tqdm==4.64.0 tskit==0.5.3 -zarr==2.11.3 +zarr==2.10.3 diff --git a/requirements/CI-tests-pip/requirements3.7.txt b/requirements/CI-tests-pip/requirements3.7.txt index 55c17a53..52917d58 100644 --- a/requirements/CI-tests-pip/requirements3.7.txt +++ b/requirements/CI-tests-pip/requirements3.7.txt @@ -8,7 +8,8 @@ pandas==1.3.5 pytest==7.1.2 pytest-xdist==2.5.0 seaborn==0.11.2 +sgkit[vcf]==0.5.0 sortedcontainers==2.4.0 tqdm==4.64.0 tskit==0.5.3 -zarr==2.11.3 +zarr==2.10.3 diff --git a/requirements/development.txt b/requirements/development.txt index 52523426..9a2bf29e 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -33,4 +33,4 @@ pandas matplotlib seaborn colorama - +sgkit[vcf] diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py new file mode 100644 index 00000000..1141f838 --- /dev/null +++ b/tests/test_sgkit.py @@ -0,0 +1,45 @@ +# +# Copyright (C) 2022 University of Oxford +# +# This file is part of tsinfer. +# +# tsinfer is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tsinfer is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with tsinfer. If not, see . +# +""" +Tests for the data files. +""" +import msprime +import numpy as np +import sgkit.io.vcf + +import tsinfer + + +def test_sgkit_dataset(tmp_path): + ts = msprime.sim_ancestry( + samples=50, + ploidy=3, + recombination_rate=0.25, + sequence_length=50, + random_seed=100, + ) + ts = msprime.sim_mutations(ts, rate=0.025, model=msprime.BinaryMutationModel()) + with open(tmp_path / "data.vcf", "w") as f: + ts.write_vcf(f) + sgkit.io.vcf.vcf_to_zarr( + tmp_path / "data.vcf", tmp_path / "data.zarr", ploidy=3, max_alt_alleles=1 + ) + samples = tsinfer.SgkitSampleData(tmp_path / "data.zarr") + inf_ts = tsinfer.infer(samples) + assert np.array_equal(ts.genotype_matrix(), inf_ts.genotype_matrix()) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 7fcc7ceb..10f939fe 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2100,7 +2100,7 @@ def variants(self, sites=None, recode_ancestral=None): genos = geno_map[genos] yield Variant(site=site, alleles=alleles, genotypes=genos) - def __all_haplotypes(self, sites=None, recode_ancestral=None): + def _all_haplotypes(self, sites=None, recode_ancestral=None): # We iterate over chunks vertically here, and it's not worth complicating # the chunk iterator to handle this. if recode_ancestral is None: @@ -2156,7 +2156,7 @@ def haplotypes(self, samples=None, sites=None, recode_ancestral=None): raise ValueError("Sample index too large.") j = 0 - for index, a in self.__all_haplotypes(sites, recode_ancestral): + for index, a in self._all_haplotypes(sites, recode_ancestral): if j == len(samples): break if index == samples[j]: @@ -2225,6 +2225,263 @@ def populations(self): yield Population(j, metadata=metadata) +class SgkitSampleData(SampleData): + + FORMAT_NAME = "tsinfer-sgkit-sample-data" + FORMAT_VERSION = (0, 1) + + def __init__(self, path): + self.path = path + self.data = zarr.open(path, mode="r") + self._num_sites, self._num_individuals, self.ploidy = self.data[ + "call_genotype" + ].shape + self._num_samples = self._num_individuals * self.ploidy + + assert self.ploidy == self.data["call_genotype"].chunks[2] + + def __metadata_schema_getter(self, zarr_group): + try: + return self.data[zarr_group].attrs["metadata_schema"] + except KeyError: + return {"codec": "json"} + + @property + def uuid(self): + return ( + "Hmm, not sure, could just generate a UUID, but then it wouldn't" + "be in the file - maybe we do need to write back on init" + ) + + @property + def format_name(self): + return self.FORMAT_NAME + + @property + def format_version(self): + return self.FORMAT_VERSION + + @property + def finalised(self): + return True + + @property + def sequence_length(self): + try: + return self.data.attrs["sequence_length"] + except KeyError: + return int(np.max(self.data["variant_position"])) + 1 + + @property + def num_sites(self): + return self._num_sites + + @property + def sites_metadata_schema(self): + return self.__metadata_schema_getter("sites") + + @property + def sites_metadata(self): + try: + return self.data["sites/metadata"] + except KeyError: + return zarr.array( + [{}] * self.num_individuals, object_codec=numcodecs.JSON() + ) + + @property + def sites_time(self): + try: + return self.data["sites/time"] + except KeyError: + return np.full(self.data["variant_position"].shape, tskit.UNKNOWN_TIME) + + @property + def sites_position(self): + return self.data["variant_position"] + + @property + def sites_alleles(self): + return self.data["variant_allele"] + + @property + def sites_ancestral_allele(self): + try: + return self.data["sites/ancestral_allele"] + except KeyError: + # Maintains backwards compatibility: in previous tsinfer versions the + # ancestral allele was always the zeroth element in the alleles list + return np.zeros(self.num_sites, dtype=np.int8) + + @property + def sites_genotypes(self): + gt = self.data["call_genotype"] + return gt[...].reshape(gt.shape[0], gt.shape[1] * gt.shape[2]) + + @property + def provenances_timestamp(self): + try: + return self.data["provenances_timestamp"] + except KeyError: + return np.array([], dtype=object) + + @property + def provenances_record(self): + try: + return self.data["provenances_record"] + except KeyError: + return np.array([], dtype=object) + + @property + def num_samples(self): + return self._num_samples + + @property + def samples_individual(self): + ret = np.zeros((self.num_samples), dtype=np.int32) + for p in range(self.ploidy): + ret[p :: self.ploidy] = np.arange(self.num_individuals) + return ret + + @property + def metadata_schema(self): + try: + return self.data.attrs["metadata_schema"] + except KeyError: + None + + @property + def metadata(self): + try: + return self.data.attrs["metadata_schema"] + except KeyError: + return b"" + + @property + def populations_metadata(self): + try: + return self.data["populations/metadata"] + except KeyError: + return np.array([], dtype=object) + + @property + def populations_metadata_schema(self): + return self.__metadata_schema_getter("populations") + + @property + def num_individuals(self): + return self._num_individuals + + @property + def individuals_time(self): + try: + return self.data["individuals/time"] + except KeyError: + return np.full(self.num_individuals, tskit.UNKNOWN_TIME) + + @property + def individuals_metadata_schema(self): + return self.__metadata_schema_getter("individuals") + + @property + def individuals_metadata(self): + try: + return self.data["individuals/metadata"] + except KeyError: + return zarr.array( + [{}] * self.num_individuals, object_codec=numcodecs.JSON() + ) + + @property + def individuals_location(self): + try: + return self.data["individuals/location"] + except KeyError: + return zarr.array([[]] * self.num_individuals, dtype=float) + + @property + def individuals_population(self): + try: + return self.data["individuals/population"] + except KeyError: + return np.full((self.num_individuals), tskit.NULL, dtype=np.int32) + + @property + def individuals_flags(self): + try: + return self.data["individuals/population"] + except KeyError: + return np.full((self.num_individuals), 0, dtype=np.int32) + + def variants(self, sites=None, recode_ancestral=None): + """ + Returns an iterator over the :class:`Variant` objects. This is equivalent to + the :meth:`tskit.TreeSequence.variants` iterator. If recode_ancestral is + ``True``, the ``.alleles`` attribute of each variant is guaranteed to return + the alleles in an order such that the ancestral state is the first item + in the list. In this case, ``variant.alleles`` may list the alleles in a + different order from the input order as listed in ``variant.site.alleles``, + and the values in genotypes array will be recoded so that the ancestral + state will have a genotype of 0. If the ancestral state is unknown, the + original input order is kept. + + :param array sites: A numpy array of ascending site ids for which to return + data. If None (default) return all sites. + :param bool recode_ancestral: If True, recode genotypes at sites where the + ancestral state is known such that the ancestral state is coded as 0, + as described above. Otherwise return genotypes in the input allele encoding. + Default: ``None`` treated as ``False``. + :return: An iterator over the variants in the sample data file. + :rtype: iter(:class:`Variant`) + """ + if recode_ancestral is None: + recode_ancestral = False + all_genotypes = chunk_iterator(self.data["call_genotype"], indexes=sites) + assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA + for genos, site in zip(all_genotypes, self.sites(ids=sites)): + # We have an extra ploidy dimension when coming from sgkit + genos = genos.reshape(self.num_samples) + aa = site.ancestral_allele + alleles = site.alleles + if aa != MISSING_DATA and aa > 0 and recode_ancestral: + # Need to recode this site + alleles = site.reorder_alleles() + # re-map the genotypes + geno_map = np.arange(len(alleles) - MISSING_DATA, dtype=genos.dtype) + geno_map[MISSING_DATA] = MISSING_DATA + geno_map[aa] = 0 + geno_map[0:aa] += 1 + genos = geno_map[genos] + yield Variant(site=site, alleles=alleles, genotypes=genos) + + def _all_haplotypes(self, sites=None, recode_ancestral=None): + # We iterate over chunks vertically here, and it's not worth complicating + # the chunk iterator to handle this. + if recode_ancestral is None: + recode_ancestral = False + aa_index = self.sites_ancestral_allele[:] + # If ancestral allele is missing, keep the order unchanged (aa_index of zero) + aa_index[aa_index == MISSING_DATA] = 0 + gt = self.data["call_genotype"] + chunk_size = gt.chunks[1] + for j in range(self.num_individuals): + if j % chunk_size == 0: + chunk = gt[:, j : j + chunk_size, :] + indiv_gt = chunk[:, j % chunk_size, :] + for k in range(self.ploidy): + a = indiv_gt[:, k].T + if recode_ancestral: + # Remap the genotypes at all sites, depending on the aa_index + a = np.where( + a == aa_index, + 0, + np.where( + np.logical_and(a != MISSING_DATA, a < aa_index), a + 1, a + ), + ) + yield (j * self.ploidy) + k, a if sites is None else a[sites] + + @attr.s(order=False, eq=False) class Ancestor: """