Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show one progress bar per chain when sampling #7634

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f849206
One progress bar per chain when samplings
jessegrabowski Jan 6, 2025
b74f4fb
Add guard against divide by zero when computing draws per second
jessegrabowski Jan 6, 2025
06572c6
No more purple
jessegrabowski Jan 7, 2025
23d122f
Step samplers are responsible for setting up progress bars
jessegrabowski Jan 8, 2025
1c5b734
Fix typos
jessegrabowski Jan 8, 2025
959f073
Add progressbar defaults to BlockedStep ABC
jessegrabowski Jan 8, 2025
50394e3
pre-commit
jessegrabowski Jan 8, 2025
4945153
Only update NUTS divergence stats after tuning
jessegrabowski Jan 9, 2025
1cec794
Add `Elapsed` and `Remaining` columns
jessegrabowski Jan 10, 2025
a96d7bb
Remove green color when chain finishes
jessegrabowski Jan 23, 2025
d61ddf6
Create `ProgressManager` class to handle progress bars
jessegrabowski Jan 23, 2025
1e13cf9
Yield `stats` from `_iter_sample`
jessegrabowski Jan 23, 2025
28a80c1
Use `ProgressManager` in `_sample_many`
jessegrabowski Jan 23, 2025
345faff
pre-commit
jessegrabowski Jan 23, 2025
741cf36
Explicit case handling for `progressbar` argument
jessegrabowski Jan 23, 2025
f4ccbd5
Allow all permutations of arguments to progressbar
jessegrabowski Jan 23, 2025
9649d66
Appease mypy
jessegrabowski Jan 23, 2025
a629a97
Add True case
jessegrabowski Jan 23, 2025
e024991
Fix final count when `progress = "combined"`
jessegrabowski Jan 25, 2025
4e535d4
Update docstrings
jessegrabowski Jan 25, 2025
b9b0583
mypy + cleanup
jessegrabowski Jan 25, 2025
9de9930
Syntax error in typehint
jessegrabowski Jan 25, 2025
79d1248
Simplify progressbar choices, update docstring
jessegrabowski Jan 26, 2025
161d10c
Incorporate feedback
jessegrabowski Jan 27, 2025
b381e5d
Be verbose with progressbar settings
jessegrabowski Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
RunType: TypeAlias = Run
HAS_MCB = True
except ImportError:
TraceOrBackend = BaseTrace # type: ignore[misc]
TraceOrBackend = BaseTrace # type: ignore[assignment, misc]
RunType = type(None) # type: ignore[assignment, misc]


