From 50abfb7cc6b6ce9053557baf47109db70011428d Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Tue, 30 Jan 2024 07:53:32 +0100 Subject: [PATCH] Add robust median to gopher filter (#98) * Added robust median to gopher filter * Added robust median to gopher filter Co-authored-by: @TTTTao725 * Added robust median to gopher filter Co-authored-by: TTTTao725 * Added robust median to gopher filter Co-authored-by: TTTTao725 * fixed typing to use union * reformatted with black * formatted using `make style` * attempting to fix style issues * one more style change using python 3.10 * more style, make robust median always a float --------- Co-authored-by: TTTTao725 Co-authored-by: TTTTao725 Co-authored-by: Luca Soldaini --- .github/workflows/CI.yml | 5 +- Makefile | 11 ++- pyproject.toml | 39 ++++++++--- python/dolma/cli/__init__.py | 8 +-- python/dolma/core/ft_tagger.py | 1 + python/dolma/core/taggers.py | 1 + python/dolma/taggers/code/code_taggers.py | 1 + python/dolma/taggers/code/utils.py | 1 + python/dolma/taggers/gopher.py | 67 ++++++++++++++++--- python/dolma/taggers/language.py | 1 + python/dolma/taggers/pii.py | 1 - .../extras_from_module/extra_taggers.py | 2 +- .../extras_from_module_path/extra_taggers.py | 2 +- .../extras/extras_from_path/extra_taggers.py | 2 +- tests/python/test_paths.py | 8 ++- tests/python/test_utils.py | 1 - 16 files changed, 110 insertions(+), 41 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 146342a0..75aa6e91 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -107,8 +107,9 @@ jobs: task: - name: Check Python style run: | - isort --check tests/python/ && isort --check python/ - black --check tests/python/ && black --check python/ + set -e + isort --check --verbose . + black --check --verbose . - name: Check Rust style run: | diff --git a/Makefile b/Makefile index d756f386..cad13a43 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,7 @@ publish: test: test-python test-rust test-python: + maturin develop --extras="all" pytest -vsx tests/python rm -rf tests/work/* @@ -39,13 +40,9 @@ test-rust: rm -rf tests/work/* develop: - maturin develop --extras=all + maturin develop --extras="all" style: rustfmt --edition 2021 src/*.rs - autopep8 --in-place --recursive python/ - isort python/ - black python/ - autopep8 --in-place --recursive tests/python/ - isort tests/python/ - black tests/python/ + isort . + black . diff --git a/pyproject.toml b/pyproject.toml index 58836ac5..a8d6dc84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,15 +96,14 @@ dolma = "dolma.cli.__main__:main" [project.optional-dependencies] dev = [ "black>=22.6.0", - "isort>=5.10.1", - "mypy>=0.971", - "pytest>=5.2", - "ipython>=8.4.0", - "autopep8>=1.7.0", "flake8>=5.0", - "ipdb>=0.13.0", "flake8-pyi>=22.8.1", "Flake8-pyproject>=1.1.0", + "ipdb>=0.13.0", + "ipython>=8.4.0", + "isort>=5.10.1", + "mypy>=0.971", + "pytest>=5.2", ] # extension to process code code = [ @@ -163,9 +162,7 @@ dolma = ["py.typed", "data/*"] [tool.black] line-length = 115 - include = '\.pyi?$' - exclude = ''' ( __pycache__ @@ -176,12 +173,36 @@ exclude = ''' | \.venv | \bdist\b | \bdoc\b + | \.cargo + | configs + | docs + | scripts + | sources + | src + | target + | tests/config + | tests/data + | tests/work ) ''' +target-version = ["py38", "py39", "py310", "py311", "py312"] + [tool.isort] profile = "black" -multi_line_output = 3 +py_version=38 +known_first_party = ["dolma"] +known_local_folder = ["tests", "python"] +extend_skip_glob = [ + "configs/*", + "docs/*", + "scripts/*", + "sources/*", + "src/*", + "tests/config/*", + "tests/data/*", + "tests/work/*" +] [tool.autopep8] max_line_length = 115 diff --git a/python/dolma/cli/__init__.py b/python/dolma/cli/__init__.py index 268c26af..86167196 100644 --- a/python/dolma/cli/__init__.py +++ b/python/dolma/cli/__init__.py @@ -4,14 +4,13 @@ Author: Luca Soldaini (@soldni) """ - from argparse import ArgumentParser, Namespace from collections.abc import Iterable from copy import deepcopy from dataclasses import Field from dataclasses import field as dataclass_field from dataclasses import is_dataclass -from logging import warn +from logging import warning from typing import ( Any, Dict, @@ -74,14 +73,13 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING)) if typ_ is MISSING: - warn(f"No type annotation for field {field_name} in {config.__name__}") + warning(f"No type annotation for field {field_name} in {config.__name__}") continue # join prefix and field name field_name = f"{prefix}.{field_name}" if prefix else field_name - # This section here is to handle Optional[T] types - # We only care for cases where T is a dataclass + # This section here is to handle Optional[T] types; we only care for cases where T is a dataclass # So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None] # and that the union contains only one non-None type if get_origin(typ_) == Union: diff --git a/python/dolma/core/ft_tagger.py b/python/dolma/core/ft_tagger.py index e68a01a3..53bcc031 100644 --- a/python/dolma/core/ft_tagger.py +++ b/python/dolma/core/ft_tagger.py @@ -5,6 +5,7 @@ @kylel, @soldni """ + import os from tempfile import NamedTemporaryFile from typing import Iterable, Literal, NamedTuple, Optional diff --git a/python/dolma/core/taggers.py b/python/dolma/core/taggers.py index 342985fd..59a414f7 100644 --- a/python/dolma/core/taggers.py +++ b/python/dolma/core/taggers.py @@ -5,6 +5,7 @@ @kylel, @soldni """ + from abc import abstractmethod from typing import List diff --git a/python/dolma/taggers/code/code_taggers.py b/python/dolma/taggers/code/code_taggers.py index feefbefd..31b57087 100644 --- a/python/dolma/taggers/code/code_taggers.py +++ b/python/dolma/taggers/code/code_taggers.py @@ -5,6 +5,7 @@ @akshitab """ + import logging import re from typing import List diff --git a/python/dolma/taggers/code/utils.py b/python/dolma/taggers/code/utils.py index e7c2c474..04c7fc78 100644 --- a/python/dolma/taggers/code/utils.py +++ b/python/dolma/taggers/code/utils.py @@ -5,6 +5,7 @@ @akshitab, @soldni """ + import json import logging from pathlib import Path diff --git a/python/dolma/taggers/gopher.py b/python/dolma/taggers/gopher.py index 9a5a28a3..5b5348e0 100644 --- a/python/dolma/taggers/gopher.py +++ b/python/dolma/taggers/gopher.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from statistics import median from typing import Counter as CounterType -from typing import List, Tuple +from typing import List, Tuple, Union from ..core.data_types import DocResult, Document, Span from ..core.registry import TaggerRegistry @@ -14,13 +14,19 @@ BULLET_POINTS = {"*", "-"} +def robust_median(values: List[Union[int, float]]) -> float: + if not values: + return 0.0 + return float(median(values)) + + @dataclass class GopherAttributes: fraction_of_characters_in_most_common_ngram: List[Tuple[int, float]] fraction_of_characters_in_duplicate_ngrams: List[Tuple[int, float]] character_count: int = 0 word_count: int = 0 - median_word_length: float = False + median_word_length: float = 0.0 symbol_to_word_ratio: float = 0.0 fraction_of_words_with_alpha_character: float = 0.0 required_word_count: int = 0 @@ -33,20 +39,51 @@ def as_spans(self) -> List[Span]: spans = [] spans.extend( [ - Span(0, self.character_count, f"fraction_of_characters_in_most_common_{n}grams", v) + Span( + 0, + self.character_count, + f"fraction_of_characters_in_most_common_{n}grams", + v, + ) for n, v in self.fraction_of_characters_in_most_common_ngram ] ) spans.extend( [ - Span(0, self.character_count, f"fraction_of_characters_in_duplicate_{n}grams", v) + Span( + 0, + self.character_count, + f"fraction_of_characters_in_duplicate_{n}grams", + v, + ) for n, v in self.fraction_of_characters_in_duplicate_ngrams ] ) - spans.append(Span(0, self.character_count, type="character_count", score=self.character_count)) + spans.append( + Span( + 0, + self.character_count, + type="character_count", + score=self.character_count, + ) + ) spans.append(Span(0, self.character_count, type="word_count", score=self.word_count)) - spans.append(Span(0, self.character_count, type="median_word_length", score=self.median_word_length)) - spans.append(Span(0, self.character_count, type="symbol_to_word_ratio", score=self.symbol_to_word_ratio)) + spans.append( + Span( + 0, + self.character_count, + type="median_word_length", + score=self.median_word_length, + ) + ) + spans.append( + Span( + 0, + self.character_count, + type="symbol_to_word_ratio", + score=self.symbol_to_word_ratio, + ) + ) spans.append( Span( 0, @@ -55,7 +92,14 @@ def as_spans(self) -> List[Span]: score=self.fraction_of_words_with_alpha_character, ) ) - spans.append(Span(0, self.character_count, type="required_word_count", score=self.required_word_count)) + spans.append( + Span( + 0, + self.character_count, + type="required_word_count", + score=self.required_word_count, + ) + ) spans.append( Span( 0, @@ -74,7 +118,10 @@ def as_spans(self) -> List[Span]: ) spans.append( Span( - 0, self.character_count, type="fraction_of_duplicate_lines", score=self.fraction_of_duplicate_lines + 0, + self.character_count, + type="fraction_of_duplicate_lines", + score=self.fraction_of_duplicate_lines, ) ) spans.append( @@ -100,7 +147,7 @@ def get_attributes(text: str) -> GopherAttributes: character_count = sum(len(word) for word in words) attrs.word_count = word_count - attrs.median_word_length = median([len(word) for word in words]) + attrs.median_word_length = robust_median([len(word) for word in words]) attrs.symbol_to_word_ratio = sum(1 for word in words if any(s in word for s in SYMBOLS)) / word_count attrs.fraction_of_words_with_alpha_character = ( sum(1 for word in words if any(c.isalpha() for c in word)) / word_count diff --git a/python/dolma/taggers/language.py b/python/dolma/taggers/language.py index 3e0884af..504fc2ab 100644 --- a/python/dolma/taggers/language.py +++ b/python/dolma/taggers/language.py @@ -5,6 +5,7 @@ @kylel, @soldni """ + from typing import TYPE_CHECKING, Iterable, List, Tuple import necessary diff --git a/python/dolma/taggers/pii.py b/python/dolma/taggers/pii.py index afa7c8df..75fe508d 100644 --- a/python/dolma/taggers/pii.py +++ b/python/dolma/taggers/pii.py @@ -6,7 +6,6 @@ """ - import re from typing import List from warnings import warn diff --git a/tests/python/extras/extras_from_module/extra_taggers.py b/tests/python/extras/extras_from_module/extra_taggers.py index 92a3fb61..3e65ef7b 100644 --- a/tests/python/extras/extras_from_module/extra_taggers.py +++ b/tests/python/extras/extras_from_module/extra_taggers.py @@ -3,4 +3,4 @@ @add_tagger("extra_v1") class ExtraV1Tagger(BaseTagger): - ... + pass diff --git a/tests/python/extras/extras_from_module_path/extra_taggers.py b/tests/python/extras/extras_from_module_path/extra_taggers.py index c88c7750..d01ce945 100644 --- a/tests/python/extras/extras_from_module_path/extra_taggers.py +++ b/tests/python/extras/extras_from_module_path/extra_taggers.py @@ -3,4 +3,4 @@ @add_tagger("extra_v3") class ExtraV1Tagger(BaseTagger): - ... + pass diff --git a/tests/python/extras/extras_from_path/extra_taggers.py b/tests/python/extras/extras_from_path/extra_taggers.py index bbe0b2ea..64e97c1b 100644 --- a/tests/python/extras/extras_from_path/extra_taggers.py +++ b/tests/python/extras/extras_from_path/extra_taggers.py @@ -3,4 +3,4 @@ @add_tagger("extra_v2") class ExtraV2Tagger(BaseTagger): - ... + pass diff --git a/tests/python/test_paths.py b/tests/python/test_paths.py index 4f3518dd..25a67c88 100644 --- a/tests/python/test_paths.py +++ b/tests/python/test_paths.py @@ -74,9 +74,11 @@ def test_local_glob_with_recursive(self): paths = list(glob_path(local_glob)) expected = list( itertools.chain.from_iterable( - (str(fp),) - if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn - else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ()) + ( + (str(fp),) + if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn + else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ()) + ) for fn in os.listdir(LOCAL_DATA) ) ) diff --git a/tests/python/test_utils.py b/tests/python/test_utils.py index b5c41676..38bf268d 100644 --- a/tests/python/test_utils.py +++ b/tests/python/test_utils.py @@ -6,7 +6,6 @@ """ - from unittest import TestCase from dolma.core.data_types import TextSlice