From 58dedd301a267095afc12a95aeab5dfd87a680ad Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Sat, 16 Mar 2024 17:50:11 -0700 Subject: [PATCH] Improves tool to compute statistics; adds deduplication options. (#135) * more stats * updated stats * updated CI * updated analyzer * added support for custom separator * fixed test bug * added options to skip short paragraphs * missed copy * chekc in mypy --- Makefile | 9 ++ python/dolma/cli/__init__.py | 6 +- python/dolma/cli/deduper.py | 7 ++ python/dolma/core/analyzer.py | 114 ++++++++++---------- python/dolma/core/binning.py | 138 ++++++++++++++++++++++-- python/dolma/core/parallel.py | 170 +++++++++++++++++++++++++----- python/dolma/core/paths.py | 125 ++++++++++++++++++---- python/dolma/core/utils.py | 10 +- src/deduper.rs | 38 +++++-- tests/config/mixer.json | 2 +- tests/config/paragraph-spans.json | 2 +- tests/python/test_analysis.py | 114 ++++++++++++++++++++ tests/python/test_deduper.py | 101 ++++++++++++++++++ tests/python/test_runtime.py | 8 +- 14 files changed, 715 insertions(+), 129 deletions(-) create mode 100644 tests/python/test_analysis.py diff --git a/Makefile b/Makefile index cad13a43..d7e2a73a 100644 --- a/Makefile +++ b/Makefile @@ -46,3 +46,12 @@ style: rustfmt --edition 2021 src/*.rs isort . black . + +check: + isort --check . + black --check . + mypy tests/python/ + mypy python/ + flake8 tests/python/ + flake8 python/ + rustfmt --edition 2021 src/*.rs --check diff --git a/python/dolma/cli/__init__.py b/python/dolma/cli/__init__.py index 443c2ad2..de6349cd 100644 --- a/python/dolma/cli/__init__.py +++ b/python/dolma/cli/__init__.py @@ -180,12 +180,14 @@ class BaseCli(Generic[D]): @classmethod def make_parser(cls, parser: A) -> A: assert hasattr(cls, "CONFIG"), f"{cls.__name__} must have a CONFIG attribute" - return make_parser(parser, cls.CONFIG) + return make_parser(parser, cls.CONFIG) # pyright: ignore @classmethod def run_from_args(cls, args: Namespace, config: Optional[dict] = None): assert hasattr(cls, "CONFIG"), f"{cls.__name__} must have a CONFIG attribute" - parsed_config = namespace_to_nested_omegaconf(args=args, structured=cls.CONFIG, config=config) + parsed_config = namespace_to_nested_omegaconf( + args=args, structured=cls.CONFIG, config=config # pyright: ignore + ) try: return cls.run(parsed_config) except OmegaConfBaseException as ex: diff --git a/python/dolma/cli/deduper.py b/python/dolma/cli/deduper.py index 5219bb61..de684200 100644 --- a/python/dolma/cli/deduper.py +++ b/python/dolma/cli/deduper.py @@ -24,6 +24,9 @@ class NgramDedupeConfig: default=1.0, help="Fraction of ngrams that must be seen before a paragraph is considered a duplicate. By default, only full overlap is considered.", ) + skip_short_paragraphs: bool = field( + default=False, help="If true, paragraphs shorter than (ngram_length + stride) will be skipped." + ) @dataclass @@ -32,6 +35,10 @@ class ParagraphDedupeConfig: by_ngram: Optional[NgramDedupeConfig] = field( default=None, help="Configuration for deduping paragraphs by ngram overlap" ) + paragraph_separator: Optional[str] = field( + default="\n", + help="String to use to separate paragraphs. By default, paragraphs are separated by newlines.", + ) @dataclass diff --git a/python/dolma/core/analyzer.py b/python/dolma/core/analyzer.py index 5e0cb5f7..4c856f94 100644 --- a/python/dolma/core/analyzer.py +++ b/python/dolma/core/analyzer.py @@ -1,7 +1,6 @@ import math import multiprocessing import re -import shutil from contextlib import ExitStack from tempfile import TemporaryDirectory from typing import Dict, List, Optional @@ -43,8 +42,8 @@ def from_tracker(cls, name: str, tracker: "BaseBucketApi", n: int) -> "SummarySp counts, bins = tracker.summarize(n=n) return SummarySpec(name=name, counts=counts, bins=bins) - def to_tracker(self) -> "BaseBucketApi": - tracker = _make_tracker() + def to_tracker(self, tracker_type: str = "fixed", **tracker_kwargs) -> "BaseBucketApi": + tracker = _make_tracker(type_=tracker_type, **tracker_kwargs) tracker.add_many(values=self.bins, counts=self.counts) return tracker @@ -151,7 +150,12 @@ def aggregate_summaries(summaries_path: str, num_bins: int = 1000) -> List[Summa decoder = Decoder(SummarySpec) # iterator with nice progress bar - it = tqdm.tqdm(list(glob_path(summaries_path)), desc="Aggregating summaries", unit=" files", unit_scale=True) + it = tqdm.tqdm( + list(glob_path(summaries_path, autoglob_dirs=True, recursive_dirs=True, yield_dirs=False)), + desc="Aggregating summaries", + unit=" files", + unit_scale=True, + ) # load partial summaries and aggregate it for path in it: @@ -168,57 +172,58 @@ def aggregate_summaries(summaries_path: str, num_bins: int = 1000) -> List[Summa return summaries -def visualize_summaries(summaries: List[SummarySpec], digits: int = 4, num_viz_bins: int = 10): +def round_values_for_visual(values: List[float], opt_sci: bool = False, max_decimal: int = 4) -> List[str]: + """Logic to round values depending on their range""" + + # we try rounding as little as possible until all values are different + # we reach the maximum number of decimal points + for decimal in range(max_decimal): + attempt_rounding = [round(val, decimal) for val in values] + if len(set(attempt_rounding)) == len(values): + # success! let's return the rounded values + return [f"{val:.{decimal}f}" for val in values] + + # no luck; let's use scientific notation instead if we are allowed to or simply return the values + if opt_sci: + return [f"{val:.1e}" for val in values] + else: + return [f"{val:.{max_decimal}f}" for val in values] + + +def visualize_summaries(summaries: List[SummarySpec], max_decimal: int = 4, num_viz_bins: int = 10): console = Console() console.print() - def round_all(values: List[float], opt_sci: bool = False) -> List[str]: - """Logic to round values depending on their range""" - - if values == [0, 1]: - # just 0 and 1; no need to round or add decimal points - return ["0", "1"] - elif all(-1 <= val <= 1 for val in values): - # all values are in the range [-1, 1]; let's attempt rounding with {digits} decimal points - # unless some values are identical after rounding. - attempt_rounding = [round(val, digits) for val in values] - - if len(set(attempt_rounding)) != len(values) and opt_sci: - # oops, some values collide after rounding; let's use scientific notation instead - # with one decimal point (note that we do this only if `opt_sci` is True) - return [f"{val:.1e}" for val in values] - else: - # phew, all good; let's use {digits} decimal points for all values - return [f"{round(val, digits):.{digits}f}" for val in values] - else: - # all values are outside the range [-1, 1]; let's round them to the nearest integer - return [f"{int(round(val, 0)):d}" for val in values] - for summary in summaries: # we use fewer bins for visualization - summary = SummarySpec( + short_summary = SummarySpec( name=summary.name, - counts=(re_binned := summary.to_tracker().summarize(n=num_viz_bins)).counts, + counts=(re_binned := summary.to_tracker().summarize(n=num_viz_bins, mode="count")).counts, bins=re_binned.bins, ) - # build the table here - table = Table(title=summary.name, style="bold", min_width=len(summary.name)) + table = Table(title=short_summary.name, style="bold", min_width=len(short_summary.name)) table.add_column("value", justify="left", style="cyan") table.add_column("dist", justify="left", style="magenta") table.add_column("count", justify="left", style="green") - rounded_bins = round_all(summary.bins) + # we round the bins and write them in [lo, hi) format ] + rounded_bins = round_values_for_visual(values=short_summary.bins, max_decimal=max_decimal) ranges = ( - [f"[{lo}, {hi})" for lo, hi in zip(rounded_bins, rounded_bins[1:])] - if len(summary.bins) > len(summary.counts) + [ + f"[{lo}, {hi}" + ("]" if i == (len(short_summary.bins) - 2) else ")") + for i, (lo, hi) in enumerate(zip(rounded_bins, rounded_bins[1:])) + ] + if len(short_summary.bins) > len(short_summary.counts) else rounded_bins ) - counts_sum = sum(summary.counts) - counts_normed = round_all([(count / counts_sum) for count in summary.counts], opt_sci=False) + counts_sum = sum(short_summary.counts) + counts_normed = round_values_for_visual( + values=[(count / counts_sum) for count in short_summary.counts], opt_sci=False, max_decimal=max_decimal + ) - for value, dist, count in zip(ranges, counts_normed, summary.counts): + for value, dist, count in zip(ranges, counts_normed, short_summary.counts): table.add_row(value, dist, f"{count:,}") console.print(table) @@ -261,23 +266,18 @@ def create_and_run_analyzer( mkdir_p(summaries_path) mkdir_p(metadata_path) - try: - analyzer = AnalyzerProcessor( - source_prefix=attributes, - destination_prefix=summaries_path, - metadata_prefix=metadata_path, - debug=debug, - seed=seed, - ignore_existing=True, - retries_on_error=0, - num_processes=num_processes, - ) - analyzer(num_bins=num_bins, name_regex=name_regex) - - summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=num_bins) - visualize_summaries(summaries=summaries) - write_output(summaries=summaries, report=report) - - finally: - shutil.rmtree(summaries_path) - shutil.rmtree(metadata_path) + analyzer = AnalyzerProcessor( + source_prefix=attributes, + destination_prefix=summaries_path, + metadata_prefix=metadata_path, + debug=debug, + seed=seed, + ignore_existing=True, + retries_on_error=0, + num_processes=num_processes, + ) + analyzer(num_bins=num_bins, name_regex=name_regex) + + summaries = aggregate_summaries(summaries_path=summaries_path, num_bins=num_bins) + visualize_summaries(summaries=summaries) + write_output(summaries=summaries, report=report) diff --git a/python/dolma/core/binning.py b/python/dolma/core/binning.py index 8f13f59d..c5c6734e 100644 --- a/python/dolma/core/binning.py +++ b/python/dolma/core/binning.py @@ -1,19 +1,124 @@ import math from abc import abstractmethod, abstractproperty -from typing import Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union import numpy as np import numpy.typing as npt -# # # OLD IMPORT # # # -# from sortedcontainers import SortedDict - class SummaryTuple(NamedTuple): counts: List[int] bins: List[float] +def cumsum_with_reset(arr: np.ndarray, reset_value: int = 0) -> np.ndarray: + """Compute the cumulative sum of an array, but reset the sum when a certain value is encountered.""" + arr = np.array(arr) + # Cumulative sum of the array + cumsum = arr.cumsum() + # Find indices where arr is `reset_value` and set the diff at these indices to the negated cumsum + reset_indices, *_ = np.where(arr == reset_value) + # For each reset point, subtract the cumsum value up to that point from the subsequent values + for i in reset_indices: + cumsum[i:] -= cumsum[i] + return cumsum + + +def equal_count_hist( + a: npt.ArrayLike, + bins: Union[int, npt.ArrayLike, str] = 10, + density: bool = False, + weights: Optional[npt.ArrayLike] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Compute bins such that each bin has approximately the same number of elements.""" + + if not isinstance(bins, int): + return np.histogram(a, bins=bins, weights=weights, density=density) + + if weights is None: + weights = np.ones_like(a) + weights = np.array(weights, dtype=int) + + if not isinstance(a, np.ndarray): + a = np.array(a) + + if a.size == 0: + return np.array([]), np.array([]) + + # can't have more bins than elements + bins = min(bins, a.size) + + iterative_counts = np.array(weights, dtype=int) + current_n = bins - 1 + bin_end_pos: List[int] = [] + + while current_n > 0: + # this is the number of elements we want in each bin + elements_per_bin = (iterative_counts * (iterative_counts > 0).astype(int)).sum() // current_n + + # whether there are individual bins that are above the size of each bin; + # therefore, we need to isolate them before we can split the rest + new_bins_locs = iterative_counts >= elements_per_bin + + if not new_bins_locs.any(): + # bins are all of proper size, so we have to use the cumulative sum to find the + # positions of the new bins. + + # calculate the cumulative sum of the counts; note that we have to reset the sum + # in case we encounter a bin with size zero, which is the result of a previous split. + cumsum_iterative_counts = cumsum_with_reset(iterative_counts, -1) + + # we calculate how many multiple of elements_per_bin we have in each cumulative sum, + # rounded to the nearest integer + rounded_bin_multiples = np.round(cumsum_iterative_counts / elements_per_bin).astype(int) + + if rounded_bin_multiples.sum() > 0: + # we check for cases where the rounded multiples are larger than the previous rounded + # multiples, which indicates that we have a new bin! + new_bins_locs = rounded_bin_multiples - np.roll(rounded_bin_multiples, 1) > 0 + else: + # this happened because the existing bins cause the partial cumulative sums to be less + # than the elements_per_bin; in this case, we pick a split on the largest bin in the array + new_bins_locs[np.argmax(iterative_counts)] = True + + # if the last position gets selected as a bin, then we need to increase + # the number of bins by one. + if new_bins_locs[-1]: + current_n += 1 + + # using the new locations, we can find the indices of the new bins + new_bins, *_ = np.where(new_bins_locs) + + # if we have more than expected new_bins_locs, we roll them so they are equally distributed + if new_bins.size > current_n: + new_bins = new_bins[0 : len(new_bins) : new_bins.size // current_n] + + # add the new bins to the list of bin end positions we found so far + bin_end_pos.extend(new_bins) + + # update the counts to reflect the positions of the bins + iterative_counts[new_bins_locs] = -1 + current_n -= len(new_bins) + + # sort the location of bins, add the last position as the end of the array + bin_end_pos = sorted(set(bin_end_pos + [len(a) - 1])) + + # bins are easy to get; we just take the values of a at the positions we found + final_bins = np.concatenate( + ([a[0]], [(a[i] + a[i + 1]) / 2 if (i + 1) < len(a) else a[i] for i in bin_end_pos]) + ) + + # for counts, we first compute cumsum of the weights, then we take the difference between + # the cumulative sum at the end of each bin and the cumulative sum at the start of each bin + final_counts = np.diff(np.concatenate(([0], np.cumsum(weights)[bin_end_pos]))) + + # if density is True, we normalize the counts + if density: + final_counts /= final_counts.sum() + + return final_counts, final_bins + + def sort_and_merge_bins( bins: npt.NDArray[np.float64], counts: npt.NDArray[np.int64], mask: Optional[npt.NDArray[np.bool_]] = None ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int64]]: @@ -110,7 +215,7 @@ def add_many(self, values: List[Union[int, float]], counts: List[int]): self.add(value, count) @abstractmethod - def summarize(self, n: int, density: bool = False) -> SummaryTuple: + def summarize(self, n: int, density: bool = False, mode: Literal["width", "count"] = "width") -> SummaryTuple: raise NotImplementedError() @@ -216,7 +321,7 @@ def add(self, value: Union[int, float], count: int = 1): else: self._add_full(value=value, count=count) - def summarize(self, n: int, density: bool = False) -> SummaryTuple: + def summarize(self, n: int, density: bool = False, mode: Literal["width", "count"] = "width") -> SummaryTuple: """Return up to n buckets with counts of merged values""" # finalize operations @@ -226,8 +331,13 @@ def summarize(self, n: int, density: bool = False) -> SummaryTuple: # if there are fewer than n buckets, return the buckets as is return SummaryTuple(counts=self._counts.tolist(), bins=self._bins.tolist()) - # make weighted histogram using counts - new_counts, new_values = np.histogram(a=self._bins, bins=n, weights=self._counts, density=density) + if mode == "width": + # make weighted histogram using counts + new_counts, new_values = np.histogram(a=self._bins, bins=n, weights=self._counts, density=density) + elif mode == "count": + new_counts, new_values = equal_count_hist(a=self._bins, bins=n, weights=self._counts, density=density) + else: + raise ValueError(f"Invalid mode: {mode}") # return lists instead of numpy arrays return SummaryTuple(counts=new_counts.tolist(), bins=new_values.tolist()) @@ -264,7 +374,7 @@ def get_bin_upper_bound(self, val: float) -> float: k = math.floor(m * self.n) + 1 # Add one to obtain the next bin return k / self.n * 2**e - def summarize(self, n: int, density: bool = False) -> SummaryTuple: + def summarize(self, n: int, density: bool = False, mode: Literal["width", "count"] = "width") -> SummaryTuple: bins, counts = zip(*sorted((m / self.n * 2**e, c) for (m, e), c in self._bins.items())) if len(self) <= n: @@ -273,8 +383,14 @@ def summarize(self, n: int, density: bool = False) -> SummaryTuple: upper_bin = self.get_bin_upper_bound(max(float(b) for b in bins)) return SummaryTuple(counts=[int(c) for c in counts], bins=[float(b) for b in bins] + [upper_bin]) - # computing the weighted histograms - new_counts, new_values = np.histogram(a=bins, bins=n, weights=counts, density=density) + if mode == "width": + # computing the weighted histograms + new_counts, new_values = np.histogram(a=bins, bins=n, weights=counts, density=density) + elif mode == "count": + # make a copy of the counts + new_counts, new_values = equal_count_hist(a=bins, bins=n, weights=counts, density=density) + else: + raise ValueError(f"Invalid mode: {mode}") # return lists instead of numpy arrays return SummaryTuple(counts=new_counts.tolist(), bins=new_values.tolist()) diff --git a/python/dolma/core/parallel.py b/python/dolma/core/parallel.py index 3cdc18d1..0bbfc75f 100644 --- a/python/dolma/core/parallel.py +++ b/python/dolma/core/parallel.py @@ -11,7 +11,7 @@ from functools import partial from queue import Queue from threading import Thread -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeVar, Union import smart_open import tqdm @@ -34,6 +34,19 @@ # we need to quote the type alias because we want to support Python 3.8 QueueType: TypeAlias = "Queue[Union[None, Tuple[int, ...]]]" +KwargsType: TypeAlias = Dict[str, Any] +BPP = TypeVar("BPP", bound="BaseParallelProcessor") + + +class AllPathsTuple(NamedTuple): + src: List[str] + dst: List[str] + meta: List[str] + kwargs: List[KwargsType] + + @classmethod + def empty(cls) -> "AllPathsTuple": + return AllPathsTuple([], [], [], []) class BaseParallelProcessor: @@ -61,6 +74,7 @@ def __init__( exclude_paths: Optional[List[str]] = None, files_regex_pattern: Optional[str] = None, retries_on_error: int = 0, + process_single_kwargs: Union[None, KwargsType, List[KwargsType]] = None, ): """Initialize the parallel processor. @@ -87,6 +101,14 @@ def __init__( that match one of the paths will be processed. Defaults to None. exclude_paths (Optional[List[str]], optional): A list of paths to exclude. If provided, files that match one of the paths will be skipped. Defaults to None. + files_regex_pattern (Optional[str], optional): A regex pattern to match files. If provided, only + files that match the pattern will be processed. Defaults to None. + retries_on_error (int, optional): The number of retries to attempt if an error occurs. + Defaults to 0. + process_single_kwargs (Union[None, KwargsType, List[KwargsType]], optional): Additional kwargs to + pass to the process_single method. If a single dict is provided, it will be used for all source + prefixes. If a list of dicts is provided, each dict will be used for the corresponding source. + By default, no additional kwargs are passed. """ self.src_prefixes = [source_prefix] if isinstance(source_prefix, str) else source_prefix @@ -103,6 +125,13 @@ def __init__( self.files_regex_pattern = re.compile(files_regex_pattern) if files_regex_pattern else None self.retries_on_error = retries_on_error + # this are additional kwargs to pass to the process_single method + process_single_kwargs = process_single_kwargs or {} + if isinstance(process_single_kwargs, dict): + self.process_single_kwargs = [process_single_kwargs] * len(self.src_prefixes) + else: + self.process_single_kwargs = process_single_kwargs + # checking that the increment_progressbar method is subclassed correctly sig = inspect.signature(self.increment_progressbar) if "queue" not in sig.parameters or sig.parameters["queue"].kind != inspect.Parameter.POSITIONAL_ONLY: @@ -131,6 +160,11 @@ def __init__( "The number of source and metadata prefixes must be the same." f"(got {len(self.src_prefixes)} and {len(self.meta_prefixes)})" ) + elif len(self.src_prefixes) != len(self.process_single_kwargs): + raise ValueError( + "The number of source prefixes and process_single_kwargs must be the same." + f"(got {len(self.src_prefixes)} and {len(self.process_single_kwargs)})" + ) if len(self.src_prefixes) == 0: raise ValueError("At least one source prefix must be provided.") @@ -251,6 +285,7 @@ def _debug_run_all( all_source_paths: List[str], all_destination_paths: List[str], all_metadata_paths: List[str], + all_process_kwargs: Union[List[KwargsType], None] = None, **process_single_kwargs: Any, ): """Run files one by one on the main process @@ -259,30 +294,84 @@ def _debug_run_all( all_source_paths (List[MultiPath]): The list of source paths to process. all_destination_paths (List[MultiPath]): The list of destination paths to save. all_metadata_paths (List[MultiPath]): The locations where to save metadata. + all_process_kwargs (Union[List[KwargsType], None]): Additional kwargs to pass to the process_single """ - it = zip(all_source_paths, all_destination_paths, all_metadata_paths) + arguments_iterator = zip( + # source paths + all_source_paths, + # destination paths + all_destination_paths, + # this is where we save the metadata to keep track of which files have been processed + all_metadata_paths, + # additional kwargs to pass to the process_single; if not provided, we use an empty dict + # will be merged with the process_single_kwargs + all_process_kwargs or [{} for _ in all_source_paths], + ) pbar_queue: QueueType = Queue() thread = Thread(target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True) thread.start() - for source_prefix, destination_prefix, metadata_prefix in it: + for source_path, destination_path, metadata_path, process_kwargs in arguments_iterator: self._process_single_and_save_status( - source_path=source_prefix, - destination_path=destination_prefix, - metadata_path=metadata_prefix, + source_path=source_path, + destination_path=destination_path, + metadata_path=metadata_path, queue=pbar_queue, - serialized_kwargs=pickle.dumps(process_single_kwargs), + serialized_kwargs=pickle.dumps({**process_kwargs, **process_single_kwargs}), ) pbar_queue.put(None) thread.join() + def __add__(self: BPP, other: BPP) -> BPP: + """Combine two parallel processors into one.""" + if not type(self) is type(other): + raise TypeError(f"Cannot add {type(self)} and {type(other)}") + + # we try combining the two list of include paths; if they are both None, then set the combo back to none + include_paths: Union[List[str], None] = [*(self.include_paths or []), *(other.include_paths or [])] + include_paths = sorted(set(include_paths or [])) if len(include_paths or []) else None + + # do the same for exclude paths + exclude_paths: Union[List[str], None] = [*(self.exclude_paths or []), *(other.exclude_paths or [])] + exclude_paths = sorted(set(exclude_paths or [])) if len(exclude_paths or []) else None + + # for the regex, do a simple or if both are set + regex_pattern: Union[str, None] = None + if self.files_regex_pattern and other.files_regex_pattern: + regex_pattern = "(" + self.files_regex_pattern.pattern + "|" + other.files_regex_pattern.pattern + ")" + elif self.files_regex_pattern: + regex_pattern = self.files_regex_pattern.pattern + elif other.files_regex_pattern: + regex_pattern = other.files_regex_pattern.pattern + + return type(self)( + source_prefix=[*self.src_prefixes, *other.src_prefixes], + destination_prefix=[*self.dst_prefixes, *other.dst_prefixes], + metadata_prefix=[*self.meta_prefixes, *other.meta_prefixes], + num_processes=max(self.num_processes, other.num_processes), + debug=self.debug or other.debug, + seed=self.seed, + pbar_timeout=max(self.pbar_timeout, other.pbar_timeout), + ignore_existing=self.ignore_existing or other.ignore_existing, + include_paths=include_paths, + exclude_paths=exclude_paths, + files_regex_pattern=regex_pattern, + retries_on_error=max(self.retries_on_error, other.retries_on_error), + process_single_kwargs=[*self.process_single_kwargs, *other.process_single_kwargs], + ) + + def __radd__(self: BPP, other: BPP) -> BPP: + """Combine two parallel processors into one.""" + return other.__add__(self) + def _multiprocessing_run_all( self, all_source_paths: List[str], all_destination_paths: List[str], all_metadata_paths: List[str], + all_process_kwargs: Union[List[KwargsType], None] = None, **process_single_kwargs: Any, ): """Run files in parallel using multiprocessing. @@ -291,13 +380,38 @@ def _multiprocessing_run_all( all_source_paths (List[MultiPath]): The list of source paths to process. all_destination_paths (List[MultiPath]): The list of destination paths to save. all_metadata_paths (List[MultiPath]): The locations where to save metadata. + all_process_kwargs (Union[List[KwargsType], None]): Additional kwargs to pass to the process_single """ try: multiprocessing.set_start_method("spawn") except RuntimeError: assert multiprocessing.get_start_method() == "spawn", "Multiprocessing start method must be spawn" - with multiprocessing.Pool(processes=self.num_processes) as pool: + all_process_kwargs = all_process_kwargs or [{} for _ in all_source_paths] + + arguments_iterator = zip( + # source paths + all_source_paths, + # destination paths + all_destination_paths, + # this is where we save the metadata to keep track of which files have been processed + all_metadata_paths, + # additional kwargs to pass to the process_single; if not provided, we use an empty dict + # will be merged with the process_single_kwargs + all_process_kwargs, + ) + + # no need to be wasteful with processes: we only need as many cores a the minimum of the number of + # source paths, destination paths, metadata paths, and process kwargs. + num_processes = min( + self.num_processes, + len(all_source_paths), + len(all_destination_paths), + len(all_metadata_paths), + len(all_process_kwargs), + ) + + with multiprocessing.Pool(processes=num_processes) as pool: pbar_queue: QueueType = (manager := multiprocessing.Manager()).Queue() thread = Thread( target=self._run_threaded_progressbar, args=(pbar_queue, self.pbar_timeout), daemon=True @@ -307,14 +421,15 @@ def _multiprocessing_run_all( process_single_fn = partial(self.process_single, queue=pbar_queue) results = [] - for s, d, m in zip(all_source_paths, all_destination_paths, all_metadata_paths): + for source_path, destination_path, metadata_path, process_kwargs in arguments_iterator: process_single_fn = partial( self._process_single_and_save_status, queue=pbar_queue, - source_path=s, - destination_path=d, - metadata_path=m, - serialized_kwargs=pickle.dumps(process_single_kwargs), + source_path=source_path, + destination_path=destination_path, + metadata_path=metadata_path, + # we need to merge the process_single_kwargs with the additional kwargs + serialized_kwargs=pickle.dumps({**process_kwargs, **process_single_kwargs}), ) result = pool.apply_async(process_single_fn) results.append(result) @@ -338,11 +453,13 @@ def _valid_path(self, path: str) -> bool: return False return True - def _get_all_paths(self) -> Tuple[List[str], List[str], List[str]]: + def _get_all_paths(self) -> AllPathsTuple: """Get all paths to process using prefixes provided""" - all_source_paths, all_destination_paths, all_metadata_paths = [], [], [] + all_paths = AllPathsTuple.empty() - for src_prefix, dst_prefix, meta_prefix in zip(self.src_prefixes, self.dst_prefixes, self.meta_prefixes): + for src_prefix, dst_prefix, meta_prefix, kwargs_prefix in zip( + self.src_prefixes, self.dst_prefixes, self.meta_prefixes, self.process_single_kwargs + ): current_source_prefixes = sorted(glob_path(src_prefix)) if len(current_source_prefixes) > 1: @@ -374,28 +491,31 @@ def _get_all_paths(self) -> Tuple[List[str], List[str], List[str]]: continue # create new paths to pass to taggers - all_source_paths.append(add_suffix(prefix, path)) - all_destination_paths.append(add_suffix(dst_prefix, path)) - all_metadata_paths.append(add_suffix(meta_prefix, path) + METADATA_SUFFIX) + all_paths.src.append(add_suffix(prefix, path)) + all_paths.dst.append(add_suffix(dst_prefix, path)) + all_paths.meta.append(add_suffix(meta_prefix, path) + METADATA_SUFFIX) + all_paths.kwargs.append(kwargs_prefix or {}) - return all_source_paths, all_destination_paths, all_metadata_paths + return all_paths def __call__(self, **process_single_kwargs: Any): """Run the processor.""" + random.seed(self.seed) # in case the user wants to override the default kwargs for retries process_single_kwargs.setdefault("retries_on_error", self.retries_on_error) - all_source_paths, all_destination_paths, all_metadata_paths = self._get_all_paths() + all_paths = self._get_all_paths() - print(f"Found {len(all_source_paths):,} files to process") + print(f"Found {len(all_paths.src):,} files to process") fn = self._debug_run_all if self.debug else self._multiprocessing_run_all fn( - all_source_paths=all_source_paths, - all_destination_paths=all_destination_paths, - all_metadata_paths=all_metadata_paths, + all_source_paths=all_paths.src, + all_destination_paths=all_paths.dst, + all_metadata_paths=all_paths.meta, + all_process_kwargs=all_paths.kwargs, **process_single_kwargs, ) diff --git a/python/dolma/core/paths.py b/python/dolma/core/paths.py index 450ed737..5b293a91 100644 --- a/python/dolma/core/paths.py +++ b/python/dolma/core/paths.py @@ -1,16 +1,16 @@ import glob -import pickle import re from functools import partial from hashlib import sha256 from itertools import chain from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import urlparse import platformdirs import smart_open from fsspec import AbstractFileSystem, get_filesystem_class +from smart_open.compression import get_supported_extensions from .loggers import get_logger @@ -155,12 +155,6 @@ def delete_dir(path: str, ignore_missing: bool = False) -> bool: return deleted -def cache_location(key: Any) -> Tuple[str, bool]: - key_hash = sha256(pickle.dumps(key)).hexdigest() - path = f"{get_cache_dir()}/{key_hash}" - return path, exists(path) - - def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]: """Partition a path into its protocol, symbols before a glob, and symbols after a glob.""" # split the path into its protocol and path components @@ -204,7 +198,13 @@ def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> return _unescape_glob(path) -def glob_path(path: Union[Path, str], hidden_files: bool = False, autoglob_dirs: bool = True) -> Iterator[str]: +def glob_path( + path: Union[Path, str], + hidden_files: bool = False, + autoglob_dirs: bool = True, + recursive_dirs: bool = False, + yield_dirs: bool = True, +) -> Iterator[str]: """ Expand a glob path into a list of paths. """ @@ -220,7 +220,19 @@ def glob_path(path: Union[Path, str], hidden_files: bool = False, autoglob_dirs: if not hidden_files and Path(gl).name.startswith("."): continue - yield join_path(protocol, gl) + if fs.isdir(gl): + if recursive_dirs: + yield from glob_path( + gl, + hidden_files=hidden_files, + autoglob_dirs=autoglob_dirs, + recursive_dirs=recursive_dirs, + yield_dirs=yield_dirs, + ) + if yield_dirs: + yield join_path(protocol, gl) + else: + yield join_path(protocol, gl) def sub_prefix(a: str, b: str) -> str: @@ -282,6 +294,22 @@ def exists(path: str) -> bool: return fs.exists(path) +def is_dir(path: str) -> bool: + """Check if a path is a directory.""" + if exists(path): + fs = _get_fs(path) + return fs.isdir(path) + return False + + +def is_file(path: str) -> bool: + """Check if a path is a file.""" + if exists(path): + fs = _get_fs(path) + return fs.isfile(path) + return False + + def parent(path: str) -> str: """Get the parent directory of a path; if the parent is the root, return the root.""" @@ -374,23 +402,16 @@ def get_cache_dir() -> str: return loc -def resource_to_filename(resource: str) -> str: +def resource_to_filename(resource: Union[str, bytes]) -> str: """ - Convert a ``resource`` into a hashed filename in a repeatable way. - If ``etag`` is specified, append its hash to the resources', delimited - by a period. - - THis is essentially the inverse of :func:`filename_to_url()`. + Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions. """ _, (*_, orig_filename) = split_path(remove_params(str(resource))) + _, extensions = split_basename_and_extension(orig_filename) resource_bytes = str(resource).encode("utf-8") resource_hash = sha256(resource_bytes) - hash_filename = resource_hash.hexdigest() - - if "." in orig_filename: - _, extension = orig_filename.split(".", 1) - hash_filename += f".{extension}" + hash_filename = resource_hash.hexdigest() + extensions return hash_filename @@ -423,3 +444,65 @@ def cached_path(path: str) -> str: dest.write(src.read()) return destination + + +def split_basename_and_extension(path: str) -> Tuple[str, str]: + """ + Get the path and extension from a given file path. If a file has multiple + extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz" + will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the + second element of the tuple will be an empty string. Works with both local + and remote (e.g. s3://) paths. + + Args: + path (str): The file path. + + Returns: + Tuple[str, str]: A tuple containing the path and extension. + """ + prot, (*parts, filename) = split_path(path) + base, *ext_parts = filename.split(".") + ext = ("." + ".".join(ext_parts)) if ext_parts else "" + return join_path(prot, *parts, base), ext + + +def decompress_path(path: str, dest: Optional[str] = None) -> str: + """ + Decompresses a file at the given path and returns the path to the decompressed file. + + Args: + path (str): The path to the file to be decompressed. + dest (str, optional): The destination path for the decompressed file. + If not provided, a destination path will be computed based on the original + file name and the cache directory. + + Returns: + str: The path to the decompressed file. If the file cannot be decompressed, + the original path will be returned. + """ + for supported_ext in get_supported_extensions(): + # not the supported extension + if not path.endswith(supported_ext): + continue + + if dest is None: + # compute the name for the decompressed file; to do this, we first hash for + # resource and then remove the extension. + base_fn, ext = split_basename_and_extension(resource_to_filename(path)) + + # to get the decompressed file name, we remove the bit of the extension that + # indicates the compression type. + decompressed_fn = base_fn + ext.replace(supported_ext, "") + + # finally, we get cache directory and join the decompressed file name to it + dest = join_path("", get_cache_dir(), decompressed_fn) + + # here we do the actual decompression + with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw: + fw.write(fr.read()) + + # return the path to the decompressed file + return dest + + # already decompressed or can't be decompressed + return path diff --git a/python/dolma/core/utils.py b/python/dolma/core/utils.py index 4454a003..2f5c5eb6 100644 --- a/python/dolma/core/utils.py +++ b/python/dolma/core/utils.py @@ -3,7 +3,7 @@ import re import string import sys -from typing import List, Union +from typing import List, Union, cast try: import blingfire @@ -15,6 +15,7 @@ import nltk import uniseg.wordbreak from nltk.tokenize.punkt import PunktSentenceTokenizer +from omegaconf import OmegaConf as om try: nltk.data.find("tokenizers/punkt") @@ -140,3 +141,10 @@ def import_modules(modules_path: Union[List[str], None]): f"({module_name}) is not globally unique. Please rename the directory to " "something unique and try again." ) + + +def dataclass_to_dict(dataclass_instance) -> dict: + """Convert a dataclass instance to a dictionary through the omegaconf library.""" + + # force typecasting because a dataclass instance will always be a dict + return cast(dict, om.to_object(om.structured(dataclass_instance))) diff --git a/src/deduper.rs b/src/deduper.rs index a9e6105b..f54da183 100644 --- a/src/deduper.rs +++ b/src/deduper.rs @@ -207,10 +207,12 @@ fn write_attributes( let text = data["text"].as_str().unwrap(); let text_length = text.len(); let mut offset = 0; - let paragraphs = text.split('\n'); - let mut duplicate_paragraph_spans = Vec::new(); if text_length > 0 { + let paragraphs = + text.split(cfg.paragraph_separator.as_deref().unwrap_or("\n")); + let mut duplicate_paragraph_spans = Vec::new(); + // skip empty documents if text_length is 0 for p in paragraphs { let par_start = offset; @@ -286,19 +288,36 @@ fn write_attributes( } word_index += 1; } - if ngram_count < 2 { + if ngram_count < 2 + && !by_ngram.skip_short_paragraphs.unwrap_or(false) + { // Too few ngrams to dedupe by overlap. Just compare the whole thing let dedupe_key = VecDeque::from([p]); - if bloom_filter.contains(&dedupe_key) { + + let span_score = match bloom_filter.contains(&dedupe_key) { + // we found a match! score is 1.0 + true => 1.0, + false => { + // this is a new paragraph, push to bloom filter + if !bloom_filter.read_only { + bloom_filter.insert(&dedupe_key); + } + // score is 0.0 because it's not a duplicate + 0.0 + } + }; + + // we check if the score is above the threshold; note that + // users can set the threshold to 0.0 to always include the span, + // or 1.0 to only include spans that are exact duplicates. + if span_score >= by_ngram.overlap_threshold { let span = vec![ Value::Number(par_start.into()), Value::Number(par_end.into()), - Value::from(1), + Value::from(span_score), ]; // add span to duplicate_paragraph_spans duplicate_paragraph_spans.push(Value::Array(span)); - } else if !bloom_filter.read_only { - bloom_filter.insert(&dedupe_key); } } else { let overlap_fraction = @@ -373,6 +392,9 @@ pub mod deduper_config { // If defined, remove paragraphs based on contained ngrams // Otherwise, hash the entire paragraph pub by_ngram: Option, + + // if not defined, we use '\n' as the paragraph separator + pub paragraph_separator: Option, } #[derive(Serialize, Deserialize, Clone)] @@ -383,6 +405,8 @@ pub mod deduper_config { pub stride: usize, // Treat as duplicate if more than this fraction of ngrams have been seen before pub overlap_threshold: f32, + // If true, skip checking for duplicates if the paragraph is shorter ngram_length + stride + pub skip_short_paragraphs: Option, } #[derive(Serialize, Deserialize, Clone)] diff --git a/tests/config/mixer.json b/tests/config/mixer.json index 2aaaf9f6..37c74395 100644 --- a/tests/config/mixer.json +++ b/tests/config/mixer.json @@ -3,7 +3,7 @@ { "name": "mixer-test", "documents": [ - "tests/data/provided/documents/*" + "tests/data/provided/documents/*.gz" ], "output": { "path": "tests/work/output/mixer", diff --git a/tests/config/paragraph-spans.json b/tests/config/paragraph-spans.json index 9f25e2da..a242dbe7 100644 --- a/tests/config/paragraph-spans.json +++ b/tests/config/paragraph-spans.json @@ -3,7 +3,7 @@ { "name": "paragraph-spans-test", "documents": [ - "tests/data/provided/documents/*" + "tests/data/provided/documents/*.gz" ], "output": { "path": "tests/work/output/paragraph-spans", diff --git a/tests/python/test_analysis.py b/tests/python/test_analysis.py new file mode 100644 index 00000000..17f37bb3 --- /dev/null +++ b/tests/python/test_analysis.py @@ -0,0 +1,114 @@ +from unittest import TestCase + +import numpy as np + +from dolma.core.binning import cumsum_with_reset, equal_count_hist + + +class TestResetCumsumNpNoSplit(TestCase): + def test_multiple_zeros(self): + arr = np.array([1, 2, 3, 0, 4, 5, 0, 6]) + expected = np.array([1, 3, 6, 0, 4, 9, 0, 6]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + def test_no_zeros(self): + arr = np.array([1, 2, 3, 4, 5]) + expected = np.array([1, 3, 6, 10, 15]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + def test_start_with_zero(self): + arr = np.array([0, 1, 2, 3, 4]) + expected = np.array([0, 1, 3, 6, 10]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + def test_end_with_zero(self): + arr = np.array([1, 2, 3, 0]) + expected = np.array([1, 3, 6, 0]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + def test_all_zeros(self): + arr = np.array([0, 0, 0, 0]) + expected = np.array([0, 0, 0, 0]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + def test_empty_array(self): + arr = np.array([]) + expected = np.array([]) + result = np.array(cumsum_with_reset(arr)) + np.testing.assert_array_equal(result, expected) + + +class TestEqualCountHist(TestCase): + def test_basic(self): + arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + counts, bins = equal_count_hist(a=arr, bins=9) + + self.assertEqual(len(counts), 9) + self.assertEqual(len(bins), 10) + self.assertEqual(counts.sum(), len(arr)) + np.testing.assert_array_equal(counts, np.ones_like(counts)) + np.testing.assert_array_equal(bins, [1] + [i + 0.5 for i in range(1, 9)] + [9]) + + def test_bin_in_three(self): + arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + counts, bins = equal_count_hist(a=arr, bins=3) + self.assertEqual(len(arr), counts.sum()) + np.testing.assert_array_equal(counts, [3, 3, 3]) + np.testing.assert_array_equal(bins, [1.0, 3.5, 6.5, 9]) + + def test_very_large_bins(self): + arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + weights = np.array([1, 1000, 1, 1, 1, 1000, 1, 1, 1]) + counts, bins = equal_count_hist(a=arr, bins=3, weights=weights) + self.assertEqual(weights.sum(), counts.sum()) + np.testing.assert_array_equal(counts, [1001, 1003, 3]) + np.testing.assert_array_equal(bins, [1.0, 2.5, 6.5, 9]) + + def test_large_end_bin(self): + arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + weights = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1000]) + counts, bins = equal_count_hist(a=arr, bins=3, weights=weights) + self.assertEqual(weights.sum(), counts.sum()) + np.testing.assert_array_equal(counts, [3, 3, 1002]) + np.testing.assert_array_equal(bins, [1.0, 3.5, 6.5, 9]) + + def test_zero_bins(self): + arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + weights = np.array([1, 0, 0, 0, 0, 0, 0, 0, 1000]) + counts, bins = equal_count_hist(a=arr, bins=3, weights=weights) + self.assertEqual(weights.sum(), counts.sum()) + np.testing.assert_array_equal(counts, [1, 0, 1000]) + np.testing.assert_array_equal(bins, [1.0, 1.5, 5.5, 9]) + + def test_short_array(self): + arr = np.array([0, 1]) + counts, bins = equal_count_hist(a=arr, bins=3) + self.assertEqual(len(arr), counts.sum()) + np.testing.assert_array_equal(counts, [1, 1]) + np.testing.assert_array_equal(bins, [0, 0.5, 1]) + + def test_empty_array(self): + arr = np.array([]) + counts, bins = equal_count_hist(a=arr, bins=3) + np.testing.assert_array_equal(counts, []) + np.testing.assert_array_equal(bins, []) + + def test_single_array(self): + arr = np.array([1]) + counts, bins = equal_count_hist(a=arr, bins=3) + self.assertEqual(len(arr), counts.sum()) + np.testing.assert_array_equal(counts, [1]) + np.testing.assert_array_equal(bins, [1, 1]) + + def test_no_natural_splits(self): + array = np.array([1, 2, 3, 4, 5, 6, 7]) + weights = np.array([1000, 1, 1, 1000, 1, 1, 1000]) + counts, bins = equal_count_hist(a=array, bins=4, weights=weights) + self.assertEqual(weights.sum(), counts.sum()) + np.testing.assert_array_equal(counts, [1000, 1, 1001, 1002]) + np.testing.assert_array_equal(bins, [1, 1.5, 2.5, 4.5, 7]) diff --git a/tests/python/test_deduper.py b/tests/python/test_deduper.py index d53a7c0c..ba255ee1 100644 --- a/tests/python/test_deduper.py +++ b/tests/python/test_deduper.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict from dolma.cli.__main__ import main +from dolma.core.utils import split_words from .utils import ( TestCasePipeline, @@ -98,6 +99,106 @@ def test_dedupe_paragraphs(self): ) return self._compare_dedupe_output(expected, computed) # pyright: ignore + def test_dedupe_paragraphs_change_splitter(self): + with open(DEDUPE_PARAGRAPHS, "r") as f: + config = json.load(f) + + config["documents"] = [f'{self.local_temp_dir}/{config["documents"][0]}'] + config["bloom_filter"]["file"] = f'{self.local_temp_dir}/{config["bloom_filter"]["file"]}' + + split_seq = "tt" + + # separate on characters "tt" instead of "\n" + config["dedupe"]["paragraphs"]["paragraph_separator"] = split_seq + + # this will ensure that the deduper will output something for each paragraph + config["dedupe"]["paragraphs"]["by_ngram"] = {"ngram_length": 1, "stride": 1, "overlap_threshold": 0.0} + + with NamedTemporaryFile("w") as f: + json.dump(config, f) + f.flush() + + main(argv=["-c", f.name, "dedupe"]) + + documents = load_jsonl(f"{self.local_temp_dir}/tests/data/provided/deduper/documents/000.json.gz") + attributes = load_jsonl( + f"{self.local_temp_dir}/tests/data/provided/deduper/attributes/dedupe_paragraphs/000.json.gz" + ) + for doc, attr in zip(documents, attributes): + self.assertEqual( + len(doc["text"].split(split_seq)), len(attr["attributes"]["bff_duplicate_paragraph_spans"]) + ) + + def test_dedupe_paragraphs_stride_math(self): + with open(DEDUPE_PARAGRAPHS, "r") as f: + config = json.load(f) + + config["documents"] = [f'{self.local_temp_dir}/{config["documents"][0]}'] + config["bloom_filter"]["file"] = f'{self.local_temp_dir}/{config["bloom_filter"]["file"]}' + + # this will ensure that the deduper will output something for each paragraph + config["dedupe"]["paragraphs"]["by_ngram"] = {"ngram_length": 10, "stride": 5, "overlap_threshold": 0.0} + + with NamedTemporaryFile("w") as f: + json.dump(config, f) + f.flush() + + main(argv=["-c", f.name, "dedupe"]) + + documents = load_jsonl(f"{self.local_temp_dir}/tests/data/provided/deduper/documents/000.json.gz") + attributes = load_jsonl( + f"{self.local_temp_dir}/tests/data/provided/deduper/attributes/dedupe_paragraphs/000.json.gz" + ) + for doc, attr in zip(documents, attributes): + valid_paragraphs = [] + i = 0 + for para in doc["text"].split("\n"): + j = min(i + len(para) + 1, len(doc["text"])) + valid_paragraphs.append((i, j)) + i = j + spans = attr["attributes"]["bff_duplicate_paragraph_spans"] + + self.assertEqual(len(valid_paragraphs), len(spans)) + for (start_para, end_para), (start_span, end_span, _) in zip(valid_paragraphs, spans): + self.assertEqual(doc["text"][start_para:end_para], doc["text"][start_span:end_span]) + + def test_dedupe_paragraphs_stride_math_skip_short(self): + with open(DEDUPE_PARAGRAPHS, "r") as f: + config = json.load(f) + + config["documents"] = [f'{self.local_temp_dir}/{config["documents"][0]}'] + config["bloom_filter"]["file"] = f'{self.local_temp_dir}/{config["bloom_filter"]["file"]}' + + # this will ensure that the deduper will output something for each paragraph + config["dedupe"]["paragraphs"]["by_ngram"] = ( + ng_cfg := {"ngram_length": 20, "stride": 5, "overlap_threshold": 0.0, "skip_short_paragraphs": True} + ) + + with NamedTemporaryFile("w") as f: + json.dump(config, f) + f.flush() + + main(argv=["-c", f.name, "dedupe"]) + + documents = load_jsonl(f"{self.local_temp_dir}/tests/data/provided/deduper/documents/000.json.gz") + attributes = load_jsonl( + f"{self.local_temp_dir}/tests/data/provided/deduper/attributes/dedupe_paragraphs/000.json.gz" + ) + for doc, attr in zip(documents, attributes): + valid_paragraphs = [] + i = 0 + for para in doc["text"].split("\n"): + j = min(i + len(para) + 1, len(doc["text"])) + if len(split_words(para)) >= ng_cfg["ngram_length"]: + valid_paragraphs.append((i, j)) + i = j + spans = attr["attributes"]["bff_duplicate_paragraph_spans"] + + self.assertEqual(len(valid_paragraphs), len(spans)) + + for (start_para, end_para), (start_span, end_span, _) in zip(valid_paragraphs, spans): + self.assertEqual(doc["text"][start_para:end_para], doc["text"][start_span:end_span]) + def test_dedupe_paragraph_ngrams(self): with open(DEDUPE_PARAGRAPH_NGRAMS, "r") as f: config = json.load(f) diff --git a/tests/python/test_runtime.py b/tests/python/test_runtime.py index e15a50d1..18f7ee4a 100644 --- a/tests/python/test_runtime.py +++ b/tests/python/test_runtime.py @@ -167,10 +167,12 @@ def test_multiple_taggers(self, experiment_name: Optional[str] = None): (temp_path / "documents").mkdir(exist_ok=True) for path in documents_dir.iterdir(): - shutil.copy(path, temp_path / "documents" / path.name) + # ignore non-json files, like .DS_Store + if path.suffix.endswith(".gz"): + shutil.copy(path, temp_path / "documents" / path.name) create_and_run_tagger( - documents=[os.path.join(temp_dir, "documents") + "/*"], + documents=[f"{temp_dir}/documents/*"], taggers=taggers, experiment=experiment_name, debug=True, @@ -188,7 +190,7 @@ def test_multiple_taggers(self, experiment_name: Optional[str] = None): # collect all attributes for all documents here attributes = [] - for fn in documents_dir.iterdir(): + for fn in (temp_path / "documents").iterdir(): # collect all attributes for the current document here current_attrs: List[dict] = []