Expand Down
146 changes: 76 additions & 70 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
Expand Down Expand Up @@ -67,7 +65,8 @@
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
CustomProgress,
ProgressBarManager,
ProgressBarType,
RandomSeed,
RandomState,
_get_seeds_per_chain,
Expand Down Expand Up @@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:
else:
varnames = ", ".join(
[
get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name
get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[misc]
for v in s.vars
]
)
Expand Down Expand Up @@ -424,7 +423,7 @@ def sample(
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
Expand Down Expand Up @@ -456,7 +455,7 @@ def sample(
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
Expand Down Expand Up @@ -488,8 +487,8 @@ def sample(
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = None,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -539,11 +538,18 @@ def sample(
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
easy spawning of new independent random streams that are needed by the step methods.
progressbar : bool, optional default=True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
Only applicable to the pymc nuts sampler.
progressbar: bool or ProgressType, optional
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
for one of the following:
- "combined": A single progress bar that displays the total progress across all chains. Only timing
information is shown.
- "split": A separate progress bar for each chain. Only timing information is shown.
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
chains. Aggregate sample statistics are also displayed.
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
are also displayed.
If True, the default is "split+stats" is used.
step : function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
step methods for those variables will be assigned automatically. By default the NUTS step
Expand Down Expand Up @@ -709,6 +715,10 @@ def sample(
if isinstance(trace, list):
raise ValueError("Please use `var_names` keyword argument for partial traces.")

# progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
# ADVI initialization expect just a bool.
progress_bool = bool(progressbar)

model = modelcontext(model)
if not model.free_RVs:
raise SamplingError(
Expand Down Expand Up @@ -805,7 +815,7 @@ def joined_blas_limiter():
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
progressbar=progress_bool,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
Expand All @@ -824,7 +834,7 @@ def joined_blas_limiter():
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
progressbar=progress_bool,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
Expand Down Expand Up @@ -1138,34 +1148,44 @@ def _sample_many(
Step function
"""
initial_step_state = step.sampling_state
for i in range(chains):
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
start=start[i],
step=step,
trace=traces[i],
rng=rngs[i],
callback=callback,
**kwargs,
)
progress_manager = ProgressBarManager(
step_method=step,
chains=chains,
draws=draws - kwargs.get("tune", 0),
tune=kwargs.get("tune", 0),
progressbar=kwargs.get("progressbar", True),
progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme),
)

with progress_manager:
for i in range(chains):
step.sampling_state = initial_step_state
_sample(
draws=draws,
chain=i,
start=start[i],
step=step,
trace=traces[i],
rng=rngs[i],
callback=callback,
progress_manager=progress_manager,
**kwargs,
)
return


def _sample(
*,
chain: int,
progressbar: bool,
rng: np.random.Generator,
start: PointType,
draws: int,
step: Step,
trace: IBaseTrace,
tune: int,
model: Model | None = None,
progressbar_theme: Theme | None = default_progress_theme,
callback=None,
progress_manager: ProgressBarManager,
**kwargs,
) -> None:
"""Sample one chain (singleprocess).
Expand All @@ -1176,27 +1196,23 @@ def _sample(
----------
chain : int
Number of the chain that the samples will belong to.
progressbar : bool
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
random_seed : single random seed
random_seed : Generator
Single random seed
start : dict
Starting point in parameter space (or partial point)
draws : int
The number of samples to draw
step : function
Step function
step : Step
Step class instance used to generate samples.
trace
A chain backend to record draws and stats.
tune : int
Number of iterations to tune.
model : Model (optional if in ``with`` context)
progressbar_theme : Theme
Optional custom theme for the progress bar.
model : Model, optional
PyMC model. If None, the model is taken from the current context.
progress_manager: ProgressBarManager
Helper class used to handle progress bar styling and updates
"""
skip_first = kwargs.get("skip_first", 0)

sampling_gen = _iter_sample(
draws=draws,
step=step,
Expand All @@ -1208,32 +1224,19 @@ def _sample(
rng=rng,
callback=callback,
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
try:
for it, stats in enumerate(sampling_gen):
progress_manager.update(
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
)

with progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
progress.update(
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
if not progress_manager.combined_progress or chain == progress_manager.chains - 1:
progress_manager.update(
chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False
)
except KeyboardInterrupt:
pass

except KeyboardInterrupt:
pass


def _iter_sample(
Expand All @@ -1247,7 +1250,7 @@ def _iter_sample(
rng: np.random.Generator,
model: Model | None = None,
callback: SamplingIteratorCallback | None = None,
) -> Iterator[bool]:
) -> Iterator[list[dict[str, Any]]]:
"""Sample one chain with a generator (singleprocess).
Parameters
Expand All @@ -1270,8 +1273,8 @@ def _iter_sample(
Yields
------
diverging : bool
Indicates if the draw is divergent. Only available with some samplers.
stats : list of dict
Dictionary of statistics returned by step sampler
"""
draws = int(draws)

Expand All @@ -1293,22 +1296,25 @@ def _iter_sample(
step.iter_count = 0
if i == tune:
step.stop_tuning()

point, stats = step.step(point)
trace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") is True)

if callback is not None:
callback(
trace=trace,
draw=Draw(chain, i == draws, i, i < tune, stats, point),
)

yield diverging
yield stats

except (KeyboardInterrupt, BaseException):
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
trace.close()
raise

else:
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
Expand Down
47 changes: 12 additions & 35 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@
import cloudpickle
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits

from pymc.backends.zarr import ZarrChain
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import (
CustomProgress,
ProgressBarManager,
RandomGeneratorState,
default_progress_theme,
get_state_from_generator,
Expand Down Expand Up @@ -485,23 +483,14 @@ def __init__(
self._max_active = cores

self._in_context = False

self._progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
self._progress = ProgressBarManager(
step_method=step_method,
chains=chains,
draws=draws,
tune=tune,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
self._show_progress = progressbar
self._divergences = 0
self._completed_draws = 0
self._total_draws = chains * (draws + tune)
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
self._chains = chains

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand All @@ -516,32 +505,20 @@ def __iter__(self):
raise ValueError("Use ParallelSampler as context manager.")
self._make_active()

with self._progress as progress:
task = progress.add_task(
self._desc.format(self),
completed=self._completed_draws,
total=self._total_draws,
)

with self._progress:
while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats = draw
self._completed_draws += 1
if not tuning and stats and stats[0].get("diverging"):
self._divergences += 1
progress.update(
task,
completed=self._completed_draws,
total=self._total_draws,
description=self._desc.format(self),

self._progress.update(
chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
)

if is_last:
proc.join()
self._active.remove(proc)
self._finished.append(proc)
self._make_active()
progress.update(task, description=self._desc.format(self), refresh=True)

# We could also yield proc.shared_point_view directly,
# and only call proc.write_next() after the yield returns.
Expand Down
Loading
Loading