Skip to content

Commit

Permalink
Use ProgressManager in _sample_many
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 23, 2025
1 parent df55bc8 commit 8a28596
Showing 1 changed file with 44 additions and 53 deletions.
97 changes: 44 additions & 53 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
Expand Down Expand Up @@ -65,7 +63,7 @@
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
CustomProgress,
ProgressManager,
RandomSeed,
RandomState,
_get_seeds_per_chain,
Expand Down Expand Up @@ -1069,34 +1067,44 @@ def _sample_many(
Step function
"""
initial_step_state = step.sampling_state
for i in range(chains):
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
start=start[i],
step=step,
trace=traces[i],
rng=rngs[i],
callback=callback,
**kwargs,
)
progress_manager = ProgressManager(
step_method=step,
chains=chains,
draws=draws - kwargs.get("tune", 0),
tune=kwargs.get("tune", 0),
progressbar=kwargs.get("progressbar", True),
progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme),
)

with progress_manager:
for i in range(chains):
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
start=start[i],
step=step,
trace=traces[i],
rng=rngs[i],
callback=callback,
progress_manager=progress_manager,
**kwargs,
)
return


def _sample(
*,
chain: int,
progressbar: bool,
rng: np.random.Generator,
start: PointType,
draws: int,
step: Step,
trace: IBaseTrace,
tune: int,
model: Model | None = None,
progressbar_theme: Theme | None = default_progress_theme,
callback=None,
progress_manager: ProgressManager,
**kwargs,
) -> None:
"""Sample one chain (singleprocess).
Expand All @@ -1107,27 +1115,23 @@ def _sample(
----------
chain : int
Number of the chain that the samples will belong to.
progressbar : bool
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
random_seed : single random seed
random_seed : Generator
Single random seed
start : dict
Starting point in parameter space (or partial point)
draws : int
The number of samples to draw
step : function
Step function
step : Step
Step class instance used to generate samples.
trace
A chain backend to record draws and stats.
tune : int
Number of iterations to tune.
model : Model (optional if in ``with`` context)
progressbar_theme : Theme
Optional custom theme for the progress bar.
model : Model, optional
PyMC model. If None, the model is taken from the current context.
progress_manager: ProgressManager
Helper class used to handle progress bar styling and updates
"""
skip_first = kwargs.get("skip_first", 0)

sampling_gen = _iter_sample(
draws=draws,
step=step,
Expand All @@ -1139,32 +1143,19 @@ def _sample(
rng=rng,
callback=callback,
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
try:
for it, stats in enumerate(sampling_gen):
progress_manager.update(
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
)

with progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
progress.update(
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
if not progress_manager.combined_progress or chain == progress_manager.chains - 1:
progress_manager.update(
chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False
)
except KeyboardInterrupt:
pass

except KeyboardInterrupt:
pass


def _iter_sample(
Expand Down

0 comments on commit 8a28596

Please sign in to comment.