diff --git a/docs/_static/ancestor_grouping.png b/docs/_static/ancestor_grouping.png new file mode 100644 index 00000000..a79c75b5 Binary files /dev/null and b/docs/_static/ancestor_grouping.png differ diff --git a/docs/_toc.yml b/docs/_toc.yml index 8f15ad5a..6ba9d46d 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -14,6 +14,7 @@ parts: - caption: Inference chapters: - file: inference + - file: large_scale - caption: Interfaces chapters: - file: api diff --git a/docs/api.rst b/docs/api.rst index e3a5ebb6..0b9f9b8a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -16,6 +16,11 @@ File formats Sample data +++++++++++ +.. autoclass:: tsinfer.VariantData + :members: + :inherited-members: + + .. autoclass:: tsinfer.SampleData :members: :inherited-members: @@ -60,6 +65,27 @@ Running inference .. autofunction:: tsinfer.post_process +***************** +Batched inference +***************** + +.. autofunction:: tsinfer.match_ancestors_batch_init + +.. autofunction:: tsinfer.match_ancestors_batch_groups + +.. autofunction:: tsinfer.match_ancestors_batch_group_partition + +.. autofunction:: tsinfer.match_ancestors_batch_group_finalise + +.. autofunction:: tsinfer.match_ancestors_batch_finalise + +.. autofunction:: tsinfer.match_samples_batch_init + +.. autofunction:: tsinfer.match_samples_batch_partition + +.. autofunction:: tsinfer.match_samples_batch_finalise + + ***************** Container classes ***************** diff --git a/docs/inference.md b/docs/inference.md index e6ebac21..d6b846fc 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -300,4 +300,4 @@ The final phase of a `tsinfer` inference consists of a number steps: section 2. Describe the structure of the output tree sequences; how the nodes are mapped, what the time values mean, etc. -::: +::: \ No newline at end of file diff --git a/docs/large_scale.md b/docs/large_scale.md new file mode 100644 index 00000000..ee343d8c --- /dev/null +++ b/docs/large_scale.md @@ -0,0 +1,136 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.12 + jupytext_version: 1.9.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +:::{currentmodule} tsinfer +::: + +(sec_large_scale)= + +# Large Scale Inference + +tsinfer scales well and has been successfully used with datasets up to half a +million samples. Here we detail considerations and tips for each step of the +inference process to help you scale up your analysis. A snakemake pipeline +which implements this parallelisation scheme is available at https://github.com/benjeffery/tsinfer-snakemake. + +(sec_large_scale_ancestor_generation)= + +## Data preparation + +For large scale inference the data must be in [VCF Zarr](https://github.com/sgkit-dev/vcf-zarr-spec) +format, read by the {class}`VariantData` class. [bio2zarr](https://github.com/sgkit-dev/bio2zarr) +is recommended for conversion from VCF. [sgkit](https://github.com/sgkit-dev/sgkit) can then +be used to perform initial filtering. + + +## Ancestor generation + +Ancestor generation is generally the fastest step in inference and is not yet +parallelised out-of-core in tsinfer. However it scales well on machines with +many cores and hyperthreading via the `num_threads` argument to +{meth}`generate_ancestors`. The limiting factor is often that the +entire genotype array for the contig being inferred needs to fit in RAM. +This is the high-water mark for memory usage in tsinfer. +Note the `genotype_encoding` argument, setting this to +{class}`tsinfer.GenotypeEncoding.ONE_BIT` reduces the memory footprint of +the genotype array by a factor of 8, for a surprisingly small increase in +runtime. With this encoding, the RAM needed is roughly +`num_sites * num_samples * ploidy / 8 bytes.` + +## Ancestor matching + +Ancestor matching is one of the more time consuming steps of inference. It +proceeds in groups, progressively growing the tree sequence with younger +ancestors. At each stage the parallelism is limited to the number of ancestors +whose possible inheritors are already matched, as all possible inheritors +of a sample must be matched in an earlier group. For a typical human data set +the number of samples per group varies from single digits up to approximately +the number of samples. +The plot below shows the number of ancestors matched in each group for a typical +human data set: + +```{figure} _static/ancestor_grouping.png +:width: 80% +``` + +There are five tsinfer API methods that can be used to parallelise ancestor +matching. + +Initially {meth}`match_ancestors_batch_init` should be called to +set up the batch matching and to determine the groupings of ancestors. +This method writes a file `metadata.json` to the `work_dir` that contains +a JSON encoded dictionary with configuration for later steps, and a key +`ancestor_grouping` which is a list of dictionaries, each containing the +list of ancestors in that group (key:`ancestors`) and a proposed partioning of +those ancestors into sets that can be matched in parallel (key:`partitions`). +The dictionary is also returned by the method. +The partitioning is controlled by the `min_work_per_job` and `max_num_partitions` +arguments. Ancestors are placed in a partition until the sum of their lengths exceeds +`min_work_per_job`, when a new partition is started. However, the number of partitions +is not allowed to exceed `max_num_partitions`. It is suggested to set `max_num_partitions` +to around 3-4x the number of worker nodes available, and `min_work_per_job` to around +2,000,000 for a typical human data set. + +Each group is then matched in turn, either by calling {meth}`match_ancestors_batch_groups` +to match without partitioning, or by calling {meth}`match_ancestors_batch_group_partition` +many times in parallel followed by a single call to {meth}`match_ancestors_batch_group_finalise`. +Each call to {meth}`match_ancestors_batch_groups` or {meth}`match_ancestors_batch_group_finalise` +outputs the tree sequence to `work_dir`, which is then used by the next group. The length of +the `ancestor_grouping` in the metadata dictionary determines the group numbers that these methods +will need to be called for, and the length of the `partitions` list in each group determines +the number of calls to {meth}`match_ancestors_batch_group_partition` that are needed (if any). + +{meth}`match_ancestors_batch_groups` matches groups, without partitioning, from +`group_index_start` (inclusively) to `group_index_end` (exclusively). Combining +many groups into one call reduces the overhead from job submission and start +up times, but note on job failure the process can only be resumed from the +last `group_index_end`. + +To match a single group in parallel, call {meth}`match_ancestors_batch_group_partition` +once for each partition listed in the `ancestor_grouping[group_index]['partitions']` list, +incrementing `partition_index`. This will match the ancestors, placing the match data in +the `working_dir`. Once all are complete a single call to +{meth}`match_ancestors_batch_group_finalise` will then insert the matches and +output the tree sequence to `work_dir`. + +At anypoint the process can be resumed from the last successfully completed call to +{meth}`match_ancestors_batch_groups`. As the tree sequences in `work_dir` checkpoint the +progress. + +Finally after the final group, call {meth}`match_ancestors_batch_finalise` to +combine the groups into a single tree sequence. + +The partitioning in `metadata.json` does not have to be used for every group. As early groups are +not matching to a large tree sequence it is often faster to not partition the first half of the +groups, depending on job set up and queueing time on your cluster. + +Calls to {meth}`match_ancestors_batch_group_partition` will only use a single core, but +{meth}`match_ancestors_batch_groups` will use as many cores as `num_threads` is set to +Therefore this value and cluster resources requested should be scaled with the number of ancestors, +which can be read from the metadata dictionary. + + + +## Sample matching + +Sample matching is far simpler than ancestor matching as it is essentially the same as a single group +of ancestors. There are three API methods that work together to enable distributed sample matching. +{meth}`match_samples_batch_init` should be called to set up the batch matching and to determine the +groupings of samples. Similar to {meth}`match_ancestors_batch_init` is has a `min_work_per_job` and +`max_num_partitions` arguments to control the level of parallelism. The method writes a file +`metadata.json` to the directory `work_dir` that contains a JSON encoded dictionary with +configuration for later steps. This is also returned by the call. The `num_partitions` key in +this dictionary is the number of times {meth}`match_samples_batch_partition` will need +to be called, with each partition index as the `partition_index` argument. These calls can happen +in parallel and write match data to the `work_dir` which is then used by +{meth}`match_samples_batch_finalise` to output the tree sequence. \ No newline at end of file diff --git a/tsinfer/formats.py b/tsinfer/formats.py index c485d792..95e23a44 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2308,7 +2308,7 @@ class VariantData(SampleData): the inference process will have ``inferred_ts.num_samples`` equal to double the number returned by ``VariantData.num_samples``. - :param Union(str, zarr.hierarchy.Group) path_or_zarr: The input dataset in + :param Union(str, zarr.Group) path_or_zarr: The input dataset in `VCF Zarr `_ format. This can either a path to the Zarr dataset saved on disk, or the Zarr object itself. diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 9693e0ed..b948d060 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -592,7 +592,7 @@ def match_ancestors( def match_ancestors_batch_init( - working_dir, + work_dir, sample_data_path, ancestral_state, ancestor_data_path, @@ -613,11 +613,78 @@ def match_ancestors_batch_init( time_units=None, record_provenance=True, ): + """ + match_ancestors_batch_init(work_dir, sample_data_path, ancestral_state, + ancestor_data_path, min_work_per_job, \\*, max_num_partitions=None, + sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, + path_compression=True) + + Initialise a batched ancestor matching job. This function is used to + prepare a working directory for running a batched ancestor matching job. The + job is split into groups of ancestors, with each group further split into + partitions of ancestors if necessary. `work_dir` is created and details + are written to `metadata.json` in `work_dir`. The job can then be run + using :meth:`match_ancestors_batch_groups` and + :meth:`match_ancestors_batch_group_partition` then finally + :meth:`match_ancestors_batch_group_finalise`. See + :ref:`large scale inference` for more details about how these + methods work together. See :meth:`match_ancestors` for details on + ancestor matching. + + :param str work_dir: The directory in which to store the working files. + :param str sample_data_path: The input dataset in + `VCF Zarr `_ format. + Path to the Zarr dataset saved on disk. See :class:`VariantData`. + :param Union(array, str) ancestral_state: A numpy array of strings specifying + the ancestral states (alleles) used in inference. This must be the same length + as the number of unmasked sites in the dataset. Alternatively, a single string + can be provided, giving the name of an array in the input dataset which contains + the ancestral states. Unknown ancestral states can be specified using "N". + Any ancestral states which do not match any of the known alleles at that site, + will be tallied, and a warning issued summarizing the unknown ancestral states. + :param str ancestor_data_path: The path to the file containing the ancestors + generated by :meth:`generate_ancestors`. + :param int min_work_per_job: The minimum amount of work (as a count of genotypes) to + allocate to a single parallel job. If the amount of work in a group of ancestors + exceeds this level it will be broken up into parallel partitions, subject to + the constriant of `max_num_partitions`. + :param int max_num_partitions: The maximum number of partitions to split a group of + ancestors into. Useful for limiting the number of jobs in a workflow to + avoid job overhead. Defaults to 1000. + :param Union(array, str) sample_mask: A numpy array of booleans specifying which + samples to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the sample mask. If ``None`` (default), all samples are included. + :param Union(array, str) site_mask: A numpy array of booleans specifying which + sites to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the site mask. If ``None`` (default), all sites are included. + :param recombination_rate: Either a floating point value giving a constant rate + :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` + object. This is used to calculate the probability of recombination between + adjacent sites. If ``None``, all matching conflicts are resolved by + recombination and all inference sites will have a single mutation + (equivalent to mismatch_ratio near zero) + :type recombination_rate: float, msprime.RateMap + :param float mismatch_ratio: The probability of a mismatch relative to the median + probability of recombination between adjacent sites: can only be used if a + recombination rate has been set (default: ``None`` treated as 1 if + ``recombination_rate`` is set). + :param bool path_compression: Whether to merge edges that share identical + paths (essentially taking advantage of shared recombination breakpoints). + :return: A dictionary of the job metadata, as written to `metadata.json` + in `work_dir`. `ancestor_grouping` in this dict contains the grouping + of ancestors into groups and should be used to guide calling + :meth:`match_ancestors_batch_groups` and + :meth:`match_ancestors_batch_group_partition`. + :rtype: dict + """ + if max_num_partitions is None: max_num_partitions = 1000 - working_dir = pathlib.Path(working_dir) - working_dir.mkdir(parents=True, exist_ok=True) + work_dir = pathlib.Path(work_dir) + work_dir.mkdir(parents=True, exist_ok=True) ancestors = formats.AncestorData.load(ancestor_data_path) sample_data = formats.VariantData( @@ -663,7 +730,7 @@ def match_ancestors_batch_init( current_partition_work += ancestor_lengths[ancestor] partitions.append(current_partition) if len(partitions) > 1: - group_dir = working_dir / f"group_{group_index}" + group_dir = work_dir / f"group_{group_index}" group_dir.mkdir() # TODO: Should be a dataclass group = { @@ -690,7 +757,7 @@ def match_ancestors_batch_init( "record_provenance": record_provenance, "ancestor_grouping": ancestor_grouping, } - metadata_path = working_dir / "metadata.json" + metadata_path = work_dir / "metadata.json" metadata_path.write_text(json.dumps(metadata)) return metadata @@ -725,6 +792,28 @@ def initialize_ancestor_matcher(metadata, ancestors_ts=None, **kwargs): def match_ancestors_batch_groups( work_dir, group_index_start, group_index_end, num_threads=0 ): + """ + match_ancestors_batch_groups(work_dir, group_index_start, + group_index_end, num_threads=0) + + Match a set of ancestor groups from `group_index_start`(inclusive) to + `group_index_end`(exclusive) in a batched ancestor matching job. See + :ref:`large scale inference` for more details. + + A tree sequence file for `group_index_start - 1` must exist in `work_dir`, unless + `group_index_start` is 0. After matching the tree sequence for `group_index_end - 1` + is written to `work_dir`. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_ancestors_batch_init`. + :param int group_index_start: The first group index to match. + :param int group_index_end: The group index to stop matching at. + :param int num_threads: The number of worker threads to use. If this is <= 1 then + match sequentially. + :return: The tree sequence representing the inferred ancestors for the last group + matched + :rtype: tskit.TreeSequence + """ metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) @@ -756,6 +845,24 @@ def match_ancestors_batch_groups( def match_ancestors_batch_group_partition(work_dir, group_index, partition_index): + """ + match_ancestors_batch_group_partition(work_dir, group_index, partition_index) + + Match a single partition of ancestors from a group in a batched ancestor matching + job. See :ref:`large scale inference` for more details. The + tree sequence for the group before must exist in `work_dir`. After matching the + results for the partition are written to `work_dir`. Once all partitions for a + group have been matched, the group can be finalised using + :meth:`match_ancestors_batch_group_finalise`. The number of partitions in a + group is recorded in `metadata.json` in the work dir under the + `ancestor_grouping` key. This method uses a single thread. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_ancestors_batch_init`. + :param int group_index: The group index that contains the partition to match. + :param int partition_index: The partition index to match. Must be less than the + number of partitions in the batch job metadata for this group. + """ metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) @@ -781,6 +888,20 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index def match_ancestors_batch_group_finalise(work_dir, group_index): + """ + match_ancestors_batch_group_finalise(work_dir, group_index) + + Finalise a group of partitioned ancestors in a batched ancestor matching job. See + :ref:`large scale inference` for more details. The tree sequence + for the group before must exist in `work_dir`, along with the results for all + partitions in this group. Writes the tree sequence for the group to `work_dir`. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_ancestors_batch_init`. + :param int group_index: The group index to finalise. + :return: The tree sequence representing the inferred ancestors for the group + :rtype: tskit.TreeSequence + """ metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) @@ -805,6 +926,19 @@ def match_ancestors_batch_group_finalise(work_dir, group_index): def match_ancestors_batch_finalise(work_dir): + """ + match_ancestors_batch_finalise(work_dir) + + Finalise a batched ancestor matching job. This method should be called after all + groups have been matched, either by :meth:`match_ancestors_batch_groups` or + :meth:`match_ancestors_batch_group_finalise`. Returns the final ancestors + tree sequence for the batch job. `work_dir` is retained and not deleted. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_ancestors_batch_init`. + :return: The tree sequence representing the inferred ancestors for the batch job + :rtype: tskit.TreeSequence + """ metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) @@ -1023,6 +1157,79 @@ def match_samples_batch_init( record_provenance=True, map_additional_sites=None, ): + """ + match_samples_batch_init(work_dir, sample_data_path, ancestral_state, + ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None, + sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, + path_compression=True, indexes=None, post_process=None, force_sample_times=False) + + Initialise a batched sample matching job. Creates `work_dir` and writes job + details to `metadata.json`. The job can then be run using parallel calls to + :meth:`match_samples_batch_partition` and once those are complete + finally :meth:`match_samples_batch_finalise`. + + The `num_partitions` key in the metadata dict contains the number of partitions + that need to be processed. + + :param str work_dir: The directory in which to store the working files. + :param str sample_data_path: The input dataset in + `VCF Zarr `_ format. + Path to the Zarr dataset saved on disk. See :class:`VariantData`. + :param Union(array, str) ancestral_state: A numpy array of strings specifying + the ancestral states (alleles) used in inference. This must be the same + length as the number of unmasked sites in the dataset. Alternatively, a + single string can be provided, giving the name of an array in the input + dataset which contains the ancestral states. Unknown ancestral states can + be specified using "N". Any ancestral states which do not match any of the + known alleles at that site, will be tallied, and a warning issued + summarizing the unknown ancestral states. + :param str ancestor_ts_path: The path to the tree sequence file containing the + ancestors generated by :meth:`match_ancestors_batch_finalise`, or + :meth:`match_ancestors`. + :param int min_work_per_job: The minimum amount of work (as a count of + genotypes) to allocate to a single parallel job. If the amount of work in + a group of samples exceeds this level it will be broken up into parallel + partitions, subject to the constriant of `max_num_partitions`. + :param int max_num_partitions: The maximum number of partitions to split a + group of samples into. Useful for limiting the number of jobs in a + workflow to avoid job overhead. Defaults to 1000. + :param Union(array, str) sample_mask: A numpy array of booleans specifying + which samples to mask out (exclude) from the dataset. Alternatively, a + string can be provided, giving the name of an array in the input dataset + which contains the sample mask. If ``None`` (default), all samples are + included. + :param Union(array, str) site_mask: A numpy array of booleans specifying which + sites to mask out (exclude) from the dataset. Alternatively, a string can + be provided, giving the name of an array in the input dataset which + contains the site mask. If ``None`` (default), all sites are included. + :param recombination_rate: Either a floating point value giving a constant + rate :math:`\\rho` per unit length of genome, or an + :class:`msprime.RateMap` object. This is used to calculate the + probability of recombination between adjacent sites. If ``None``, all + matching conflicts are resolved by recombination and all inference sites + will have a single mutation (equivalent to mismatch_ratio near zero) + :type recombination_rate: float, msprime.RateMap + :param float mismatch_ratio: The probability of a mismatch relative to the + median probability of recombination between adjacent sites: can only be + used if a recombination rate has been set (default: ``None`` treated as 1 + if ``recombination_rate`` is set). + :param bool path_compression: Whether to merge edges that share identical paths + (essentially taking advantage of shared recombination breakpoints). + :param indexes: The sample indexes to match. If ``None`` (default), all + samples are matched. + :type indexes: arraylike + :param bool post_process: Whether to run the :func:`post_process` method on + the the tree sequence which, among other things, removes ancestral + material that does not end up in the current samples (if not specified, + defaults to ``True``) + :param bool force_sample_times: After matching, should an attempt be made to + adjust the time of "historical samples" (those associated with an + individual having a non-zero time) such that the sample nodes in the tree + sequence appear at the time of the individual with which they are + associated. + :return: A dictionary of the job metadata, as written to `metadata.json` in + `work_dir`. + """ if max_num_partitions is None: max_num_partitions = 1000 @@ -1079,13 +1286,26 @@ def match_samples_batch_init( num_samples_per_partition = 1 wd.num_samples_per_partition = num_samples_per_partition wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition) - wd_path = work_dir / "wd.json" + wd_path = work_dir / "metadata.json" wd.save(wd_path) return wd def match_samples_batch_partition(work_dir, partition_index): - wd_path = pathlib.Path(work_dir) / "wd.json" + """ + match_samples_batch_partition(work_dir, partition_index) + + Match a single partition of samples in a batched sample matching job. See + :ref:`large scale inference` for more details. Match data + for the partition is written to `work_dir`. Uses a single thread to perform + matching. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_samples_batch_init`. + :param int partition_index: The partition index to match. Must be less than + the number of partitions in the batch job metadata key `num_partitions`. + """ + wd_path = pathlib.Path(work_dir) / "metadata.json" wd = SampleBatchWorkDescriptor.load(wd_path) if partition_index >= wd.num_partitions or partition_index < 0: raise ValueError(f"Partition {partition_index} is out of range") @@ -1113,7 +1333,19 @@ def match_samples_batch_partition(work_dir, partition_index): def match_samples_batch_finalise(work_dir): - wd_path = os.path.join(work_dir, "wd.json") + """ + match_samples_batch_finalise(work_dir) + + Finalise a batched sample matching job. This method should be called after all + partitions have been matched by :meth:`match_samples_batch_partition`. Returns + the final tree sequence for the batch job. `work_dir` is retained and not deleted. + + :param str work_dir: The working directory for the batch job, as written by + :meth:`match_samples_batch_init`. + :return: The tree sequence representing the inferred history of the samples. + :rtype: tskit.TreeSequence + """ + wd_path = os.path.join(work_dir, "metadata.json") wd = SampleBatchWorkDescriptor.load(wd_path) variant_data, ancestor_ts, matcher = load_variant_data_and_ancestors_ts(wd) results = []