diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 684ccb68a3d..be2444921d7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -115,6 +115,7 @@ jobs: - | tests/backends/test_mcbackend.py + tests/backends/test_zarr.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py @@ -240,6 +241,7 @@ jobs: - | tests/backends/test_arviz.py + tests/backends/test_zarr.py tests/variational/test_updates.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 71b6c78ed43..de0572e0a24 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -19,6 +19,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index f795fca078a..c399a3e24af 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -17,6 +17,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 48649a617df..39deb8a41a9 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -10,6 +10,7 @@ dependencies: - cachetools>=4.2.1 - cloudpickle - h5py>=2.7 +- zarr>=2.5.0,<3 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index e6fe9857e0a..79c57a44c64 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -21,6 +21,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index ee5bd206f41..bbcba9149f8 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -20,6 +20,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - myst-nb<=1.0.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fa598528300..399fab811b6 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -23,6 +23,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/docs/source/api/backends.rst b/docs/source/api/backends.rst index ca00a56d816..8f0c76f4533 100644 --- a/docs/source/api/backends.rst +++ b/docs/source/api/backends.rst @@ -20,3 +20,5 @@ Internal structures NDArray base.BaseTrace base.MultiTrace + zarr.ZarrTrace + zarr.ZarrChain diff --git a/docs/source/conf.py b/docs/source/conf.py index 74ac0d97463..b9afc12e733 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -309,6 +309,7 @@ "python": ("https://docs.python.org/3/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "xarray": ("https://docs.xarray.dev/en/stable/", None), + "zarr": ("https://zarr.readthedocs.io/en/stable/", None), } diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index cd007cf3c0a..8bcba42301c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,9 +72,11 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray +from pymc.backends.zarr import ZarrTrace from pymc.blocking import PointType from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep +from pymc.util import get_random_generator HAS_MCB = False try: @@ -102,11 +104,13 @@ def _init_trace( model: Model, trace_vars: list[TensorVariable] | None = None, initial_point: PointType | None = None, + rng: np.random.Generator | None = None, ) -> BaseTrace: """Initialize a trace backend for a chain.""" + rng_ = get_random_generator(rng) strace: BaseTrace if trace is None: - strace = NDArray(model=model, vars=trace_vars, test_point=initial_point) + strace = NDArray(model=model, vars=trace_vars, test_point=initial_point, rng=rng_) elif isinstance(trace, BaseTrace): if len(trace) > 0: raise ValueError("Continuation of traces is no longer supported.") @@ -120,15 +124,29 @@ def _init_trace( def init_traces( *, - backend: TraceOrBackend | None, + backend: TraceOrBackend | ZarrTrace | None, chains: int, expected_length: int, step: BlockedStep | CompoundStep, initial_point: PointType, model: Model, trace_vars: list[TensorVariable] | None = None, + tune: int = 0, + rng: np.random.Generator | None = None, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" + if isinstance(backend, ZarrTrace): + backend.init_trace( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=trace_vars, + test_point=initial_point, + rng=rng, + ) + return None, backend.straces if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( backend=backend, @@ -136,6 +154,7 @@ def init_traces( initial_point=initial_point, step=step, model=model, + rng=rng, ) assert backend is None or isinstance(backend, BaseTrace) @@ -148,7 +167,8 @@ def init_traces( model=model, trace_vars=trace_vars, initial_point=initial_point, + rng=rng_, ) - for chain_number in range(chains) + for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains)) ] return None, traces diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 5a2a043a396..1188efbfafc 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -34,7 +34,7 @@ from pymc.backends.report import SamplerReport from pymc.model import modelcontext -from pymc.pytensorf import compile +from pymc.pytensorf import compile, copy_function_with_new_rngs from pymc.util import get_var_name logger = logging.getLogger(__name__) @@ -159,6 +159,7 @@ def __init__( fn=None, var_shapes=None, var_dtypes=None, + rng=None, ): model = modelcontext(model) @@ -177,6 +178,8 @@ def __init__( on_unused_input="ignore", ) fn.trust_input = True + if rng is not None: + fn = copy_function_with_new_rngs(fn=fn, rng=rng) # Get variable shapes. Most backends will need this # information. diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 3d2c8fd9e7e..b6342a21822 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -29,7 +29,7 @@ from pymc.backends.base import IBaseTrace from pymc.model import Model -from pymc.pytensorf import PointFunc +from pymc.pytensorf import PointFunc, copy_function_with_new_rngs from pymc.step_methods.compound import ( BlockedStep, CompoundStep, @@ -38,6 +38,7 @@ flat_statname, flatten_steps, ) +from pymc.util import get_random_generator _log = logging.getLogger(__name__) @@ -96,7 +97,11 @@ class ChainRecordAdapter(IBaseTrace): """Wraps an McBackend ``Chain`` as an ``IBaseTrace``.""" def __init__( - self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection + self, + chain: mcb.Chain, + point_fn: PointFunc, + stats_bijection: StatsBijection, + rng: np.random.Generator | None = None, ) -> None: # Assign attributes required by IBaseTrace self.chain = chain.cmeta.chain_number @@ -107,8 +112,11 @@ def __init__( for sstats in stats_bijection._stat_groups ] + self._rng = rng self._chain = chain self._point_fn = point_fn + if rng is not None: + self._point_fn = copy_function_with_new_rngs(self._point_fn, rng) self._statsbj = stats_bijection super().__init__() @@ -257,6 +265,7 @@ def init_chain_adapters( initial_point: Mapping[str, np.ndarray], step: CompoundStep | BlockedStep, model: Model, + rng: np.random.Generator | None, ) -> tuple[mcb.Run, list[ChainRecordAdapter]]: """Create an McBackend metadata description for the MCMC run. @@ -286,7 +295,8 @@ def init_chain_adapters( chain=run.init_chain(chain_number=chain_number), point_fn=point_fn, stats_bijection=statsbj, + rng=rng_, ) - for chain_number in range(chains) + for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains)) ] return run, adapters diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py new file mode 100644 index 00000000000..2fb7134303a --- /dev/null +++ b/pymc/backends/zarr.py @@ -0,0 +1,867 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from typing import Any + +import arviz as az +import numcodecs +import numpy as np +import xarray as xr +import zarr + +from arviz.data.base import make_attrs +from arviz.data.inference_data import WARMUP_TAG +from numcodecs.abc import Codec +from pytensor.tensor.variable import TensorVariable + +import pymc + +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.backends.base import BaseTrace +from pymc.blocking import StatDtype, StatShape +from pymc.model.core import Model, modelcontext +from pymc.pytensorf import copy_function_with_new_rngs +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + get_stats_dtypes_shapes_from_steps, +) +from pymc.util import ( + UNSET, + _UnsetType, + get_default_varnames, + get_random_generator, + is_transformed_name, +) + +try: + from zarr.storage import BaseStore, default_compressor + from zarr.sync import Synchronizer + + _zarr_available = True +except ImportError: + _zarr_available = False + + +class ZarrChain(BaseTrace): + """Interface object to interact with a single chain in a :class:`~.ZarrTrace`. + + Parameters + ---------- + store : zarr.storage.BaseStore | collections.abc.MutableMapping + The store object where the zarr groups and arrays will be stored and read from. + This store must exist before creating a ``ZarrChain`` object. ``ZarrChain`` are + only intended to be used as interfaces to the individual chains of + :class:`~.ZarrTrace` objects. This means that the :class:`~.ZarrTrace` should + be the one that creates the store that is then provided to a ``ZarrChain``. + stats_bijection : pymc.step_methods.compound.StatsBijection + An object that maps between a list of step method stats and a dictionary of + said stats with the accompanying stepper index. + synchronizer : zarr.sync.Synchronizer | None + The synchronizer to use for the underlying zarr arrays. + model : Model + If None, the model is taken from the `with` context. + vars : Sequence[TensorVariable] | None + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + test_point : dict[str, numpy.ndarray] | None + This is not used and is inherited from the signature of :class:`~.BaseTrace`, + which uses it to determine the shape and dtype of `vars`. + draws_per_chunk : int + The number of draws that make up a chunk in the variable's posterior array. + The interface only writes the samples to the store once a chunk is completely + filled. + """ + + def __init__( + self, + store: BaseStore | MutableMapping, + stats_bijection: StatsBijection, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + draws_per_chunk: int = 1, + fn: Callable | None = None, + ): + if not _zarr_available: + raise RuntimeError("You must install zarr to be able to create ZarrChain instances") + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) + self._step_method: BlockedStep | CompoundStep | None = None + self.unconstrained_variables = { + var.name for var in self.vars if is_transformed_name(var.name) + } + self.draw_idx = 0 + self._buffers: dict[str, dict[str, list]] = { + "posterior": {}, + "sample_stats": {}, + } + self._buffered_draws = 0 + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk > 0 + self._posterior = zarr.open_group( + store, synchronizer=synchronizer, path="posterior", mode="a" + ) + if self.unconstrained_variables: + self._unconstrained_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" + ) + self._buffers["unconstrained_posterior"] = {} + self._sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="sample_stats", mode="a" + ) + self._sampling_state = zarr.open_group( + store, synchronizer=synchronizer, path="_sampling_state", mode="a" + ) + self.stats_bijection = stats_bijection + + def link_stepper(self, step_method: BlockedStep | CompoundStep): + """Provide a reference to the step method used during sampling. + + This reference can be used to facilite writing the stepper's sampling state + each time the samples are flushed into the storage. + """ + self._step_method = step_method + + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + self.chain = chain + self.total_draws = draws + self.draws_until_flush = min([self.draws_per_chunk, draws - self.draw_idx]) + self.clear_buffers() + + def clear_buffers(self): + for group in self._buffers: + self._buffers[group] = {} + self._buffered_draws = 0 + + def buffer(self, group, var_name, value): + buffer = self._buffers[group] + if var_name not in buffer: + buffer[var_name] = [] + buffer[var_name].append(value) + + def record( + self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] + ) -> bool | None: + """Record the step method's returned draw and stats. + + The draws and stats are first stored in an internal buffer. Once the buffer is + filled, the samples and stats are written (flushed) onto the desired zarr store. + + Returns + ------- + flushed : bool | None + Returns ``True`` only if the data was written onto the desired zarr store. + Any other time that the recorded draw and stats are written into the + internal buffer, ``None`` is returned. + + See Also + -------- + :meth:`~ZarrChain.flush` + """ + unconstrained_variables = self.unconstrained_variables + for var_name, var_value in zip(self.varnames, self.fn(**draw)): + if var_name in unconstrained_variables: + self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) + else: + self.buffer(group="posterior", var_name=var_name, value=var_value) + for var_name, var_value in self.stats_bijection.map(stats).items(): + self.buffer(group="sample_stats", var_name=var_name, value=var_value) + self._buffered_draws += 1 + if self._buffered_draws == self.draws_until_flush: + self.flush() + return True + return None + + def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None): + """Record the sampling state information to the store's ``_sampling_state`` group. + + The sampling state includes the number of draws taken so far (``draw_idx``) and + the step method's ``sampling_state``. + + Parameters + ---------- + step : BlockedStep | CompoundStep | None + The step method from which to take the ``sampling_state``. If ``None``, + the ``step`` is taken to be the step method that was linked to the + ``ZarrChain`` when calling :meth:`~ZarrChain.link_stepper`. If this method was never + called, no step method ``sampling_state`` information is stored in the + chain. + """ + if step is None: + step = self._step_method + if step is not None: + self.store_sampling_state(step.sampling_state) + self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) + + def store_sampling_state(self, sampling_state): + self._sampling_state.sampling_state.set_coordinate_selection( + self.chain, np.array([sampling_state], dtype="object") + ) + + def flush(self): + """Write the data stored in the internal buffer to the desired zarr store. + + After writing the draws and stats returned by each step of the step method, + the :meth:`~ZarrChain.record_sampling_state` is called, the internal buffer is cleared and + the number of steps until the next flush is determined. + """ + chain = self.chain + draw_slice = slice(self.draw_idx, self.draw_idx + self.draws_until_flush) + for group_name, buffer in self._buffers.items(): + group = getattr(self, f"_{group_name}") + for var_name, var_value in buffer.items(): + group[var_name].set_orthogonal_selection( + (chain, draw_slice), + np.stack(var_value), + ) + self.draw_idx += self.draws_until_flush + self.record_sampling_state() + self.clear_buffers() + self.draws_until_flush = min([self.draws_per_chunk, self.total_draws - self.draw_idx]) + + +FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None +DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = { + np.floating: np.nan, + np.integer: 0, + np.bool_: False, + np.str_: "", + np.datetime64: np.datetime64(0, "Y"), + np.timedelta64: np.timedelta64(0, "Y"), +} + + +def get_initial_fill_value_and_codec( + dtype: Any, +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: + _dtype = np.dtype(dtype) + fill_value: FILL_VALUE_TYPE = None + codec = None + try: + fill_value = DEFAULT_FILL_VALUES[_dtype] + except KeyError: + for key in DEFAULT_FILL_VALUES: + if np.issubdtype(_dtype, key): + fill_value = DEFAULT_FILL_VALUES[key] + break + else: + codec = numcodecs.Pickle() + return fill_value, _dtype, codec + + +class ZarrTrace: + """Object that stores and enables access to MCMC draws stored in a :class:`zarr.hierarchy.Group` objects. + + This class creats a zarr hierarchy to represent the sampling information which is + intended to mimic :class:`arviz.InferenceData`. The hierarchy looks like this: + + | root + | |--> constant_data + | |--> observed_data + | |--> posterior + | |--> unconstrained_posterior + | |--> sample_stats + | |--> warmup_posterior + | |--> warmup_unconstrained_posterior + | |--> warmup_sample_stats + | |--> _sampling_state + + The root group is created when the ``ZarrTrace`` object is initialized. The rest of + the groups are created once :meth:`~ZarrChain.init_trace` is called with a few exceptions: + unconstrained_posterior is only created if ``include_transformed = True``, and the + groups prefixed with ``warmup_`` are created only after calling + :meth:`~ZarrTrace.split_warmup_groups`. + + Since ``ZarrTrace`` objects are intended to be as close to + :class:`arviz.InferenceData` objects as possible, the groups store the dimension + and coordinate information following the `xarray zarr standard `_. + + Parameters + ---------- + store : zarr.storage.BaseStore | collections.abc.MutableMapping | None + The store object where the zarr groups and arrays will be stored and read from. + Any zarr compatible storage object works. Keep in mind that if ``None`` is + provided, a :class:`zarr.storage.MemoryStore` will be used, which means that + information won't be visible to other processes and won't persist after the + ``ZarrTrace`` life-cycle ends. If you want to have persistent storage, please + use one of the multiple disk backed zarr storage options, e.g. + :class:`~zarr.storage.DirectoryStore` or :class:`~zarr.storage.ZipStore`. + synchronizer : zarr.sync.Synchronizer | None + The synchronizer to use for the underlying zarr arrays. + compressor : numcodec.abc.Codec | None | pymc.util.UNSET + The compressor to use for the underlying zarr arrays. If ``None``, no compressor + is used. If ``UNSET``, zarr's default compressor is used. + draws_per_chunk : int + The number of draws that make up a chunk in the variable's posterior array. + Each variable's array shape is set to ``(n_chains, n_draws, *rv_shape)``, but + the chunks are set to ``(1, draws_per_chunk, *rv_shape)``. This means that each + chain will have it's own chunk to read or write to, allowing for concurrent + write operations of different chains not to interfere with each other, and that + multiple draws can belong to the same chunk. The variable's core dimension + however, will never be split across different chunks. + include_transformed : bool + If ``True``, the transformed, unconstrained value variables are included in the + storage group. + + Notes + ----- + ``ZarrTrace`` objects represent the storage information. If the underlying store + persists on disk or over the network (e.g. with a :class:`zarr.storage.FSStore`) + multiple process will be able to concurrently access the same storage and read or + write to it. + + The intended division of labour is for ``ZarrTrace`` to handle the creation and + management of the zarr group and storage objects and arrays, and for individual + :class:`~.ZarrChain` objects to handle recording MCMC samples to the trace. This + division was chosen to stay close to the existing `pymc.backends.base.MultiTrace` + and `pymc.backends.ndarray.NDArray` way of working with the existing samplers. + + One extra feature of ``ZarrTrace`` is that it enables direct access to any array's + metadata. ``ZarrTrace`` takes advantage of this to tag arrays as ``deterministic`` + or ``freeRV`` depending on what kind of variable they were in the defining model. + + See Also + -------- + :class:`~pymc.backends.zarr.ZarrChain` + """ + + def __init__( + self, + store: BaseStore | MutableMapping | None = None, + synchronizer: Synchronizer | None = None, + compressor: Codec | None | _UnsetType = UNSET, + draws_per_chunk: int = 1, + include_transformed: bool = False, + ): + if not _zarr_available: + raise RuntimeError("You must install zarr to be able to create ZarrTrace instances") + self.synchronizer = synchronizer + if compressor is UNSET: + compressor = default_compressor + self.compressor = compressor + self.root = zarr.group( + store=store, + overwrite=True, + synchronizer=synchronizer, + ) + + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk >= 1 + + self.include_transformed = include_transformed + + self._is_base_setup = False + + def groups(self) -> list[str]: + return [str(group_name) for group_name, _ in self.root.groups()] + + @property + def posterior(self) -> zarr.Group: + return self.root.posterior + + @property + def unconstrained_posterior(self) -> zarr.Group: + return self.root.unconstrained_posterior + + @property + def sample_stats(self) -> zarr.Group: + return self.root.sample_stats + + @property + def constant_data(self) -> zarr.Group: + return self.root.constant_data + + @property + def observed_data(self) -> zarr.Group: + return self.root.observed_data + + @property + def _sampling_state(self) -> zarr.Group: + return self.root._sampling_state + + def init_trace( + self, + chains: int, + draws: int, + tune: int, + step: BlockedStep | CompoundStep, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + rng: np.random.Generator | None = None, + ): + """Initialize the trace groups and arrays. + + This function creates and fills with default values the groups below the + ``ZarrTrace.root`` group. It creates the ``constant_data``, ``observed_data``, + ``posterior``, ``unconstrained_posterior`` (if ``include_transformed = True``), + ``sample_stats``, and ``_sampling_state`` zarr groups, and all of the relevant + arrays that must be stored there. + + Every array in the posterior and sample stats groups will have the + (chains, tune + draws) batch dimensions to the left of the core dimensions of + the model's random variable or the step method's stat shape. The warmup (tuning + draws) and the posterior samples are split at a later stage, once + :meth:`~ZarrTrace.split_warmup_groups` is called. + + After the creation if the zarr hierarchies, it initializes the list of + :class:`~pymc.backends.zarr.Zarrchain` instances (one for each chain) under the + ``straces`` attribute. These objects serve as the interface to record draws and + samples generated by the step methods for each chain. + + Parameters + ---------- + chains : int + The number of chains to use to initialize the arrays. + draws : int + The number of posterior draws to use to initialize the arrays. + tune : int + The number of tuning steps to use to initialize the arrays. + step : pymc.step_methods.compound.BlockedStep | pymc.step_methods.compound.CompoundStep + The step method that will be used to generate the draws and stats. + model : pymc.model.core.Model | None + If None, the model is taken from the ``with`` context. + vars : Sequence[TensorVariable] | None + Sampling values will be stored for these variables. If ``None``, + ``model.unobserved_RVs`` is used. + test_point : dict[str, numpy.ndarray] | None + This is not used and is a product of the inheritance of :class:`ZarrChain` + from :class:`~.BaseTrace`, which uses it to determine the shape and dtype + of `vars`. + rng : numpy.random.Generator | None + A random generator to use to seed the shared random generators that are + present in the pytensor function that maps samples drawn by step methods + onto samples in the posterior trace. Note that this only does anything + if there are deterministic variables that are generated by raw pytensor + random variables. + """ + if self._is_base_setup: + raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover + model = modelcontext(model) + self.model = model + self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) + if vars is None: + vars = model.unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" + self.varnames = get_default_varnames( + [var.name for var in vars], include_transformed=self.include_transformed + ) + self.vars = [var for var in vars if var.name in self.varnames] + + self.fn = model.compile_fn( + self.vars, + inputs=model.value_vars, + on_unused_input="ignore", + point_fn=False, + ) + + # Get variable shapes. Most backends will need this + # information. + if test_point is None: + test_point = model.initial_point() + var_values = list(zip(self.varnames, self.fn(**test_point))) + self.var_dtype_shapes = { + var: (value.dtype, value.shape) + for var, value in var_values + if not is_transformed_name(var) + } + extra_var_attrs = { + var: { + "kind": "freeRV" + if is_transformed_name(var) or model[var] in model.free_RVs + else "deterministic" + } + for var in self.var_dtype_shapes + } + self.unc_var_dtype_shapes = { + var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var) + } + extra_unc_var_attrs = {var: {"kind": "freeRV"} for var in self.unc_var_dtype_shapes} + + self.create_group( + name="constant_data", + data_dict=find_constants(self.model), + ) + + self.create_group( + name="observed_data", + data_dict=find_observations(self.model), + ) + + # Create the posterior that includes warmup draws + self.init_group_with_empty( + group=self.root.create_group(name="posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=tune + draws, + extra_var_attrs=extra_var_attrs, + ) + + # Create the unconstrained posterior group that includes warmup draws + if self.include_transformed and self.unc_var_dtype_shapes: + self.init_group_with_empty( + group=self.root.create_group(name="unconstrained_posterior", overwrite=True), + var_dtype_and_shape=self.unc_var_dtype_shapes, + chains=chains, + draws=tune + draws, + extra_var_attrs=extra_unc_var_attrs, + ) + + # Create the sample stats that include warmup draws + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + self.init_group_with_empty( + group=self.root.create_group(name="sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=tune + draws, + ) + + self.init_sampling_state_group(tune=tune, chains=chains) + + self.straces = [ + ZarrChain( + store=self.root.store, + synchronizer=self.synchronizer, + model=self.model, + vars=self.vars, + test_point=test_point, + stats_bijection=StatsBijection(step.stats_dtypes), + draws_per_chunk=self.draws_per_chunk, + fn=copy_function_with_new_rngs(self.fn, rng_), + ) + for rng_ in get_random_generator(rng).spawn(chains) + ] + for chain, strace in enumerate(self.straces): + strace.setup(draws=tune + draws, chain=chain, sampler_vars=None) + + def split_warmup_groups(self): + """Split the warmup and standard groups. + + This method takes the entries in the arrays in the posterior, sample_stats + and unconstrained_posterior that happened in the tuning phase and moves them + into the warmup_ groups. If the ``warmup_posterior`` group already exists, then + nothing is done. + + See Also + -------- + :meth:`~ZarrTrace.split_warmup` + """ + if "warmup_posterior" not in self.groups(): + self.split_warmup("posterior", error_if_already_split=False) + self.split_warmup("sample_stats", error_if_already_split=False) + try: + self.split_warmup("unconstrained_posterior", error_if_already_split=False) + except KeyError: + pass + + @property + def tuning_steps(self): + try: + return int(self._sampling_state.tuning_steps.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no tuning step information available" + ) + + @property + def sampling_time(self): + try: + return float(self._sampling_state.sampling_time.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no sampling time information available" + ) + + @sampling_time.setter + def sampling_time(self, value): + self._sampling_state.sampling_time.set_basic_selection((), float(value)) + + def init_sampling_state_group(self, tune: int, chains: int): + state = self.root.create_group(name="_sampling_state", overwrite=True) + sampling_state = state.empty( + name="sampling_state", + overwrite=True, + shape=(chains,), + chunks=(1,), + dtype="object", + object_codec=numcodecs.Pickle(), + compressor=self.compressor, + ) + sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + draw_idx = state.array( + name="draw_idx", + overwrite=True, + data=np.zeros(chains, dtype="int"), + chunks=(1,), + dtype="int", + fill_value=-1, + compressor=self.compressor, + ) + draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.array( + name="tuning_steps", + data=tune, + overwrite=True, + dtype="int", + fill_value=0, + compressor=self.compressor, + ) + state.array( + name="sampling_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + state.array( + name="sampling_start_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + + chain = state.array( + name="chain", + data=np.arange(chains), + compressor=self.compressor, + ) + + chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.empty( + name="global_warnings", + dtype="object", + object_codec=numcodecs.Pickle(), + shape=(0,), + ) + + def init_group_with_empty( + self, + group: zarr.Group, + var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], + chains: int, + draws: int, + extra_var_attrs: dict | None = None, + ) -> zarr.Group: + group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} + for name, (_dtype, shape) in var_dtype_and_shape.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) + shape = shape or () + array = group.full( + name=name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, draws, *shape), + chunks=(1, self.draws_per_chunk, *shape), + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i, shape_i in enumerate(shape): + dim = f"{name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(shape_i, dtype="int") + dims = ("chain", "draw", *dims) + attrs = extra_var_attrs[name] if extra_var_attrs is not None else {} + attrs.update({"_ARRAY_DIMENSIONS": dims}) + array.attrs.update(attrs) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: + group: zarr.Group | None = None + if data_dict: + group_coords = {} + group = self.root.create_group(name=name, overwrite=True) + for var_name, var_value in data_dict.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype) + array = group.array( + name=var_name, + data=var_value, + fill_value=fill_value, + dtype=dtype, + object_codec=object_codec, + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[var_name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i in range(var_value.ndim): + dim = f"{var_name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(var_value.shape[i], dtype="int") + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def split_warmup(self, group_name: str, error_if_already_split: bool = True): + """Split the arrays of a group into the warmup and regular groups. + + This function takes the first ``self.tuning_steps`` draws of supplied + ``group_name`` and moves them into a new zarr group called + ``f"warmup_{group_name}"``. + + Parameters + ---------- + group_name : str + The name of the group that should be split. + error_if_already_split : bool + If ``True`` and if the ``f"warmup_{group_name}"`` group already exists in + the root hierarchy, a ``ValueError`` is raised. If this flag is ``False`` + but the warmup group already exists, the contents of that group are + overwritten. + """ + if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { + group_name for group_name, _ in self.root.groups() + }: + raise RuntimeError(f"Warmup data for {group_name} has already been split") + posterior_group = self.root[group_name] + tune = self.tuning_steps + warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) + if tune == 0: + try: + self.root.pop(f"{WARMUP_TAG}{group_name}") + except KeyError: + pass + return + for name, array in posterior_group.arrays(): + array_attrs = array.attrs.asdict() + if name == "draw": + warmup_array = warmup_group.array( + name="draw", + data=np.arange(tune), + dtype="int", + compressor=self.compressor, + ) + posterior_array = posterior_group.array( + name=name, + data=np.arange(len(array) - tune), + dtype="int", + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + else: + dims = array.attrs["_ARRAY_DIMENSIONS"] + warmup_idx: slice | tuple[slice, slice] + if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: + must_overwrite_posterior = True + warmup_idx = (slice(None), slice(None, tune, None)) + posterior_idx = (slice(None), slice(tune, None, None)) + else: + must_overwrite_posterior = False + warmup_idx = slice(None) + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) + warmup_array = warmup_group.array( + name=name, + data=array[warmup_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + compressor=self.compressor, + ) + if must_overwrite_posterior: + posterior_array = posterior_group.array( + name=name, + data=array[posterior_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + warmup_array.attrs.update(array_attrs) + + def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: + """Convert ``ZarrTrace`` to :class:`~.arviz.InferenceData`. + + This converts all the groups in the ``ZarrTrace.root`` hierarchy into an + ``InferenceData`` object. The only exception is that ``_sampling_state`` is + excluded. + + Parameters + ---------- + save_warmup : bool + If ``True``, all of the warmup groups are stored in the inference data + object. + + Notes + ----- + ``xarray`` and in turn ``arviz`` require the zarr groups to have consolidated + metadata. To achieve this, a new consolidated store is constructed by calling + :func:`zarr.consolidate_metadata` on the root's store. This means that the + returned ``InferenceData`` object will operate on a different storage unit + than the calling ``ZarrTrace``, so future changes to the ``ZarrTrace`` won't be + automatically reflected in the returned ``InferenceData`` object. + """ + self.split_warmup_groups() + # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata + consolidated_root = zarr.consolidate_metadata(self.root.store) + # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view + # we need to actually grab the underlying store so that xarray doesn't produce completely + # empty arrays + store = consolidated_root.store.store + groups = {} + try: + global_attrs = { + "tuning_steps": self.tuning_steps, + "sampling_time": self.sampling_time, + } + except AttributeError: + global_attrs = {} # pragma: no cover + for name, _ in self.root.groups(): + if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)): + continue + data = xr.open_zarr(store, group=name, mask_and_scale=False) + attrs = {**data.attrs, **global_attrs} + data.attrs = make_attrs(attrs=attrs, library=pymc) + groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data + return az.InferenceData(**groups) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f665d5931cb..7e07cfa5b5d 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -14,7 +14,7 @@ import warnings from collections.abc import Callable, Generator, Iterable, Sequence -from typing import cast +from typing import cast, overload import numpy as np import pandas as pd @@ -22,6 +22,7 @@ import pytensor.tensor as pt import scipy.sparse as sps +from pytensor import shared from pytensor.compile import Function, Mode, get_mode from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad @@ -37,12 +38,13 @@ ) from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op +from pytensor.link.jax.linker import JAXLinker from pytensor.scalar.basic import Cast from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.type import RandomType +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature @@ -51,7 +53,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.exceptions import NotConstantValueError -from pymc.util import makeiter +from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state from pymc.vartypes import continuous_types, isgenerator, typefilter PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable @@ -1163,3 +1165,64 @@ def normalize_rng_param(rng: None | Variable) -> Variable: "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) return rng + + +@overload +def copy_function_with_new_rngs( + fn: PointFunc, rng: np.random.Generator | RandomGeneratorState +) -> PointFunc: ... + + +@overload +def copy_function_with_new_rngs( + fn: Function, rng: np.random.Generator | RandomGeneratorState +) -> Function: ... + + +def copy_function_with_new_rngs( + fn: Function, rng: np.random.Generator | RandomGeneratorState +) -> Function: + """Copy a compiled pytensor function and replace the random Generators with spawns. + + Parameters + ---------- + fn : pytensor.compile.function.types.Function | pymc.util.PointFunc + The compiled function + rng : numpy.random.Generator | RandomGeneratorState + The random generator or its state + + Returns + ------- + fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc + A copy of the input function with the shared random generator states set to + spawns of the supplied ``rng``. If the function has no shared random generators + in it, the input ``fn`` is returned without any changes. + If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner + pytensor function has random variables, then the inner pytensor function is + copied, setting new random generators, and a new ``PointFunc`` instance is + returned. + """ + # Copy the function and replace any shared RNGs + # This is needed so that it can work correctly with multiple traces + # This will be costly if set_rng is called too often! + rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng) + fn_ = fn.f if isinstance(fn, PointFunc) else fn + shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)] + n_shared_rngs = len(shared_rngs) + if n_shared_rngs > 0 and isinstance(fn_.maker.linker, JAXLinker): + # Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables + # used internally are not the ones that `function.get_shared()` returns. + warnings.warn( + "At the moment, it is not possible to set the random generator's key for " + "JAX linked functions. This means that the draws yielded by the random " + "variables that are requested by 'Deterministic' will not be reproducible." + ) + return fn + swap = { + old_shared_rng: shared(rng, borrow=True) + for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True) + } + if isinstance(fn, PointFunc): + return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn + else: + return fn.copy(swap=swap) if n_shared_rngs > 0 else fn diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index b2d643a5f1b..8d7972832d3 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -26,6 +26,7 @@ Any, Literal, TypeAlias, + cast, overload, ) @@ -40,6 +41,7 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol +from zarr.storage import MemoryStore import pymc as pm @@ -50,6 +52,7 @@ find_observations, ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains +from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -503,7 +506,7 @@ def sample( model: Model | None = None, compile_kwargs: dict | None = None, **kwargs, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: r"""Draw samples from the posterior using the given step methods. Multiple step methods are supported via compound step methods. @@ -570,7 +573,13 @@ def sample( Number of iterations of initializer. Only works for 'ADVI' init methods. trace : backend, optional A backend instance or None. - If None, the NDArray backend is used. + If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects + is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance, + the drawn samples will be written onto the desired storage while sampling is + on-going. This means sampling runs that, for whatever reason, die in the middle + of their execution will write the partial results onto the storage. If the + storage persist on disk, these results should be available even after a server + crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information. discard_tuned_samples : bool Whether to discard posterior samples of the tune interval. compute_convergence_checks : bool, default=True @@ -607,8 +616,12 @@ def sample( Returns ------- - trace : pymc.backends.base.MultiTrace or arviz.InferenceData - A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples. + trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData + A ``MultiTrace``, :class:`~arviz.InferenceData` or + :class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A + ``ZarrTrace`` is only returned if the supplied ``trace`` argument is a + ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for + the benefits this backend provides. Notes ----- @@ -741,7 +754,7 @@ def joined_blas_limiter(): rngs = get_random_generator(random_seed).spawn(chains) random_seed_list = [rng.integers(2**30) for rng in rngs] - if not discard_tuned_samples and not return_inferencedata: + if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace): warnings.warn( "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" @@ -852,6 +865,8 @@ def joined_blas_limiter(): trace_vars=trace_vars, initial_point=initial_points[0], model=model, + tune=tune, + rng=rngs[0].spawn(1)[0], ) sample_args = { @@ -934,7 +949,7 @@ def joined_blas_limiter(): # into a function to make it easier to test and refactor. return _sample_return( run=run, - traces=traces, + traces=trace if isinstance(trace, ZarrTrace) else traces, tune=tune, t_sampling=t_sampling, discard_tuned_samples=discard_tuned_samples, @@ -949,7 +964,7 @@ def joined_blas_limiter(): def _sample_return( *, run: RunType | None, - traces: Sequence[IBaseTrace], + traces: Sequence[IBaseTrace] | ZarrTrace, tune: int, t_sampling: float, discard_tuned_samples: bool, @@ -958,18 +973,69 @@ def _sample_return( keep_warning_stat: bool, idata_kwargs: dict[str, Any], model: Model, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: """Pick/slice chains, run diagnostics and convert to the desired return type. Final step of `pm.sampler`. """ + if isinstance(traces, ZarrTrace): + # Split warmup from posterior samples + traces.split_warmup_groups() + + # Set sampling time + traces.sampling_time = t_sampling + + # Compute number of actual draws per chain + total_draws_per_chain = traces._sampling_state.draw_idx[:] + n_chains = len(traces.straces) + desired_tune = traces.tuning_steps + desired_draw = len(traces.posterior.draw) + tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune) + draws_per_chain = total_draws_per_chain - tuning_steps_per_chain + + total_n_tune = tuning_steps_per_chain.sum() + total_draws = draws_per_chain.sum() + + _log.info( + f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " + f"took {t_sampling:.0f} seconds." + ) + + if compute_convergence_checks or return_inferencedata: + idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples) + log_likelihood = idata_kwargs.pop("log_likelihood", False) + if log_likelihood: + from pymc.stats.log_density import compute_log_likelihood + + idata = compute_log_likelihood( + idata, + var_names=None if log_likelihood is True else log_likelihood, + extend_inferencedata=True, + model=model, + sample_dims=["chain", "draw"], + progressbar=False, + ) + if compute_convergence_checks: + warns = run_convergence_checks(idata, model) + for warn in warns: + traces._sampling_state.global_warnings.append(np.array([warn])) + log_warnings(warns) + + if return_inferencedata: + # By default we drop the "warning" stat which contains `SamplerWarning` + # objects that can not be stored with `.to_netcdf()`. + if not keep_warning_stat: + return drop_warning_stat(idata) + return idata + return traces + # Pick and slice chains to keep the maximum number of samples if discard_tuned_samples: traces, length = _choose_chains(traces, tune) else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: @@ -1212,6 +1278,8 @@ def _iter_sample( step.set_rng(rng) point = start + if isinstance(trace, ZarrChain): + trace.link_stepper(step) try: step.tune = bool(tune) @@ -1233,13 +1301,14 @@ def _iter_sample( ) yield diverging - except KeyboardInterrupt: - trace.close() - raise - except BaseException: + except (KeyboardInterrupt, BaseException): + if isinstance(trace, ZarrChain): + trace.record_sampling_state(step=step) trace.close() raise else: + if isinstance(trace, ZarrChain): + trace.record_sampling_state(step=step) trace.close() @@ -1298,6 +1367,19 @@ def _mp_sample( # We did draws += tune in pm.sample draws -= tune + zarr_chains: list[ZarrChain] | None = None + zarr_recording = False + if all(isinstance(trace, ZarrChain) for trace in traces): + if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore): + warnings.warn( + "Parallel sampling with MemoryStore zarr store wont write the processes " + "step method sampling state. If you wish to be able to access the step " + "method sampling state, please use a different storage backend, e.g. " + "DirectoryStore or ZipStore" + ) + else: + zarr_chains = cast(list[ZarrChain], traces) + zarr_recording = True sampler = ps.ParallelSampler( draws=draws, @@ -1311,13 +1393,16 @@ def _mp_sample( progressbar_theme=progressbar_theme, blas_cores=blas_cores, mp_ctx=mp_ctx, + zarr_chains=zarr_chains, ) try: try: with sampler: for draw in sampler: strace = traces[draw.chain] - strace.record(draw.point, draw.stats) + if not zarr_recording: + # Zarr recording happens in each process + strace.record(draw.point, draw.stats) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 67417e0d8f1..28e74d5e8ae 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -22,6 +22,7 @@ from collections import namedtuple from collections.abc import Sequence +from typing import cast import cloudpickle import numpy as np @@ -31,6 +32,7 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits +from pymc.backends.zarr import ZarrChain from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( @@ -104,13 +106,27 @@ def __init__( tune: int, rng_state: RandomGeneratorState, blas_cores, + chain: int, + zarr_chains: list[ZarrChain] | bytes | None = None, + zarr_chains_is_pickled: bool = False, ): - # For some strange reason, spawn multiprocessing doesn't copy the rng - # seed sequence, so we have to rebuild it from scratch + # Because of https://github.com/numpy/numpy/issues/27727, we can't send + # the rng instance to the child process because pickling (copying) looses + # the seed sequence state information. For this reason, we send a + # RandomGeneratorState instead. rng = random_generator_from_state(rng_state) self._msg_pipe = msg_pipe self._step_method = step_method self._step_method_is_pickled = step_method_is_pickled + self.chain = chain + self._zarr_recording = False + self._zarr_chain: ZarrChain | None = None + if zarr_chains_is_pickled: + self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain] + elif zarr_chains is not None: + self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain] + self._zarr_recording = self._zarr_chain is not None + self._shared_point = shared_point self._rng = rng self._draws = draws @@ -135,6 +151,7 @@ def run(self): # We do not create this in __init__, as pickling this # would destroy the shared memory. self._unpickle_step_method() + self._link_step_to_zarrchain() self._point = self._make_numpy_refs() self._start_loop() except KeyboardInterrupt: @@ -148,6 +165,10 @@ def run(self): finally: self._msg_pipe.close() + def _link_step_to_zarrchain(self): + if self._zarr_recording: + self._zarr_chain.link_stepper(self._step_method) + def _wait_for_abortion(self): while True: msg = self._recv_msg() @@ -170,6 +191,7 @@ def _recv_msg(self): return self._msg_pipe.recv() def _start_loop(self): + zarr_recording = self._zarr_recording self._step_method.set_rng(self._rng) draw = 0 @@ -199,6 +221,8 @@ def _start_loop(self): if msg[0] == "abort": raise KeyboardInterrupt() elif msg[0] == "write_next": + if zarr_recording: + self._zarr_chain.record(point, stats) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) @@ -225,6 +249,8 @@ def __init__( start: dict[str, np.ndarray], blas_cores, mp_ctx, + zarr_chains: list[ZarrChain] | None = None, + zarr_chains_pickled: bytes | None = None, ): self.chain = chain process_name = f"worker_chain_{chain}" @@ -247,6 +273,16 @@ def __init__( self._readable = True self._num_samples = 0 + zarr_chains_send: list[ZarrChain] | bytes | None = None + if zarr_chains_pickled is not None: + zarr_chains_send = zarr_chains_pickled + elif zarr_chains is not None: + if mp_ctx.get_start_method() == "spawn": + raise ValueError( + "please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'" + ) + zarr_chains_send = zarr_chains + if step_method_pickled is not None: step_method_send = step_method_pickled else: @@ -270,6 +306,9 @@ def __init__( tune, get_state_from_generator(rng), blas_cores, + self.chain, + zarr_chains_send, + zarr_chains_pickled is not None, ), ) self._process.start() @@ -392,6 +431,7 @@ def __init__( progressbar_theme: Theme | None = default_progress_theme, blas_cores: int | None = None, mp_ctx=None, + zarr_chains: list[ZarrChain] | None = None, ): if any(len(arg) != chains for arg in [rngs, start_points]): raise ValueError(f"Number of rngs and start_points must be {chains}.") @@ -412,8 +452,15 @@ def __init__( mp_ctx = multiprocessing.get_context(mp_ctx) step_method_pickled = None + zarr_chains_pickled = None + self.zarr_recording = False + if zarr_chains is not None: + assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains) + self.zarr_recording = True if mp_ctx.get_start_method() != "fork": step_method_pickled = cloudpickle.dumps(step_method, protocol=-1) + if zarr_chains is not None: + zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1) self._samplers = [ ProcessAdapter( @@ -426,6 +473,8 @@ def __init__( start, blas_cores, mp_ctx, + zarr_chains=zarr_chains, + zarr_chains_pickled=zarr_chains_pickled, ) for chain, rng, start in zip(range(chains), rngs, start_points) ] diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 4e5a2299601..b8a7ba593a7 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -27,6 +27,7 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from pymc.backends.base import BaseTrace +from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext from pymc.stats.convergence import log_warning_stats @@ -36,6 +37,7 @@ PopulationArrayStepShared, StatsType, ) +from pymc.step_methods.compound import StepMethodState from pymc.step_methods.metropolis import DEMetropolis from pymc.util import CustomProgress @@ -81,6 +83,11 @@ def _sample_population( Show progress bars? (defaults to True) parallelize : bool Setting for multiprocess parallelization + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. """ warn_population_size( step=step, @@ -263,6 +270,9 @@ def _run_secondary(c, stepper_dumps, secondary_end, task, progress): # receiving a None is the signal to exit if incoming is None: break + elif incoming == "sampling_state": + secondary_end.send((c, stepper.sampling_state)) + continue tune_stop, population = incoming if tune_stop: stepper.stop_tuning() @@ -307,6 +317,14 @@ def step(self, tune_stop: bool, population) -> list[tuple[PointType, StatsType]] updates.append(self._steppers[c].step(population[c])) return updates + def request_sampling_state(self, chain) -> StepMethodState: + if self.is_parallelized: + self._primary_ends[chain].send(("sampling_state",)) + _, sampling_state = self._primary_ends[chain].recv() + else: + sampling_state = self._steppers[chain].sampling_state + return sampling_state + def _prepare_iter_population( *, @@ -332,6 +350,11 @@ def _prepare_iter_population( Start points for each chain parallelize : bool Setting for multiprocess parallelization + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. tune : int Number of iterations to tune. rngs: sequence of random Generators @@ -411,8 +434,11 @@ def _iter_population( the helper object for (parallelized) stepping of chains steppers : list The step methods for each chain - traces : list - Traces for each chain + traces : Sequence[BaseTrace] + A sequences of chain traces where the sampling results will be stored. Can be + a sequence of :py:class:`~pymc.backends.ndarray.NDArray`, + :py:class:`~pymc.backends.mcbackend.ChainRecordAdapter`, or + :py:class:`~pymc.backends.zarr.ZarrChain`. points : list population of chain states @@ -432,8 +458,11 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - strace.record(points[c], stats) + flushed = strace.record(points[c], stats) log_warning_stats(stats) + if flushed and isinstance(strace, ZarrChain): + sampling_state = popstep.request_sampling_state(c) + strace.store_sampling_state(sampling_state) # yield the state of all chains in parallel yield i except KeyboardInterrupt: diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d0393afd570..1fcb3d2673f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -22,6 +22,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence +from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -96,6 +97,7 @@ def infer_warn_stats_info( @dataclass_state class StepMethodState(DataClassState): + var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True}) rng: RandomGeneratorState diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index e24276cf143..ec7bbbae483 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from dataclasses import Field, dataclass, fields +from dataclasses import MISSING, Field, dataclass, fields from typing import Any, ClassVar import numpy as np @@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState: state_class = self._state_class kwargs = {} for field in fields(state_class): - val = getattr(self, field.name) + is_tensor_name = field.metadata.get("tensor_name", False) + val: Any + if is_tensor_name: + val = [var.name for var in getattr(self, "vars")] + else: + val = getattr(self, field.name, field.default) + if val is MISSING: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {field.name!r}" + ) _val: Any if isinstance(val, WithSamplingState): _val = val.sampling_state @@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState): state, state_class ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" for field in fields(state_class): + is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) if isinstance(state_val, RandomGeneratorState): state_val = random_generator_from_state(state_val) - self_val = getattr(self, field.name) is_frozen = field.metadata.get("frozen", False) + self_val: Any + if is_tensor_name: + self_val = [var.name for var in getattr(self, "vars")] + assert is_frozen + else: + self_val = getattr(self, field.name, field.default) if is_frozen: if not equal_dataclass_values(state_val, self_val): raise ValueError( diff --git a/pymc/util.py b/pymc/util.py index 8a059d7e0d6..63576676eb2 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import re import warnings from collections import namedtuple @@ -276,7 +277,12 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: nidata = arviz.InferenceData(attrs=idata.attrs) for gname, group in idata.items(): if "sample_stat" in gname: - group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore") + warning_vars = [ + name + for name in group.data_vars + if name == "warning" or re.match(r"sampler_\d+__warning", str(name)) + ] + group = group.drop_vars(names=[*warning_vars, "warning_dim_0"], errors="ignore") nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims) return nidata diff --git a/requirements-dev.txt b/requirements-dev.txt index 56f7f964fcf..e7e3644aae6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,4 @@ threadpoolctl>=3.1.0 types-cachetools typing-extensions>=3.7.4 watermark +zarr>=2.5.0,<3 diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py new file mode 100644 index 00000000000..32f508ef1ae --- /dev/null +++ b/tests/backends/test_zarr.py @@ -0,0 +1,538 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +from dataclasses import asdict + +import numpy as np +import pytest +import xarray as xr +import zarr + +from arviz import InferenceData + +import pymc as pm + +from pymc.backends.zarr import ZarrTrace +from pymc.stats.convergence import SamplerWarning +from pymc.step_methods import NUTS, CompoundStep, Metropolis +from pymc.step_methods.state import equal_dataclass_values +from tests.helpers import equal_sampling_states + + +@pytest.fixture(scope="module") +def model(): + time_int = np.array([np.timedelta64(np.timedelta64(i, "h"), "ns") for i in range(25)]) + coords = { + "dim_int": range(3), + "dim_str": ["A", "B"], + "dim_time": np.datetime64("2024-10-16") + time_int, + "dim_interval": time_int, + } + rng = np.random.default_rng(42) + with pm.Model(coords=coords) as model: + data1 = pm.Data("data1", np.ones(3, dtype="bool"), dims=["dim_int"]) + data2 = pm.Data("data2", np.ones(3, dtype="bool")) + time = pm.Data("time", time_int / np.timedelta64(1, "h"), dims="dim_time") + + a = pm.Normal("a", shape=(len(coords["dim_int"]), len(coords["dim_str"]))) + b = pm.Normal("b", dims=["dim_int", "dim_str"]) + c = pm.Deterministic("c", a + b, dims=["dim_int", "dim_str"]) + + d = pm.LogNormal("d", dims="dim_time") + e = pm.Deterministic("e", (time + d)[:, None] + c[0], dims=["dim_interval", "dim_str"]) + + obs = pm.Normal( + "obs", + mu=e, + observed=rng.normal(size=(len(coords["dim_time"]), len(coords["dim_str"]))), + dims=["dim_time", "dim_str"], + ) + + return model + + +@pytest.fixture(params=["include_transformed", "discard_transformed"]) +def include_transformed(request): + return request.param == "include_transformed" + + +@pytest.fixture(params=["frequent_writes", "sparse_writes"]) +def draws_per_chunk(request): + spec = { + "frequent_writes": 1, + "sparse_writes": 7, + } + return spec[request.param] + + +@pytest.fixture(params=["single_step", "compound_step"]) +def model_step(request, model): + rng = np.random.default_rng(42) + with model: + if request.param == "single_step": + step = NUTS(rng=rng) + else: + rngs = rng.spawn(2) + step = CompoundStep( + [ + Metropolis(vars=model["a"], rng=rngs[0]), + NUTS(vars=[rv for rv in model.value_vars if rv.name != "a"], rng=rngs[1]), + ] + ) + return step + + +def test_record(model, model_step, include_transformed, draws_per_chunk): + store = zarr.TempStore() + trace = ZarrTrace( + store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + draws = 5 + tune = 5 + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + # Assert that init was successful + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "constant_data", + "observed_data", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Record samples from the ZarrChain + manually_collected_warmup_draws = [] + manually_collected_warmup_stats = [] + manually_collected_draws = [] + manually_collected_stats = [] + point = model.initial_point() + for draw in range(tune + draws): + tuning = draw < tune + if not tuning: + model_step.stop_tuning() + point, stats = model_step.step(point) + if tuning: + manually_collected_warmup_draws.append(point) + manually_collected_warmup_stats.append(stats) + else: + manually_collected_draws.append(point) + manually_collected_stats.append(stats) + trace.straces[0].record(point, stats) + trace.straces[0].record_sampling_state(model_step) + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Assert split warmup + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "warmup_sample_stats", + "warmup_posterior", + "constant_data", + "observed_data", + } + if include_transformed: + trace.split_warmup("unconstrained_posterior") + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + # trace.consolidate() + + # Assert observed data is correct + assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"} + assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"] + np.testing.assert_array_equal(trace.observed_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.observed_data.dim_str[:], model.coords["dim_str"]) + + # Assert constant data is correct + assert set(dict(trace.constant_data.arrays())) == { + "data1", + "data2", + "data2_dim_0", + "time", + "dim_time", + "dim_int", + } + assert list(trace.constant_data.data1.attrs["_ARRAY_DIMENSIONS"]) == ["dim_int"] + assert list(trace.constant_data.data2.attrs["_ARRAY_DIMENSIONS"]) == ["data2_dim_0"] + assert list(trace.constant_data.time.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time"] + np.testing.assert_array_equal(trace.constant_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.constant_data.dim_int[:], model.coords["dim_int"]) + + # Assert unconstrained posterior has correct shapes and kinds + assert {rv.name for rv in model.free_RVs + model.deterministics} <= set( + dict(trace.posterior.arrays()) + ) + if include_transformed: + assert {"d_log__", "chain", "draw", "d_log___dim_0"} == set( + dict(trace.unconstrained_posterior.arrays()) + ) + assert list(trace.unconstrained_posterior.d_log__.attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + "d_log___dim_0", + ] + assert trace.unconstrained_posterior.d_log__.attrs["kind"] == "freeRV" + np.testing.assert_array_equal(trace.unconstrained_posterior.chain, np.arange(1)) + np.testing.assert_array_equal(trace.unconstrained_posterior.draw, np.arange(draws)) + np.testing.assert_array_equal( + trace.unconstrained_posterior.d_log___dim_0, np.arange(len(model.coords["dim_time"])) + ) + + # Assert posterior has correct shapes and kinds + posterior_dims = set() + for kind, rv_name in [ + (kind, rv.name) + for kind, rv in itertools.chain( + itertools.zip_longest([], model.free_RVs, fillvalue="freeRV"), + itertools.zip_longest([], model.deterministics, fillvalue="deterministic"), + ) + ]: + if rv_name == "a": + expected_dims = ["a_dim_0", "a_dim_1"] + else: + expected_dims = model.named_vars_to_dims[rv_name] + posterior_dims |= set(expected_dims) + assert list(trace.posterior[rv_name].attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + *expected_dims, + ] + assert trace.posterior[rv_name].attrs["kind"] == kind + for posterior_dim in posterior_dims: + try: + model_coord = model.coords[posterior_dim] + except KeyError: + model_coord = { + "a_dim_0": np.arange(len(model.coords["dim_int"])), + "a_dim_1": np.arange(len(model.coords["dim_str"])), + "chain": np.arange(1), + "draw": np.arange(draws), + }[posterior_dim] + np.testing.assert_array_equal(trace.posterior[posterior_dim][:], model_coord) + + # Assert sample stats have correct shape + stats_bijection = trace.straces[0].stats_bijection + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var in trace.posterior.arrays(): + assert np.array_equal(trace.posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert manually collected warmup samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_warmup_draws, manually_collected_warmup_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["warmup_unconstrained_posterior"] + else: + posterior = trace.root["warmup_posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["warmup_sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert manually collected posterior samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["unconstrained_posterior"] + else: + posterior = trace.root["posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{var} value does not match: {stat_val} != {value}") + + # Assert sampling_state is correct + assert list(trace._sampling_state.draw_idx[:]) == [draws + tune] + assert equal_sampling_states( + trace._sampling_state.sampling_state[0], + model_step.sampling_state, + ) + + # Assert to inference data returns the expected groups + idata = trace.to_inferencedata(save_warmup=True) + expected_groups = { + "posterior", + "constant_data", + "observed_data", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert set(idata.groups()) == expected_groups + for group in idata.groups(): + for name, value in itertools.chain( + idata[group].data_vars.items(), idata[group].coords.items() + ): + try: + array = getattr(trace, group)[name][:] + except AttributeError: + array = trace.root[group][name][:] + if "sample_stats" in group and "warning" in name: + continue + np.testing.assert_array_equal(array, value) + + +@pytest.mark.parametrize("tune", [0, 5, 10]) +def test_split_warmup(tune, model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 10 - tune + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + assert len(trace.root.posterior.draw) == draws + assert len(trace.root.sample_stats.draw) == draws + if tune == 0: + with pytest.raises(KeyError): + trace.root["warmup_posterior"] + else: + assert len(trace.root["warmup_posterior"].draw) == tune + assert len(trace.root["warmup_sample_stats"].draw) == tune + + with pytest.raises(RuntimeError): + trace.split_warmup("posterior") + + for var_name, posterior_array in trace.posterior.arrays(): + dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert posterior_array.shape[1] == draws + assert trace.root["warmup_posterior"][var_name].shape[1] == tune + for var_name, sample_stats_array in trace.sample_stats.arrays(): + dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert sample_stats_array.shape[1] == draws + assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune + + +@pytest.fixture(scope="function", params=["discard_tuning", "keep_tuning"]) +def discard_tuned_samples(request): + return request.param == "discard_tuning" + + +@pytest.fixture(scope="function", params=["return_idata", "return_zarr"]) +def return_inferencedata(request): + return request.param == "return_idata" + + +@pytest.fixture( + scope="function", params=[True, False], ids=["keep_warning_stat", "discard_warning_stat"] +) +def keep_warning_stat(request): + return request.param + + +@pytest.fixture( + scope="function", params=[True, False], ids=["parallel_sampling", "sequential_sampling"] +) +def parallel(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False], ids=["compute_loglike", "no_loglike"]) +def log_likelihood(request): + return request.param + + +def test_sample( + model, + model_step, + include_transformed, + discard_tuned_samples, + return_inferencedata, + keep_warning_stat, + parallel, + log_likelihood, + draws_per_chunk, +): + if not return_inferencedata and not log_likelihood: + pytest.skip( + reason="log_likelihood is only computed if an inference data object is returned" + ) + store = zarr.TempStore() + trace = ZarrTrace( + store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + tune = 2 + draws = 3 + if parallel: + chains = 2 + cores = 2 + else: + chains = 1 + cores = 1 + with model: + out_trace = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + trace=trace, + step=model_step, + discard_tuned_samples=discard_tuned_samples, + return_inferencedata=return_inferencedata, + keep_warning_stat=keep_warning_stat, + idata_kwargs={"log_likelihood": log_likelihood}, + ) + + if not return_inferencedata: + assert isinstance(out_trace, ZarrTrace) + assert out_trace.root.store is trace.root.store + else: + assert isinstance(out_trace, InferenceData) + + expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"} + if include_transformed: + expected_groups |= {"unconstrained_posterior"} + if not return_inferencedata or not discard_tuned_samples: + expected_groups |= {"warmup_posterior", "warmup_sample_stats"} + if include_transformed: + expected_groups |= {"warmup_unconstrained_posterior"} + if not return_inferencedata: + expected_groups |= {"_sampling_state"} + elif log_likelihood: + expected_groups |= {"log_likelihood"} + assert set(out_trace.groups()) == expected_groups + + if return_inferencedata: + warning_stat = ( + "sampler_1__warning" if isinstance(model_step, CompoundStep) else "sampler_0__warning" + ) + if keep_warning_stat: + assert warning_stat in out_trace.sample_stats + else: + assert warning_stat not in out_trace.sample_stats + + # Assert that all variables have non empty samples (not NaNs) + if return_inferencedata: + assert all( + (not np.any(np.isnan(v))) and v.shape[:2] == (chains, draws) + for v in out_trace.posterior.data_vars.values() + ) + else: + dimensions = {*model.coords, "a_dim_0", "a_dim_1", "chain", "draw"} + assert all( + (not np.any(np.isnan(v[:]))) and v.shape[:2] == (chains, draws) + for name, v in out_trace.posterior.arrays() + if name not in dimensions + ) + + # Assert that the trace has valid sampling state stored for each chain + for step_method_state in trace._sampling_state.sampling_state[:]: + # We have no access to the actual step method that was using by each chain in pymc.sample + # The best way to see if the step method state is valid is by trying to set + # the model_step sampling state to the one stored in the trace. + model_step.sampling_state = step_method_state + + +def test_sampling_consistency( + model, + model_step, + draws_per_chunk, +): + # Test that pm.sample will generate the same posterior and sampling state + # regardless of whether sampling was done in parallel or not. + store1 = zarr.TempStore() + parallel_trace = ZarrTrace( + store=store1, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + store2 = zarr.TempStore() + sequential_trace = ZarrTrace( + store=store2, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk + ) + tune = 2 + draws = 3 + chains = 2 + random_seed = 12345 + initial_step_state = model_step.sampling_state + with model: + parallel_idata = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=chains, + trace=parallel_trace, + step=model_step, + discard_tuned_samples=True, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": False}, + random_seed=random_seed, + ) + model_step.sampling_state = initial_step_state + sequential_idata = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=1, + trace=sequential_trace, + step=model_step, + discard_tuned_samples=True, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={"log_likelihood": False}, + random_seed=random_seed, + ) + for chain in range(chains): + assert equal_sampling_states( + parallel_trace._sampling_state.sampling_state[chain], + sequential_trace._sampling_state.sampling_state[chain], + ) + xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 41b068e0427..2330c043be6 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -909,3 +909,49 @@ def test_sample(self, seeded_test): np.testing.assert_allclose( x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1 ) + + +@pytest.fixture(scope="function", params=[None, "mcbackend", "zarr"]) +def trace_backend(request): + if request.param is None: + return None + elif request.param == "mcbackend": + try: + import mcbackend as mcb + except ImportError: + pytest.skip("Requires McBackend to be installed.") + return mcb.NumPyBackend() + elif request.param == "zarr": + try: + trace = pm.backends.zarr.ZarrTrace() + except RuntimeError: + pytest.skip("Requires zarr to be installed") + return trace + + +@pytest.fixture(scope="function", params=["FAST_COMPILE", "NUMBA", "JAX"]) +def pytensor_mode(request): + return request.param + + +def test_random_deterministics(trace_backend, pytensor_mode): + with pm.Model() as m: + x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero + pm.Deterministic("y", x + pm.Normal.dist()) + + if pytensor_mode == "JAX": + expected_warning = ( + "At the moment, it is not possible to set the random generator's key for " + "JAX linked functions. This means that the draws yielded by the random " + "variables that are requested by 'Deterministic' will not be reproducible." + ) + with pytest.warns(UserWarning, match=expected_warning): + with pytensor.config.change_flags(mode=pytensor_mode): + idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + assert not idata1.posterior.equals(idata2.posterior) + else: + with pytensor.config.change_flags(mode=pytensor_mode): + idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend) + assert idata1.posterior.equals(idata2.posterior)