From 00504ff0c2a69003928eee147eeeb5b4113f21d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Wed, 22 Nov 2023 07:25:07 +0100 Subject: [PATCH] Fix a few issues of the FixedBucketsValTracker (#73) * Fix a few issues of the FixedBucketsValTracker 1. Make the default number of bins of the internal tracker smaller so it does not cause numerical issues and/or memory problems. 2. Use floor() instead of int() (trunc) for rounding to have the same behaviour for positive and negative numbers. 3. Add an extra bin in the summarization method such that the number of bins in the summary is always "number of values"+1. This is consistent with the numpy histogram convention. * fix style --------- Co-authored-by: Luca Soldaini --- python/dolma/core/analyzer.py | 3 ++- python/dolma/core/binning.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/dolma/core/analyzer.py b/python/dolma/core/analyzer.py index b76baa57..5e0cb5f7 100644 --- a/python/dolma/core/analyzer.py +++ b/python/dolma/core/analyzer.py @@ -1,3 +1,4 @@ +import math import multiprocessing import re import shutil @@ -27,7 +28,7 @@ def _make_tracker(type_: str = "fixed", **kwargs: int) -> BaseBucketApi: if type_ == "infer": return InferBucketsValTracker(**{"n": NUM_BINS, "b": BUFF_SIZE, **kwargs}) elif type_ == "fixed": - return FixedBucketsValTracker(**{"n": NUM_BINS, **kwargs}) + return FixedBucketsValTracker(**{"n": int(math.log10(NUM_BINS)), **kwargs}) else: raise ValueError(f"Unknown tracker type {type_}") diff --git a/python/dolma/core/binning.py b/python/dolma/core/binning.py index cd41b6b0..8f13f59d 100644 --- a/python/dolma/core/binning.py +++ b/python/dolma/core/binning.py @@ -235,14 +235,17 @@ def summarize(self, n: int, density: bool = False) -> SummaryTuple: class FixedBucketsValTracker(BaseBucketApi): def __init__(self, n: int = 2): + # we use n to determine the precision of the bins; for convenience we store it as a power of 10. + # 10**n will be the maximum number of bins for each power of 2. + # Too large numbers will cause numeric problems and can cause a lot of memory use. assert n >= 0 - # we use n to determine the precision of the bins; for convenience we store it as a power of 10 + assert n <= 100 self.n = 10**n self._bins: Dict[Tuple[int, int], int] = {} def add(self, value: Union[int, float], count: int = 1): m, e = math.frexp(value) - k = int(m * self.n), e + k = math.floor(m * self.n), e if k not in self._bins: self._bins[k] = 0 @@ -255,12 +258,20 @@ def __len__(self) -> int: def full(self) -> bool: return False + def get_bin_upper_bound(self, val: float) -> float: + """Return the upper bound of the bin containing val""" + m, e = math.frexp(val) + k = math.floor(m * self.n) + 1 # Add one to obtain the next bin + return k / self.n * 2**e + def summarize(self, n: int, density: bool = False) -> SummaryTuple: bins, counts = zip(*sorted((m / self.n * 2**e, c) for (m, e), c in self._bins.items())) if len(self) <= n: # if there are fewer than n buckets, return the buckets as is - return SummaryTuple(counts=[int(c) for c in counts], bins=[float(b) for b in bins]) + # To be consistent we also add the limit of the last bin, so the bins denote bin edges + upper_bin = self.get_bin_upper_bound(max(float(b) for b in bins)) + return SummaryTuple(counts=[int(c) for c in counts], bins=[float(b) for b in bins] + [upper_bin]) # computing the weighted histograms new_counts, new_values = np.histogram(a=bins, bins=n, weights=counts, density=density)