Skip to content

Commit

Permalink
Add sgkit dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Nov 3, 2022
1 parent 49d6423 commit 5ce3c24
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 6 deletions.
3 changes: 2 additions & 1 deletion requirements/CI-tests-complete/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ pytest-cov==3.0.0
pytest-xdist==2.5.0
seaborn==0.11.2
setuptools==65.4.1
sgkit[vcf]==0.5.0
sortedcontainers==2.4.0
tqdm==4.64.0
tskit==0.5.3
twine==4.0.1
zarr==2.11.3
zarr==2.10.3
3 changes: 2 additions & 1 deletion requirements/CI-tests-pip/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ pandas==1.4.2
pytest==7.1.2
pytest-xdist==2.5.0
seaborn==0.11.2
sgkit[vcf]==0.5.0
sortedcontainers==2.4.0
tqdm==4.64.0
tskit==0.5.3
zarr==2.11.3
zarr==2.10.3
3 changes: 2 additions & 1 deletion requirements/CI-tests-pip/requirements3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ pandas==1.3.5
pytest==7.1.2
pytest-xdist==2.5.0
seaborn==0.11.2
sgkit[vcf]==0.5.0
sortedcontainers==2.4.0
tqdm==4.64.0
tskit==0.5.3
zarr==2.11.3
zarr==2.10.3
2 changes: 1 addition & 1 deletion requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ pandas
matplotlib
seaborn
colorama

sgkit[vcf]
45 changes: 45 additions & 0 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (C) 2022 University of Oxford
#
# This file is part of tsinfer.
#
# tsinfer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tsinfer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tsinfer. If not, see <http://www.gnu.org/licenses/>.
#
"""
Tests for the data files.
"""
import msprime
import numpy as np
import sgkit.io.vcf

import tsinfer


def test_sgkit_dataset(tmp_path):
ts = msprime.sim_ancestry(
samples=50,
ploidy=3,
recombination_rate=0.25,
sequence_length=50,
random_seed=100,
)
ts = msprime.sim_mutations(ts, rate=0.025, model=msprime.BinaryMutationModel())
with open(tmp_path / "data.vcf", "w") as f:
ts.write_vcf(f)
sgkit.io.vcf.vcf_to_zarr(
tmp_path / "data.vcf", tmp_path / "data.zarr", ploidy=3, max_alt_alleles=1
)
samples = tsinfer.SgkitSampleData(tmp_path / "data.zarr")
inf_ts = tsinfer.infer(samples)
assert np.array_equal(ts.genotype_matrix(), inf_ts.genotype_matrix())
261 changes: 259 additions & 2 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ def variants(self, sites=None, recode_ancestral=None):
genos = geno_map[genos]
yield Variant(site=site, alleles=alleles, genotypes=genos)

def __all_haplotypes(self, sites=None, recode_ancestral=None):
def _all_haplotypes(self, sites=None, recode_ancestral=None):
# We iterate over chunks vertically here, and it's not worth complicating
# the chunk iterator to handle this.
if recode_ancestral is None:
Expand Down Expand Up @@ -2156,7 +2156,7 @@ def haplotypes(self, samples=None, sites=None, recode_ancestral=None):
raise ValueError("Sample index too large.")
j = 0

for index, a in self.__all_haplotypes(sites, recode_ancestral):
for index, a in self._all_haplotypes(sites, recode_ancestral):
if j == len(samples):
break
if index == samples[j]:
Expand Down Expand Up @@ -2225,6 +2225,263 @@ def populations(self):
yield Population(j, metadata=metadata)


class SgkitSampleData(SampleData):

FORMAT_NAME = "tsinfer-sgkit-sample-data"
FORMAT_VERSION = (0, 1)

def __init__(self, path):
self.path = path
self.data = zarr.open(path, mode="r")
self._num_sites, self._num_individuals, self.ploidy = self.data[
"call_genotype"
].shape
self._num_samples = self._num_individuals * self.ploidy

assert self.ploidy == self.data["call_genotype"].chunks[2]

def __metadata_schema_getter(self, zarr_group):
try:
return self.data[zarr_group].attrs["metadata_schema"]
except KeyError:
return {"codec": "json"}

@property
def uuid(self):
return (
"Hmm, not sure, could just generate a UUID, but then it wouldn't"
"be in the file - maybe we do need to write back on init"
)

@property
def format_name(self):
return self.FORMAT_NAME

@property
def format_version(self):
return self.FORMAT_VERSION

@property
def finalised(self):
return True

@property
def sequence_length(self):
try:
return self.data.attrs["sequence_length"]
except KeyError:
return int(np.max(self.data["variant_position"])) + 1

@property
def num_sites(self):
return self._num_sites

@property
def sites_metadata_schema(self):
return self.__metadata_schema_getter("sites")

@property
def sites_metadata(self):
try:
return self.data["sites/metadata"]
except KeyError:
return zarr.array(
[{}] * self.num_individuals, object_codec=numcodecs.JSON()
)

@property
def sites_time(self):
try:
return self.data["sites/time"]
except KeyError:
return np.full(self.data["variant_position"].shape, tskit.UNKNOWN_TIME)

@property
def sites_position(self):
return self.data["variant_position"]

@property
def sites_alleles(self):
return self.data["variant_allele"]

