From f849206e7b44eec59cb9c12b818e4aee5ba36ad4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Mon, 6 Jan 2025 22:26:42 +0800 Subject: [PATCH 01/25] One progress bar per chain when samplings --- pymc/sampling/parallel.py | 96 +++++++++++++++++++++++++++++++-------- pymc/util.py | 95 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 168 insertions(+), 23 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 3c2a8c9a36..c2c6169a01 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -28,7 +28,9 @@ import numpy as np from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import TextColumn +from rich.style import Style +from rich.table import Column from rich.theme import Theme from threadpoolctl import threadpool_limits @@ -37,6 +39,7 @@ from pymc.exceptions import SamplingError from pymc.util import ( CustomProgress, + DivergenceBarColumn, RandomGeneratorState, default_progress_theme, get_state_from_generator, @@ -487,20 +490,35 @@ def __init__( self._in_context = False self._progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + diverging_finished_color="tab:purple", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(44,160,44)"), # tab:green + ), + TextColumn("{task.fields[draws]:,d}", table_column=Column("Draws", ratio=1)), + TextColumn( + "{task.fields[divergences]:,d}", table_column=Column("Divergences", ratio=1) + ), + TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), + TextColumn("{task.fields[tree_depth]:,d}", table_column=Column("Tree depth", ratio=1)), + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), console=Console(theme=progressbar_theme), disable=not progressbar, + include_headers=True, ) + self._show_progress = progressbar self._divergences = 0 + self._divergences_by_chain = [0] * chains self._completed_draws = 0 - self._total_draws = chains * (draws + tune) - self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" + self._completed_draws_by_chain = [0] * chains + self._total_draws = draws + tune + self._desc = "Sampling chain" self._chains = chains def _make_active(self): @@ -517,31 +535,71 @@ def __iter__(self): self._make_active() with self._progress as progress: - task = progress.add_task( - self._desc.format(self), - completed=self._completed_draws, - total=self._total_draws, - ) + tasks = [ + progress.add_task( + self._desc.format(self), + completed=self._completed_draws, + total=self._total_draws, + chain_idx=chain_idx, + draws=0, + divergences=0, + step_size=0.0, + tree_depth=0, + sampling_speed=0, + speed_unit="draws/s", + ) + for chain_idx in range(self._chains) + ] while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw + speed = 0 + unit = "draws/s" + self._completed_draws += 1 + self._completed_draws_by_chain[proc.chain] += 1 + if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 + self._divergences_by_chain[proc.chain] += 1 + + if self._show_progress: + elapsed = progress._tasks[proc.chain].elapsed + speed = self._completed_draws_by_chain[proc.chain] / elapsed + + if speed > 1: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + progress.update( - task, - completed=self._completed_draws, - total=self._total_draws, - description=self._desc.format(self), + tasks[proc.chain], + completed=self._completed_draws_by_chain[proc.chain], + draws=draw, + divergences=self._divergences_by_chain[proc.chain], + step_size=stats[0].get("step_size", 0), + tree_depth=stats[0].get("tree_size", 0), + sampling_speed=speed, + speed_unit=unit, ) if is_last: + self._completed_draws_by_chain[proc.chain] += 1 + proc.join() self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update(task, description=self._desc.format(self), refresh=True) + progress.update( + tasks[proc.chain], + draws=draw + 1, + divergences=self._divergences_by_chain[proc.chain], + step_size=stats[0].get("step_size", 0), + tree_depth=stats[0].get("tree_size", 0), + refresh=True, + ) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/util.py b/pymc/util.py index 8dc7d16804..eb04d0f2cf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -17,7 +17,7 @@ import warnings from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from copy import deepcopy from typing import NewType, cast @@ -30,7 +30,10 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad -from rich.progress import Progress +from rich.box import SIMPLE_HEAD +from rich.progress import BarColumn, Progress, Task +from rich.style import Style +from rich.table import Column, Table from rich.theme import Theme from pymc.exceptions import BlockModelAccessError @@ -556,8 +559,10 @@ class CustomProgress(Progress): it's `True`. """ - def __init__(self, *args, **kwargs): - self.is_enabled = kwargs.get("disable", None) is not True + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + if self.is_enabled: super().__init__(*args, **kwargs) @@ -607,6 +612,88 @@ def update( ) return None + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class DivergenceBarColumn(BarColumn): + def __init__(self, *args, diverging_color="red", diverging_finished_color="purple", **kwargs): + from matplotlib.colors import to_rgb + + self.diverging_color = diverging_color + self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] + + self.diverging_finished_color = diverging_finished_color + self.diverging_finished_rgb = [int(x * 255) for x in to_rgb(self.diverging_finished_color)] + + super().__init__(*args, **kwargs) + + self.non_diverging_style = self.complete_style + self.non_diverging_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + divergences = task.fields.get("divergences", 0) + if divergences > 0: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_finished_rgb)) + else: + self.complete_style = self.non_diverging_style + self.finished_style = self.non_diverging_finished_style + RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) From b74f4fbdb27365bbbd3e8691b2598961217adeb4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Mon, 6 Jan 2025 23:12:05 +0800 Subject: [PATCH 02/25] Add guard against divide by zero when computing draws per second --- pymc/sampling/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index c2c6169a01..b16fec8284 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -565,7 +565,7 @@ def __iter__(self): self._divergences_by_chain[proc.chain] += 1 if self._show_progress: - elapsed = progress._tasks[proc.chain].elapsed + elapsed = max(progress._tasks[proc.chain].elapsed, 1e-4) speed = self._completed_draws_by_chain[proc.chain] / elapsed if speed > 1: From 06572c649c1ae3379dbf831dcb6110927ff4783e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 7 Jan 2025 10:56:04 +0800 Subject: [PATCH 03/25] No more purple --- pymc/sampling/parallel.py | 1 - pymc/util.py | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index b16fec8284..a9d4370cb4 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -493,7 +493,6 @@ def __init__( DivergenceBarColumn( table_column=Column("Progress", ratio=2), diverging_color="tab:red", - diverging_finished_color="tab:purple", complete_style=Style.parse("rgb(31,119,180)"), # tab:blue finished_style=Style.parse("rgb(44,160,44)"), # tab:green ), diff --git a/pymc/util.py b/pymc/util.py index eb04d0f2cf..b9a1ebe469 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -671,15 +671,12 @@ def call_column(column, task): class DivergenceBarColumn(BarColumn): - def __init__(self, *args, diverging_color="red", diverging_finished_color="purple", **kwargs): + def __init__(self, *args, diverging_color="red", **kwargs): from matplotlib.colors import to_rgb self.diverging_color = diverging_color self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] - self.diverging_finished_color = diverging_finished_color - self.diverging_finished_rgb = [int(x * 255) for x in to_rgb(self.diverging_finished_color)] - super().__init__(*args, **kwargs) self.non_diverging_style = self.complete_style @@ -689,7 +686,7 @@ def callbacks(self, task: "Task"): divergences = task.fields.get("divergences", 0) if divergences > 0: self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_finished_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) else: self.complete_style = self.non_diverging_style self.finished_style = self.non_diverging_finished_style From 23d122f7120c2993ea3307c232a37f6f7c6a4b22 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 8 Jan 2025 18:13:21 +0800 Subject: [PATCH 04/25] Step samplers are responsible for setting up progress bars --- pymc/sampling/parallel.py | 79 ++++++++++----------------------- pymc/step_methods/compound.py | 32 +++++++++++++ pymc/step_methods/hmc/nuts.py | 31 +++++++++++++ pymc/step_methods/metropolis.py | 34 ++++++++++++++ pymc/step_methods/slicer.py | 29 ++++++++++++ pymc/util.py | 41 ++++++++++++++++- 6 files changed, 188 insertions(+), 58 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index a9d4370cb4..976b2ab075 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,10 +27,6 @@ import cloudpickle import numpy as np -from rich.console import Console -from rich.progress import TextColumn -from rich.style import Style -from rich.table import Column from rich.theme import Theme from threadpoolctl import threadpool_limits @@ -38,12 +34,13 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( - CustomProgress, - DivergenceBarColumn, - RandomGeneratorState, + compute_draw_speed, + create_progress_bar, default_progress_theme, + RandomGeneratorState, get_state_from_generator, random_generator_from_state, + ) logger = logging.getLogger(__name__) @@ -489,33 +486,21 @@ def __init__( self._in_context = False - self._progress = CustomProgress( - DivergenceBarColumn( - table_column=Column("Progress", ratio=2), - diverging_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(44,160,44)"), # tab:green - ), - TextColumn("{task.fields[draws]:,d}", table_column=Column("Draws", ratio=1)), - TextColumn( - "{task.fields[divergences]:,d}", table_column=Column("Divergences", ratio=1) - ), - TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), - TextColumn("{task.fields[tree_depth]:,d}", table_column=Column("Tree depth", ratio=1)), - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = create_progress_bar( + progress_columns, + progress_stats, + progressbar=progressbar, + progressbar_theme=progressbar_theme, ) + self.progress_stats = progress_stats + self.update_stats = step_method._make_update_stat_function() + self._show_progress = progressbar self._divergences = 0 - self._divergences_by_chain = [0] * chains self._completed_draws = 0 - self._completed_draws_by_chain = [0] * chains self._total_draws = draws + tune self._desc = "Sampling chain" self._chains = chains @@ -537,15 +522,13 @@ def __iter__(self): tasks = [ progress.add_task( self._desc.format(self), - completed=self._completed_draws, + completed=0, + draws=0, total=self._total_draws, chain_idx=chain_idx, - draws=0, - divergences=0, - step_size=0.0, - tree_depth=0, sampling_speed=0, speed_unit="draws/s", + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, ) for chain_idx in range(self._chains) ] @@ -553,40 +536,26 @@ def __iter__(self): while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw - speed = 0 - unit = "draws/s" self._completed_draws += 1 - self._completed_draws_by_chain[proc.chain] += 1 + + speed, unit = compute_draw_speed(progress._tasks[proc.chain].elapsed, draw) if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 - self._divergences_by_chain[proc.chain] += 1 - - if self._show_progress: - elapsed = max(progress._tasks[proc.chain].elapsed, 1e-4) - speed = self._completed_draws_by_chain[proc.chain] / elapsed - if speed > 1: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed + self.progress_stats = self.update_stats(self.progress_stats, stats, proc.chain) progress.update( tasks[proc.chain], - completed=self._completed_draws_by_chain[proc.chain], + completed=draw, draws=draw, - divergences=self._divergences_by_chain[proc.chain], - step_size=stats[0].get("step_size", 0), - tree_depth=stats[0].get("tree_size", 0), sampling_speed=speed, speed_unit=unit, + **{stat: value[proc.chain] for stat, value in self.progress_stats.items()}, ) if is_last: - self._completed_draws_by_chain[proc.chain] += 1 - proc.join() self._active.remove(proc) self._finished.append(proc) @@ -594,9 +563,7 @@ def __iter__(self): progress.update( tasks[proc.chain], draws=draw + 1, - divergences=self._divergences_by_chain[proc.chain], - step_size=stats[0].get("step_size", 0), - tree_depth=stats[0].get("tree_size", 0), + **{stat: value[proc.chain] for stat, value in self.progress_stats.items()}, refresh=True, ) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index ff3f9c66a5..794730b4f7 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -297,6 +297,38 @@ def set_rng(self, rng: RandomGenerator): for method, _rng in zip(self.methods, _rngs): method.set_rng(_rng) + def _progressbar_config(self, n_chains=1): + from functools import reduce + + column_lists, stat_dict_list = zip( + *[method._progressbar_config(n_chains) for method in self.methods] + ) + flat_list = reduce(lambda left_list, right_list: left_list + right_list, column_lists) + + columns = [] + headers = [] + + for col in flat_list: + name = col.get_table_column().header + if name not in headers: + headers.append(name) + columns.append(col) + + stats = reduce(lambda left_dict, right_dict: left_dict | right_dict, stat_dict_list) + + return columns, stats + + def _make_update_stat_function(self): + update_fns = [method._make_update_stats_function() for method in self.methods] + + def update_stats(stats, step_stats, chain_idx): + for step_stat, update_fn in zip(step_stats, update_fns): + stats = update_fn(stats, step_stat, chain_idx) + + return stats + + return update_stats + def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: """Flatten a hierarchy of step methods to a list.""" diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index bbda728e80..a99b1fe7c2 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -20,6 +20,8 @@ import numpy as np from pytensor import config +from rich.progress import TextColumn +from rich.table import Column from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence @@ -229,6 +231,35 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), + TextColumn("{task.fields[tree_size]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "step_size": [0] * n_chains, + "tree_size": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stat_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["divergences"][chain_idx] += step_stats["diverging"] + stats["step_size"][chain_idx] = step_stats["step_size"] + stats["tree_size"][chain_idx] = step_stats["tree_size"] + return stats + + return update_stats + # A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory") diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8e22218a13..70c650653d 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -24,6 +24,8 @@ from pytensor import tensor as pt from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV +from rich.progress import TextColumn +from rich.table import Column import pymc as pm @@ -325,6 +327,38 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: def competence(var, has_grad): return Competence.COMPATIBLE + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), + TextColumn( + "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) + ), + ] + + stats = { + "tune": [True] * n_chains, + "scaling": [0] * n_chains, + "accept_rate": [0.0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["accept_rate"][chain_idx] = step_stats["accept"] + stats["scaling"][chain_idx] = step_stats["scaling"] + + return stats + + return update_stats + def tune(scale, acc_rate): """ diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index ecc7967614..9c10acfdf4 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -17,6 +17,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.blocking import RaveledVars, StatsType from pymc.initial_point import PointType from pymc.model import modelcontext @@ -195,3 +198,29 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.COMPATIBLE return Competence.INCOMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), + TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), + ] + + stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + if isinstance(step_stats, list): + step_stats = step_stats[0] + + stats["tune"][chain_idx] = step_stats["tune"] + stats["nstep_out"][chain_idx] = step_stats["nstep_out"] + stats["nstep_in"][chain_idx] = step_stats["nstep_in"] + + return stats + + return update_stats diff --git a/pymc/util.py b/pymc/util.py index b9a1ebe469..80d1893759 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -31,7 +31,8 @@ from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad from rich.box import SIMPLE_HEAD -from rich.progress import BarColumn, Progress, Task +from rich.console import Console +from rich.progress import BarColumn, Progress, Task, TextColumn from rich.style import Style from rich.table import Column, Table from rich.theme import Theme @@ -684,7 +685,7 @@ def __init__(self, *args, diverging_color="red", **kwargs): def callbacks(self, task: "Task"): divergences = task.fields.get("divergences", 0) - if divergences > 0: + if isinstance(divergences, float | int) and divergences > 0: self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) else: @@ -692,6 +693,42 @@ def callbacks(self, task: "Task"): self.finished_style = self.non_diverging_finished_style +def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + columns += step_columns + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ) + ] + + return CustomProgress( + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(44,160,44)"), # tab:green + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) + + +def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) From 1c5b734249731c54d4010da335a8012def53748e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 8 Jan 2025 18:17:33 +0800 Subject: [PATCH 05/25] Fix typos --- pymc/sampling/parallel.py | 2 +- pymc/step_methods/compound.py | 2 +- pymc/step_methods/hmc/nuts.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 976b2ab075..cd9e47cfe5 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -496,7 +496,7 @@ def __init__( ) self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stat_function() + self.update_stats = step_method._make_update_stats_function() self._show_progress = progressbar self._divergences = 0 diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 794730b4f7..a66417ece3 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -318,7 +318,7 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stat_function(self): + def _make_update_stats_function(self): update_fns = [method._make_update_stats_function() for method in self.methods] def update_stats(stats, step_stats, chain_idx): diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index a99b1fe7c2..e3652cb4ed 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -248,7 +248,7 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stat_function(): + def _make_update_stats_function(): def update_stats(stats, step_stats, chain_idx): if isinstance(step_stats, list): step_stats = step_stats[0] From 959f073c3c7c83094b84937c899a86edb06e2cff Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 8 Jan 2025 18:29:17 +0800 Subject: [PATCH 06/25] Add progressbar defaults to BlockedStep ABC --- pymc/step_methods/compound.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index a66417ece3..d07b070f0f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -181,6 +181,20 @@ def __new__(cls, *args, **kwargs): step.__newargs = (vars, *args), kwargs return step + @staticmethod + def _progressbar_config(n_chains=1): + columns = [] + stats = {} + + return columns, stats + + @staticmethod + def _make_update_stats_function(): + def update_stats(stats, step_stats, chain_idx): + return stats + + return update_stats + # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): return self.__newargs From 50394e3b50c0d5aefb4312c7c50025d90b88e0e8 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 8 Jan 2025 19:13:00 +0800 Subject: [PATCH 07/25] pre-commit --- pymc/sampling/parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index cd9e47cfe5..e7a9f79f48 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -34,13 +34,12 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( + RandomGeneratorState, compute_draw_speed, create_progress_bar, default_progress_theme, - RandomGeneratorState, get_state_from_generator, random_generator_from_state, - ) logger = logging.getLogger(__name__) From 4945153de331e4000762ee153c01f55bf669c823 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 9 Jan 2025 14:16:12 +0800 Subject: [PATCH 08/25] Only update NUTS divergence stats after tuning --- pymc/step_methods/hmc/nuts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index e3652cb4ed..18707c3592 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -253,7 +253,9 @@ def update_stats(stats, step_stats, chain_idx): if isinstance(step_stats, list): step_stats = step_stats[0] - stats["divergences"][chain_idx] += step_stats["diverging"] + if not step_stats["tune"]: + stats["divergences"][chain_idx] += step_stats["diverging"] + stats["step_size"][chain_idx] = step_stats["step_size"] stats["tree_size"][chain_idx] = step_stats["tree_size"] return stats From 1cec794f25c46ca8586ba0a31b55f1bc5b6b8133 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 10 Jan 2025 12:31:43 +0800 Subject: [PATCH 09/25] Add `Elapsed` and `Remaining` columns --- pymc/sampling/parallel.py | 2 +- pymc/util.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index e7a9f79f48..ba7a85c8ae 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -523,7 +523,7 @@ def __iter__(self): self._desc.format(self), completed=0, draws=0, - total=self._total_draws, + total=self._total_draws - 1, chain_idx=chain_idx, sampling_speed=0, speed_unit="draws/s", diff --git a/pymc/util.py b/pymc/util.py index 80d1893759..8940014bb1 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -32,7 +32,14 @@ from pytensor.graph.utils import ValidatingScratchpad from rich.box import SIMPLE_HEAD from rich.console import Console -from rich.progress import BarColumn, Progress, Task, TextColumn +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from rich.style import Style from rich.table import Column, Table from rich.theme import Theme @@ -59,6 +66,8 @@ def __getattr__(name): { "bar.complete": "#1764f4", "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", } ) @@ -700,7 +709,9 @@ def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_t TextColumn( "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", table_column=Column("Sampling Speed", ratio=1), - ) + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), ] return CustomProgress( From a96d7bbf39c7c6098a53a53304195d286b046619 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 18:00:00 +0800 Subject: [PATCH 10/25] Remove green color when chain finishes --- pymc/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 8940014bb1..514fc356fa 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -719,7 +719,7 @@ def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_t table_column=Column("Progress", ratio=2), diverging_color="tab:red", complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(44,160,44)"), # tab:green + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue ), *columns, console=Console(theme=progressbar_theme), From d61ddf692324310184a43ba72d7e84329c456043 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 19:52:22 +0800 Subject: [PATCH 11/25] Create `ProgressManager` class to handle progress bars --- pymc/sampling/parallel.py | 64 ++------------ pymc/util.py | 182 +++++++++++++++++++++++++++++++++----- 2 files changed, 167 insertions(+), 79 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index ba7a85c8ae..549d205e90 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -34,9 +34,8 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( + ProgressManager, RandomGeneratorState, - compute_draw_speed, - create_progress_bar, default_progress_theme, get_state_from_generator, random_generator_from_state, @@ -484,26 +483,15 @@ def __init__( self._max_active = cores self._in_context = False - - progress_columns, progress_stats = step_method._progressbar_config(chains) - - self._progress = create_progress_bar( - progress_columns, - progress_stats, + self._progress = ProgressManager( + step_method=step_method, + chains=chains, + draws=draws, + tune=tune, progressbar=progressbar, progressbar_theme=progressbar_theme, ) - self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stats_function() - - self._show_progress = progressbar - self._divergences = 0 - self._completed_draws = 0 - self._total_draws = draws + tune - self._desc = "Sampling chain" - self._chains = chains - def _make_active(self): while self._inactive and len(self._active) < self._max_active: proc = self._inactive.pop(0) @@ -517,41 +505,13 @@ def __iter__(self): raise ValueError("Use ParallelSampler as context manager.") self._make_active() - with self._progress as progress: - tasks = [ - progress.add_task( - self._desc.format(self), - completed=0, - draws=0, - total=self._total_draws - 1, - chain_idx=chain_idx, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, - ) - for chain_idx in range(self._chains) - ] - + with self._progress: while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw - self._completed_draws += 1 - - speed, unit = compute_draw_speed(progress._tasks[proc.chain].elapsed, draw) - - if not tuning and stats and stats[0].get("diverging"): - self._divergences += 1 - - self.progress_stats = self.update_stats(self.progress_stats, stats, proc.chain) - - progress.update( - tasks[proc.chain], - completed=draw, - draws=draw, - sampling_speed=speed, - speed_unit=unit, - **{stat: value[proc.chain] for stat, value in self.progress_stats.items()}, + self._progress.update( + chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats ) if is_last: @@ -559,12 +519,6 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update( - tasks[proc.chain], - draws=draw + 1, - **{stat: value[proc.chain] for stat, value in self.progress_stats.items()}, - refresh=True, - ) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/util.py b/pymc/util.py index 514fc356fa..3d727d78c4 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -702,30 +702,164 @@ def callbacks(self, task: "Task"): self.finished_style = self.non_diverging_finished_style -def create_progress_bar(step_columns, init_stat_dict, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] - columns += step_columns - columns += [ - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), - TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), - ] - - return CustomProgress( - DivergenceBarColumn( - table_column=Column("Progress", ratio=2), - diverging_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(31,119,180)"), # tab:blue - ), - *columns, - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, - ) +class ProgressManager: + def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme): + mode = "chain" + stats = "full" + + if isinstance(progressbar, bool): + show_progress = progressbar + else: + show_progress = True + + if "+" in progressbar: + mode, stats = progressbar.split("+") + else: + mode = progressbar + stats = "full" + + if mode not in ["chain", "combined"]: + raise ValueError('Invalid mode. Valid values are "chain" and "combined"') + if stats not in ["full", "simple"]: + raise ValueError('Invalid stats. Valid values are "full" and "simple"') + + progress_columns, progress_stats = step_method._progressbar_config(chains) + self.combined_progress = mode == "combined" + self.full_stats = stats == "full" + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + + self.progress_stats = progress_stats + self.update_stats = step_method._make_update_stats_function() + + self._show_progress = show_progress + self.divergences = 0 + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + def compute_draw_speed(self, chain_idx, draws): + elapsed = self._progress.tasks[chain_idx].elapsed + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + speed, unit = self.compute_draw_speed(chain_idx, draw) + + if not tuning and stats and stats[0].get("diverging"): + self.divergences += 1 + + self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) + more_updates = ( + {stat: value[chain_idx] for stat, value in self.progress_stats.items()} + if self.full_stats + else {} + ) + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + **more_updates, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw - 1, + **more_updates, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) def compute_draw_speed(elapsed, draws): From 1e13cf923c5537bca913748461cbdc0d2a9e1b27 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 19:53:12 +0800 Subject: [PATCH 12/25] Yield `stats` from `_iter_sample` --- pymc/sampling/mcmc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 6fb80284fd..395904bda5 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1293,22 +1293,25 @@ def _iter_sample( step.iter_count = 0 if i == tune: step.stop_tuning() + point, stats = step.step(point) trace.record(point, stats) log_warning_stats(stats) - diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") is True) + if callback is not None: callback( trace=trace, draw=Draw(chain, i == draws, i, i < tune, stats, point), ) - yield diverging + yield stats + except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) trace.close() raise + else: if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) From 28a80c1832841fe827f1b80bc3ce4ceb2864644f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 19:53:55 +0800 Subject: [PATCH 13/25] Use `ProgressManager` in `_sample_many` --- pymc/sampling/mcmc.py | 97 ++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 395904bda5..661099b942 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -36,8 +36,6 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs from pytensor.graph.basic import Variable -from rich.console import Console -from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol @@ -67,7 +65,7 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - CustomProgress, + ProgressManager, RandomSeed, RandomState, _get_seeds_per_chain, @@ -1138,25 +1136,35 @@ def _sample_many( Step function """ initial_step_state = step.sampling_state - for i in range(chains): - step.sampling_state = initial_step_state - _sample( - draws=draws, - chain=i, - start=start[i], - step=step, - trace=traces[i], - rng=rngs[i], - callback=callback, - **kwargs, - ) + progress_manager = ProgressManager( + step_method=step, + chains=chains, + draws=draws - kwargs.get("tune", 0), + tune=kwargs.get("tune", 0), + progressbar=kwargs.get("progressbar", True), + progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme), + ) + + with progress_manager: + for i in range(chains): + step.sampling_state = initial_step_state + _sample( + draws=draws, + chain=i, + start=start[i], + step=step, + trace=traces[i], + rng=rngs[i], + callback=callback, + progress_manager=progress_manager, + **kwargs, + ) return def _sample( *, chain: int, - progressbar: bool, rng: np.random.Generator, start: PointType, draws: int, @@ -1164,8 +1172,8 @@ def _sample( trace: IBaseTrace, tune: int, model: Model | None = None, - progressbar_theme: Theme | None = default_progress_theme, callback=None, + progress_manager: ProgressManager, **kwargs, ) -> None: """Sample one chain (singleprocess). @@ -1176,27 +1184,23 @@ def _sample( ---------- chain : int Number of the chain that the samples will belong to. - progressbar : bool - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - random_seed : single random seed + random_seed : Generator + Single random seed start : dict Starting point in parameter space (or partial point) draws : int The number of samples to draw - step : function - Step function + step : Step + Step class instance used to generate samples. trace A chain backend to record draws and stats. tune : int Number of iterations to tune. - model : Model (optional if in ``with`` context) - progressbar_theme : Theme - Optional custom theme for the progress bar. + model : Model, optional + PyMC model. If None, the model is taken from the current context. + progress_manager: ProgressManager + Helper class used to handle progress bar styling and updates """ - skip_first = kwargs.get("skip_first", 0) - sampling_gen = _iter_sample( draws=draws, step=step, @@ -1208,32 +1212,19 @@ def _sample( rng=rng, callback=callback, ) - _pbar_data = {"chain": chain, "divergences": 0} - _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - - progress = CustomProgress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeRemainingColumn(), - TextColumn("/"), - TimeElapsedColumn(), - console=Console(theme=progressbar_theme), - disable=not progressbar, - ) + try: + for it, stats in enumerate(sampling_gen): + progress_manager.update( + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + ) - with progress: - try: - task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) - for it, diverging in enumerate(sampling_gen): - if it >= skip_first and diverging: - _pbar_data["divergences"] += 1 - progress.update(task, description=_desc.format(**_pbar_data), completed=it) - progress.update( - task, description=_desc.format(**_pbar_data), completed=draws, refresh=True + if not progress_manager.combined_progress or chain == progress_manager.chains - 1: + progress_manager.update( + chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False ) - except KeyboardInterrupt: - pass + + except KeyboardInterrupt: + pass def _iter_sample( From 345faffa8ae329f973b8a951774723f92fd8876c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 20:01:44 +0800 Subject: [PATCH 14/25] pre-commit --- pymc/sampling/mcmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 661099b942..60d4adf0a8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1296,13 +1296,13 @@ def _iter_sample( ) yield stats - + except (KeyboardInterrupt, BaseException): if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) trace.close() raise - + else: if isinstance(trace, ZarrChain): trace.record_sampling_state(step=step) From 741cf36b6c0e0de51b9ea8d14dbfb21c077aa62c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 23:49:08 +0800 Subject: [PATCH 15/25] Explicit case handling for `progressbar` argument --- pymc/util.py | 50 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index 3d727d78c4..8f41e47f38 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -704,28 +704,38 @@ def callbacks(self, task: "Task"): class ProgressManager: def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme): - mode = "chain" - stats = "full" - - if isinstance(progressbar, bool): - show_progress = progressbar - else: - show_progress = True - - if "+" in progressbar: - mode, stats = progressbar.split("+") - else: - mode = progressbar - stats = "full" - - if mode not in ["chain", "combined"]: - raise ValueError('Invalid mode. Valid values are "chain" and "combined"') - if stats not in ["full", "simple"]: - raise ValueError('Invalid stats. Valid values are "full" and "simple"') + self.combined_progress = False + self.full_stats = True + show_progress = True + + match progressbar: + case True: + show_progress = True + case False: + show_progress = False + case "combined": + self.combined_progress = True + case "chain": + self.combined_progress = False + case "combined+full": + self.combined_progress = True + self.full_stats = True + case "combined+simple": + self.combined_progress = True + self.full_stats = False + case "chain+full": + self.combined_progress = False + self.full_stats = True + case "chain+simple": + self.combined_progress = False + self.full_stats = False + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "or one of 'combined', 'chain', 'combined+full', 'combined+simple', 'chain+full', 'chain+simple'." + ) progress_columns, progress_stats = step_method._progressbar_config(chains) - self.combined_progress = mode == "combined" - self.full_stats = stats == "full" self._progress = self.create_progress_bar( progress_columns, From f4ccbd5d0358657d0d4095f0ec4daf91c4af8f97 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 23 Jan 2025 23:54:57 +0800 Subject: [PATCH 16/25] Allow all permutations of arguments to progressbar --- pymc/util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index 8f41e47f38..99f5db9d96 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -709,30 +709,32 @@ def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_th show_progress = True match progressbar: - case True: - show_progress = True case False: show_progress = False case "combined": self.combined_progress = True case "chain": self.combined_progress = False - case "combined+full": + case "simple": + self.full_stats = False + case "full": + self.full_stats = True + case "combined+full" | "full+combined": self.combined_progress = True self.full_stats = True - case "combined+simple": + case "combined+simple" | "simple+combined": self.combined_progress = True self.full_stats = False - case "chain+full": + case "chain+full" | "full+chain": self.combined_progress = False self.full_stats = True - case "chain+simple": + case "chain+simple" | "simple+chain": self.combined_progress = False self.full_stats = False case _: raise ValueError( "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " - "or one of 'combined', 'chain', 'combined+full', 'combined+simple', 'chain+full', 'chain+simple'." + "one of 'combined', 'chain', 'simple', 'full', or a '+' delimited pair of two of these values." ) progress_columns, progress_stats = step_method._progressbar_config(chains) From 9649d66b247a5f3af20b5843bd88b9dbfbf0582a Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 24 Jan 2025 00:09:50 +0800 Subject: [PATCH 17/25] Appease mypy --- pymc/backends/__init__.py | 2 +- pymc/sampling/mcmc.py | 8 ++++---- pymc/util.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index d3f7620882..eaa484a13f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -87,7 +87,7 @@ RunType: TypeAlias = Run HAS_MCB = True except ImportError: - TraceOrBackend = BaseTrace # type: ignore[misc] + TraceOrBackend = BaseTrace # type: ignore[assignment, misc] RunType = type(None) # type: ignore[assignment, misc] diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 60d4adf0a8..5b76d498ab 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -276,7 +276,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: else: varnames = ", ".join( [ - get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name + get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[misc] for v in s.vars ] ) @@ -1238,7 +1238,7 @@ def _iter_sample( rng: np.random.Generator, model: Model | None = None, callback: SamplingIteratorCallback | None = None, -) -> Iterator[bool]: +) -> Iterator[list[dict[str, Any]]]: """Sample one chain with a generator (singleprocess). Parameters @@ -1261,8 +1261,8 @@ def _iter_sample( Yields ------ - diverging : bool - Indicates if the draw is divergent. Only available with some samplers. + stats : list of dict + Dictionary of statistics returned by step sampler """ draws = int(draws) diff --git a/pymc/util.py b/pymc/util.py index 99f5db9d96..c979151ed2 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -755,7 +755,7 @@ def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_th self.desc = "Sampling chain" self.chains = chains - self._tasks: list[Task] | None = None + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] def __enter__(self): self._initialize_tasks() From a629a9703be2023217de108ea201a139124c2387 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 24 Jan 2025 00:23:43 +0800 Subject: [PATCH 18/25] Add True case --- pymc/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/util.py b/pymc/util.py index c979151ed2..63d697da66 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -709,6 +709,8 @@ def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_th show_progress = True match progressbar: + case True: + show_progress = True case False: show_progress = False case "combined": From e024991492a36239ab3d94a9c716d8f13dd6ca28 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 26 Jan 2025 00:05:43 +0800 Subject: [PATCH 19/25] Fix final count when `progress = "combined"` --- pymc/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/util.py b/pymc/util.py index 63d697da66..7874c21635 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -842,7 +842,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): if is_last: self._progress.update( self.tasks[chain_idx], - draws=draw + 1 if not self.combined_progress else draw - 1, + draws=draw + 1 if not self.combined_progress else draw, **more_updates, refresh=True, ) From 4e535d4a408e55c4f6a21540020914738397bf1e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 26 Jan 2025 00:23:08 +0800 Subject: [PATCH 20/25] Update docstrings --- pymc/sampling/mcmc.py | 21 +++++++++---- pymc/util.py | 70 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 5b76d498ab..24850e0888 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -66,6 +66,7 @@ from pymc.step_methods.hmc import quadpotential from pymc.util import ( ProgressManager, + ProgressType, RandomSeed, RandomState, _get_seeds_per_chain, @@ -486,7 +487,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -537,11 +538,19 @@ def sample( A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed. We no longer support ``RandomState`` objects because their seeding mechanism does not allow easy spawning of new independent random streams that are needed by the step methods. - progressbar : bool, optional default=True - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - Only applicable to the pymc nuts sampler. + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for either: + - "combined": A single progress bar that displays the progress of all chains combined. + - "chain": A separate progress bar for each chain. + + You can also combine the above options with: + - "simple": A simple progress bar that displays only timing information alongside the progress bar. + - "full": A progress bar that displays all available statistics. + + These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple". + + If True, the default is "chain+full". step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step diff --git a/pymc/util.py b/pymc/util.py index 7874c21635..b3107938eb 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -19,7 +19,7 @@ from collections import namedtuple from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import NewType, cast +from typing import TYPE_CHECKING, Literal, NewType, cast import arviz import cloudpickle @@ -46,6 +46,23 @@ from pymc.exceptions import BlockModelAccessError +if TYPE_CHECKING: + from pymc import BlockedStep + + +ProgressType = Literal[ + "chain", + "combined", + "simple", + "full", + "combined+full", + "full+combined", + "combined+simple", + "simple+combined", + "chain+full", + "full+chain", +] + def __getattr__(name): if name == "dataset_to_point_list": @@ -639,6 +656,7 @@ def make_tasks_table(self, tasks: Iterable[Task]) -> Table: """ def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display if hasattr(column, "callbacks"): column.callbacks(task) @@ -681,6 +699,8 @@ def call_column(column, task): class DivergenceBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a divergence.""" + def __init__(self, *args, diverging_color="red", **kwargs): from matplotlib.colors import to_rgb @@ -703,7 +723,53 @@ def callbacks(self, task: "Task"): class ProgressManager: - def __init__(self, step_method, chains, draws, tune, progressbar, progressbar_theme): + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: BlockedStep, + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressType = True, + progressbar_theme: Theme = default_progress_theme, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for either: + - "combined": A single progress bar that displays the progress of all chains combined. + - "chain": A separate progress bar for each chain. + + You can also combine the above options with: + - "simple": A simple progress bar that displays only timing information alongside the progress bar. + - "full": A progress bar that displays all available statistics. + + These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple". + + If True, the default is "chain+full". + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ self.combined_progress = False self.full_stats = True show_progress = True From b9b05837a05bb6dcdcfa708053159accf6b3a194 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 26 Jan 2025 00:34:37 +0800 Subject: [PATCH 21/25] mypy + cleanup --- pymc/sampling/mcmc.py | 10 ++++++---- pymc/util.py | 9 ++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 24850e0888..c60bc59f03 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -302,7 +302,7 @@ def _sample_external_nuts( initvals: StartDict | Sequence[StartDict | None] | None, model: Model, var_names: Sequence[str] | None, - progressbar: bool, + progressbar: bool | ProgressType, idata_kwargs: dict | None, compute_convergence_checks: bool, nuts_sampler_kwargs: dict | None, @@ -401,7 +401,7 @@ def _sample_external_nuts( initvals=initvals, model=model, var_names=var_names, - progressbar=progressbar, + progressbar=True if progressbar else False, nuts_sampler=sampler, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, @@ -488,7 +488,7 @@ def sample( cores: int | None = None, random_seed: RandomState = None, progressbar: bool | ProgressType = True, - progressbar_theme: Theme | None = default_progress_theme, + progressbar_theme: Theme | None = None, step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -831,7 +831,9 @@ def joined_blas_limiter(): n_init=n_init, model=model, random_seed=random_seed_list, - progressbar=progressbar, + progressbar=True + if progressbar + else False, # ADVI doesn't use the ProgressManager; pass a bool only jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, diff --git a/pymc/util.py b/pymc/util.py index b3107938eb..3a70726638 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -47,7 +47,7 @@ from pymc.exceptions import BlockModelAccessError if TYPE_CHECKING: - from pymc import BlockedStep + from pymc.step_methods.compound import BlockedStep, CompoundStep ProgressType = Literal[ @@ -727,12 +727,12 @@ class ProgressManager: def __init__( self, - step_method: BlockedStep, + step_method: "BlockedStep" | "CompoundStep", chains: int, draws: int, tune: int, progressbar: bool | ProgressType = True, - progressbar_theme: Theme = default_progress_theme, + progressbar_theme: Theme | None = None, ): """ Manage progress bars displayed during sampling. @@ -770,6 +770,9 @@ def __init__( progressbar_theme: Theme, optional The theme to use for the progress bar. Defaults to the default theme. """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + self.combined_progress = False self.full_stats = True show_progress = True From 9de9930ea3e377650bf3f02551afa3a5dde8a45e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 26 Jan 2025 00:44:16 +0800 Subject: [PATCH 22/25] Syntax error in typehint --- pymc/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index 3a70726638..5242bae6fe 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -727,7 +727,7 @@ class ProgressManager: def __init__( self, - step_method: "BlockedStep" | "CompoundStep", + step_method: "BlockedStep | CompoundStep", chains: int, draws: int, tune: int, @@ -745,7 +745,7 @@ def __init__( Parameters ---------- - step_method: BlockedStep + step_method: BlockedStep or CompoundStep The step method being used to sample chains: int Number of chains being sampled From 79d12480a648d79f5afdd6fed39b88eb013c9641 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 26 Jan 2025 21:51:40 +0800 Subject: [PATCH 23/25] Simplify progressbar choices, update docstring --- pymc/sampling/mcmc.py | 39 +++++++++++++++++---------------- pymc/util.py | 51 ++++++++++++++++--------------------------- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c60bc59f03..164ca5d99b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -302,7 +302,7 @@ def _sample_external_nuts( initvals: StartDict | Sequence[StartDict | None] | None, model: Model, var_names: Sequence[str] | None, - progressbar: bool | ProgressType, + progressbar: bool, idata_kwargs: dict | None, compute_convergence_checks: bool, nuts_sampler_kwargs: dict | None, @@ -401,7 +401,7 @@ def _sample_external_nuts( initvals=initvals, model=model, var_names=var_names, - progressbar=True if progressbar else False, + progressbar=progressbar, nuts_sampler=sampler, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, @@ -423,7 +423,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -455,7 +455,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool = True, + progressbar: bool | ProgressType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -540,17 +540,16 @@ def sample( easy spawning of new independent random streams that are needed by the step methods. progressbar: bool or ProgressType, optional How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask - for either: - - "combined": A single progress bar that displays the progress of all chains combined. - - "chain": A separate progress bar for each chain. - - You can also combine the above options with: - - "simple": A simple progress bar that displays only timing information alongside the progress bar. - - "full": A progress bar that displays all available statistics. - - These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple". - - If True, the default is "chain+full". + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step @@ -716,6 +715,10 @@ def sample( if isinstance(trace, list): raise ValueError("Please use `var_names` keyword argument for partial traces.") + # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and + # ADVI initialization expect just a bool. + progress_bool = True if progressbar else False + model = modelcontext(model) if not model.free_RVs: raise SamplingError( @@ -812,7 +815,7 @@ def joined_blas_limiter(): initvals=initvals, model=model, var_names=var_names, - progressbar=progressbar, + progressbar=progress_bool, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, nuts_sampler_kwargs=nuts_sampler_kwargs, @@ -831,9 +834,7 @@ def joined_blas_limiter(): n_init=n_init, model=model, random_seed=random_seed_list, - progressbar=True - if progressbar - else False, # ADVI doesn't use the ProgressManager; pass a bool only + progressbar=progress_bool, jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, diff --git a/pymc/util.py b/pymc/util.py index 5242bae6fe..223c25248c 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -51,16 +51,12 @@ ProgressType = Literal[ - "chain", "combined", - "simple", - "full", - "combined+full", - "full+combined", - "combined+simple", - "simple+combined", - "chain+full", - "full+chain", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", ] @@ -755,17 +751,16 @@ def __init__( Number of tuning steps per chain progressbar: bool or ProgressType, optional How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask - for either: - - "combined": A single progress bar that displays the progress of all chains combined. - - "chain": A separate progress bar for each chain. + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. - You can also combine the above options with: - - "simple": A simple progress bar that displays only timing information alongside the progress bar. - - "full": A progress bar that displays all available statistics. - - These can be combined with a "+" delimiter, for example: "combined+full" or "chain+simple". - - If True, the default is "chain+full". + If True, the default is "split+stats" is used. progressbar_theme: Theme, optional The theme to use for the progress bar. Defaults to the default theme. @@ -784,28 +779,20 @@ def __init__( show_progress = False case "combined": self.combined_progress = True - case "chain": + self.full_stats = False + case "split": self.combined_progress = False - case "simple": self.full_stats = False - case "full": - self.full_stats = True - case "combined+full" | "full+combined": - self.combined_progress = True + case "combined+stats" | "stats+combined": self.full_stats = True - case "combined+simple" | "simple+combined": self.combined_progress = True - self.full_stats = False - case "chain+full" | "full+chain": - self.combined_progress = False + case "split+stats" | "stats+split": self.full_stats = True - case "chain+simple" | "simple+chain": self.combined_progress = False - self.full_stats = False case _: raise ValueError( "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " - "one of 'combined', 'chain', 'simple', 'full', or a '+' delimited pair of two of these values." + "one of 'combined', 'split', 'split+stats', or 'combined+stats." ) progress_columns, progress_stats = step_method._progressbar_config(chains) From 161d10cbd201ed40ab99e3bb2899f8b9b10761a1 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Mon, 27 Jan 2025 13:28:59 +0800 Subject: [PATCH 24/25] Incorporate feedback --- pymc/sampling/mcmc.py | 18 +++++++++--------- pymc/sampling/parallel.py | 4 ++-- pymc/util.py | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 164ca5d99b..bd2425c9f5 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -65,8 +65,8 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - ProgressManager, - ProgressType, + ProgressBarManager, + ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, @@ -423,7 +423,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool | ProgressType = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -455,7 +455,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool | ProgressType = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, @@ -487,7 +487,7 @@ def sample( chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, - progressbar: bool | ProgressType = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = None, step=None, var_names: Sequence[str] | None = None, @@ -717,7 +717,7 @@ def sample( # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and # ADVI initialization expect just a bool. - progress_bool = True if progressbar else False + progress_bool = bool(progressbar) model = modelcontext(model) if not model.free_RVs: @@ -1148,7 +1148,7 @@ def _sample_many( Step function """ initial_step_state = step.sampling_state - progress_manager = ProgressManager( + progress_manager = ProgressBarManager( step_method=step, chains=chains, draws=draws - kwargs.get("tune", 0), @@ -1185,7 +1185,7 @@ def _sample( tune: int, model: Model | None = None, callback=None, - progress_manager: ProgressManager, + progress_manager: ProgressBarManager, **kwargs, ) -> None: """Sample one chain (singleprocess). @@ -1210,7 +1210,7 @@ def _sample( Number of iterations to tune. model : Model, optional PyMC model. If None, the model is taken from the current context. - progress_manager: ProgressManager + progress_manager: ProgressBarManager Helper class used to handle progress bar styling and updates """ sampling_gen = _iter_sample( diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 549d205e90..af2106ce6f 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -34,7 +34,7 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import ( - ProgressManager, + ProgressBarManager, RandomGeneratorState, default_progress_theme, get_state_from_generator, @@ -483,7 +483,7 @@ def __init__( self._max_active = cores self._in_context = False - self._progress = ProgressManager( + self._progress = ProgressBarManager( step_method=step_method, chains=chains, draws=draws, diff --git a/pymc/util.py b/pymc/util.py index 223c25248c..cbc25b6564 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -50,7 +50,7 @@ from pymc.step_methods.compound import BlockedStep, CompoundStep -ProgressType = Literal[ +ProgressBarType = Literal[ "combined", "split", "combined+stats", @@ -718,7 +718,7 @@ def callbacks(self, task: "Task"): self.finished_style = self.non_diverging_finished_style -class ProgressManager: +class ProgressBarManager: """Manage progress bars displayed during sampling.""" def __init__( @@ -727,7 +727,7 @@ def __init__( chains: int, draws: int, tune: int, - progressbar: bool | ProgressType = True, + progressbar: bool | ProgressBarType = True, progressbar_theme: Theme | None = None, ): """ From b381e5d3cf26f6910219032cf6871c21519eeaff Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Mon, 27 Jan 2025 13:59:09 +0800 Subject: [PATCH 25/25] Be verbose with progressbar settings --- pymc/util.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index cbc25b6564..979b3beebf 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -768,27 +768,31 @@ def __init__( if progressbar_theme is None: progressbar_theme = default_progress_theme - self.combined_progress = False - self.full_stats = True - show_progress = True - match progressbar: case True: + self.combined_progress = False + self.full_stats = True show_progress = True case False: + self.combined_progress = False + self.full_stats = True show_progress = False case "combined": self.combined_progress = True self.full_stats = False + show_progress = True case "split": self.combined_progress = False self.full_stats = False + show_progress = True case "combined+stats" | "stats+combined": - self.full_stats = True self.combined_progress = True - case "split+stats" | "stats+split": self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": self.combined_progress = False + self.full_stats = True + show_progress = True case _: raise ValueError( "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), "