Skip to content

Commit

Permalink
Improves tool to compute statistics; adds deduplication options. (#135)
Browse files Browse the repository at this point in the history
* more stats

* updated stats

* updated CI

* updated analyzer

* added support for custom separator

* fixed test bug

* added options to skip short paragraphs

* missed copy

* chekc in mypy
  • Loading branch information
soldni authored Mar 17, 2024
1 parent 476629d commit 58dedd3
Show file tree
Hide file tree
Showing 14 changed files with 715 additions and 129 deletions.
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,12 @@ style:
rustfmt --edition 2021 src/*.rs
isort .
black .

check:
isort --check .
black --check .
mypy tests/python/
mypy python/
flake8 tests/python/
flake8 python/
rustfmt --edition 2021 src/*.rs --check
6 changes: 4 additions & 2 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,14 @@ class BaseCli(Generic[D]):
@classmethod
def make_parser(cls, parser: A) -> A:
assert hasattr(cls, "CONFIG"), f"{cls.__name__} must have a CONFIG attribute"
return make_parser(parser, cls.CONFIG)
return make_parser(parser, cls.CONFIG) # pyright: ignore

@classmethod
def run_from_args(cls, args: Namespace, config: Optional[dict] = None):
assert hasattr(cls, "CONFIG"), f"{cls.__name__} must have a CONFIG attribute"
parsed_config = namespace_to_nested_omegaconf(args=args, structured=cls.CONFIG, config=config)
parsed_config = namespace_to_nested_omegaconf(
args=args, structured=cls.CONFIG, config=config # pyright: ignore
)
try:
return cls.run(parsed_config)
except OmegaConfBaseException as ex:
Expand Down
7 changes: 7 additions & 0 deletions python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class NgramDedupeConfig:
default=1.0,
help="Fraction of ngrams that must be seen before a paragraph is considered a duplicate. By default, only full overlap is considered.",
)
skip_short_paragraphs: bool = field(
default=False, help="If true, paragraphs shorter than (ngram_length + stride) will be skipped."
)


@dataclass
Expand All @@ -32,6 +35,10 @@ class ParagraphDedupeConfig:
by_ngram: Optional[NgramDedupeConfig] = field(
default=None, help="Configuration for deduping paragraphs by ngram overlap"
)
paragraph_separator: Optional[str] = field(
default="\n",
help="String to use to separate paragraphs. By default, paragraphs are separated by newlines.",
)


@dataclass
Expand Down
114 changes: 57 additions & 57 deletions python/dolma/core/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
import multiprocessing
import re
import shutil
from contextlib import ExitStack
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional
Expand Down Expand Up @@ -43,8 +42,8 @@ def from_tracker(cls, name: str, tracker: "BaseBucketApi", n: int) -> "SummarySp
counts, bins = tracker.summarize(n=n)
return SummarySpec(name=name, counts=counts, bins=bins)

def to_tracker(self) -> "BaseBucketApi":
tracker = _make_tracker()
def to_tracker(self, tracker_type: str = "fixed", **tracker_kwargs) -> "BaseBucketApi":
tracker = _make_tracker(type_=tracker_type, **tracker_kwargs)
tracker.add_many(values=self.bins, counts=self.counts)
return tracker

Expand Down Expand Up @@ -151,7 +150,12 @@ def aggregate_summaries(summaries_path: str, num_bins: int = 1000) -> List[Summa
decoder = Decoder(SummarySpec)

# iterator with nice progress bar
it = tqdm.tqdm(list(glob_path(summaries_path)), desc="Aggregating summaries", unit=" files", unit_scale=True)
it = tqdm.tqdm(
list(glob_path(summaries_path, autoglob_dirs=True, recursive_dirs=True, yield_dirs=False)),
desc="Aggregating summaries",
unit=" files",
unit_scale=True,
)

# load partial summaries and aggregate it
for path in it:
Expand All @@ -168,57 +172,58 @@ def aggregate_summaries(summaries_path: str, num_bins: int = 1000) -> List[Summa
return summaries


def visualize_summaries(summaries: List[SummarySpec], digits: int = 4, num_viz_bins: int = 10):
def round_values_for_visual(values: List[float], opt_sci: bool = False, max_decimal: int = 4) -> List[str]:
"""Logic to round values depending on their range"""

# we try rounding as little as possible until all values are different
# we reach the maximum number of decimal points
for decimal in range(max_decimal):
attempt_rounding = [round(val, decimal) for val in values]
if len(set(attempt_rounding)) == len(values):
# success! let's return the rounded values
return [f"{val:.{decimal}f}" for val in values]

# no luck; let's use scientific notation instead if we are allowed to or simply return the values
if opt_sci:
return [f"{val:.1e}" for val in values]
else:
return [f"{val:.{max_decimal}f}" for val in values]


def visualize_summaries(summaries: List[SummarySpec], max_decimal: int = 4, num_viz_bins: int = 10):
console = Console()
console.print()

def round_all(values: List[float], opt_sci: bool = False) -> List[str]:
"""Logic to round values depending on their range"""

if values == [0, 1]:
# just 0 and 1; no need to round or add decimal points
return ["0", "1"]
elif all(-1 <= val <= 1 for val in values):
# all values are in the range [-1, 1]; let's attempt rounding with {digits} decimal points
# unless some values are identical after rounding.
attempt_rounding = [round(val, digits) for val in values]

if len(set(attempt_rounding)) != len(values) and opt_sci:
# oops, some values collide after rounding; let's use scientific notation instead
# with one decimal point (note that we do this only if `opt_sci` is True)
return [f"{val:.1e}" for val in values]
else:
# phew, all good; let's use {digits} decimal points for all values
return [f"{round(val, digits):.{digits}f}" for val in values]
else:
# all values are outside the range [-1, 1]; let's round them to the nearest integer
return [f"{int(round(val, 0)):d}" for val in values]

for summary in summaries:
# we use fewer bins for visualization
summary = SummarySpec(
short_summary = SummarySpec(
name=summary.name,
counts=(re_binned := summary.to_tracker().summarize(n=num_viz_bins)).counts,
counts=(re_binned := summary.to_tracker().summarize(n=num_viz_bins, mode="count")).counts,
bins=re_binned.bins,
)

# build the table here
table = Table(title=summary.name, style="bold", min_width=len(summary.name))
table = Table(title=short_summary.name, style="bold", min_width=len(short_summary.name))
table.add_column("value", justify="left", style="cyan")
table.add_column("dist", justify="left", style="magenta")
table.add_column("count", justify="left", style="green")

rounded_bins = round_all(summary.bins)
# we round the bins and write them in [lo, hi) format ]
rounded_bins = round_values_for_visual(values=short_summary.bins, max_decimal=max_decimal)
ranges = (
[f"[{lo}, {hi})" for lo, hi in zip(rounded_bins, rounded_bins[1:])]
if len(summary.bins) > len(summary.counts)
[
f"[{lo}, {hi}" + ("]" if i == (len(short_summary.bins) - 2) else ")")
for i, (lo, hi) in enumerate(zip(rounded_bins, rounded_bins[1:]))
]
if len(short_summary.bins) > len(short_summary.counts)
else rounded_bins
)

counts_sum = sum(summary.counts)
counts_normed = round_all([(count / counts_sum) for count in summary.counts], opt_sci=False)
counts_sum = sum(short_summary.counts)
counts_normed = round_values_for_visual(
values=[(count / counts_sum) for count in short_summary.counts], opt_sci=False, max_decimal=max_decimal
)

for value, dist, count in zip(ranges, counts_normed, summary.counts):
for value, dist, count in zip(ranges, counts_normed, short_summary.counts):
table.add_row(value, dist, f"{count:,}")

console.print(table)
Expand Down Expand Up @@ -261,23 +266,18 @@ def create_and_run_analyzer(
mkdir_p(summaries_path)
mkdir_p(metadata_path)

try:
analyzer = AnalyzerProcessor(
source_prefix=attributes,
destination_prefix=summaries_path,
metadata_prefix=metadata_path,
debug=debug,
seed=seed,
ignore_existing=True,
retries_on_error=0,
num_processes=num_processes,
)
analyzer(num_bins=num_bins, name_regex=name_regex)

summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=num_bins)
visualize_summaries(summaries=summaries)
write_output(summaries=summaries, report=report)

finally:
shutil.rmtree(summaries_path)
shutil.rmtree(metadata_path)
analyzer = AnalyzerProcessor(
source_prefix=attributes,
destination_prefix=summaries_path,
metadata_prefix=metadata_path,
debug=debug,
seed=seed,
ignore_existing=True,
retries_on_error=0,
num_processes=num_processes,
)
analyzer(num_bins=num_bins, name_regex=name_regex)

summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=num_bins)
visualize_summaries(summaries=summaries)
write_output(summaries=summaries, report=report)
Loading

0 comments on commit 58dedd3

Please sign in to comment.