@property
def sites_ancestral_allele(self):
try:
return self.data["sites/ancestral_allele"]
except KeyError:
# Maintains backwards compatibility: in previous tsinfer versions the
# ancestral allele was always the zeroth element in the alleles list
return np.zeros(self.num_sites, dtype=np.int8)

@property
def sites_genotypes(self):
gt = self.data["call_genotype"]
return gt[...].reshape(gt.shape[0], gt.shape[1] * gt.shape[2])

@property
def provenances_timestamp(self):
try:
return self.data["provenances_timestamp"]
except KeyError:
return np.array([], dtype=object)

@property
def provenances_record(self):
try:
return self.data["provenances_record"]
except KeyError:
return np.array([], dtype=object)

@property
def num_samples(self):
return self._num_samples

@property
def samples_individual(self):
ret = np.zeros((self.num_samples), dtype=np.int32)
for p in range(self.ploidy):
ret[p :: self.ploidy] = np.arange(self.num_individuals)
return ret

@property
def metadata_schema(self):
try:
return self.data.attrs["metadata_schema"]
except KeyError:
None

@property
def metadata(self):
try:
return self.data.attrs["metadata_schema"]
except KeyError:
return b""

@property
def populations_metadata(self):
try:
return self.data["populations/metadata"]
except KeyError:
return np.array([], dtype=object)

@property
def populations_metadata_schema(self):
return self.__metadata_schema_getter("populations")

@property
def num_individuals(self):
return self._num_individuals

@property
def individuals_time(self):
try:
return self.data["individuals/time"]
except KeyError:
return np.full(self.num_individuals, tskit.UNKNOWN_TIME)

@property
def individuals_metadata_schema(self):
return self.__metadata_schema_getter("individuals")

@property
def individuals_metadata(self):
try:
return self.data["individuals/metadata"]
except KeyError:
return zarr.array(
[{}] * self.num_individuals, object_codec=numcodecs.JSON()
)

@property
def individuals_location(self):
try:
return self.data["individuals/location"]
except KeyError:
return zarr.array([[]] * self.num_individuals, dtype=float)

@property
def individuals_population(self):
try:
return self.data["individuals/population"]
except KeyError:
return np.full((self.num_individuals), tskit.NULL, dtype=np.int32)

@property
def individuals_flags(self):
try:
return self.data["individuals/population"]
except KeyError:
return np.full((self.num_individuals), 0, dtype=np.int32)

def variants(self, sites=None, recode_ancestral=None):
"""
Returns an iterator over the :class:`Variant` objects. This is equivalent to
the :meth:`tskit.TreeSequence.variants` iterator. If recode_ancestral is
``True``, the ``.alleles`` attribute of each variant is guaranteed to return
the alleles in an order such that the ancestral state is the first item
in the list. In this case, ``variant.alleles`` may list the alleles in a
different order from the input order as listed in ``variant.site.alleles``,
and the values in genotypes array will be recoded so that the ancestral
state will have a genotype of 0. If the ancestral state is unknown, the
original input order is kept.
:param array sites: A numpy array of ascending site ids for which to return
data. If None (default) return all sites.
:param bool recode_ancestral: If True, recode genotypes at sites where the
ancestral state is known such that the ancestral state is coded as 0,
as described above. Otherwise return genotypes in the input allele encoding.
Default: ``None`` treated as ``False``.
:return: An iterator over the variants in the sample data file.
:rtype: iter(:class:`Variant`)
"""
if recode_ancestral is None:
recode_ancestral = False
all_genotypes = chunk_iterator(self.data["call_genotype"], indexes=sites)
assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA
for genos, site in zip(all_genotypes, self.sites(ids=sites)):
# We have an extra ploidy dimension when coming from sgkit
genos = genos.reshape(self.num_samples)
aa = site.ancestral_allele
alleles = site.alleles
if aa != MISSING_DATA and aa > 0 and recode_ancestral:
# Need to recode this site
alleles = site.reorder_alleles()
# re-map the genotypes
geno_map = np.arange(len(alleles) - MISSING_DATA, dtype=genos.dtype)
geno_map[MISSING_DATA] = MISSING_DATA
geno_map[aa] = 0
geno_map[0:aa] += 1
genos = geno_map[genos]
yield Variant(site=site, alleles=alleles, genotypes=genos)

def _all_haplotypes(self, sites=None, recode_ancestral=None):
# We iterate over chunks vertically here, and it's not worth complicating
# the chunk iterator to handle this.
if recode_ancestral is None:
recode_ancestral = False
aa_index = self.sites_ancestral_allele[:]
# If ancestral allele is missing, keep the order unchanged (aa_index of zero)
aa_index[aa_index == MISSING_DATA] = 0
gt = self.data["call_genotype"]
chunk_size = gt.chunks[1]
for j in range(self.num_individuals):
if j % chunk_size == 0:
chunk = gt[:, j : j + chunk_size, :]
indiv_gt = chunk[:, j % chunk_size, :]
for k in range(self.ploidy):
a = indiv_gt[:, k].T
if recode_ancestral:
# Remap the genotypes at all sites, depending on the aa_index
a = np.where(
a == aa_index,
0,
np.where(
np.logical_and(a != MISSING_DATA, a < aa_index), a + 1, a
),
)
yield (j * self.ploidy) + k, a if sites is None else a[sites]


@attr.s(order=False, eq=False)
class Ancestor:
"""
Expand Down

0 comments on commit 5ce3c24

Please sign in to comment.