Skip to content

Commit

Permalink
Add robust median to gopher filter (#98)
Browse files Browse the repository at this point in the history
* Added robust median to gopher filter

* Added robust median to gopher filter

Co-authored-by: @TTTTao725

* Added robust median to gopher filter

Co-authored-by: TTTTao725 <[email protected]>

* Added robust median to gopher filter

Co-authored-by: TTTTao725 <[email protected]>

* fixed typing to use union

* reformatted with black

* formatted using `make style`

* attempting to fix style issues

* one more style change using python 3.10

* more style, make robust median always a float

---------

Co-authored-by: TTTTao725 <[email protected]>
Co-authored-by: TTTTao725 <[email protected]>
Co-authored-by: Luca Soldaini <[email protected]>
  • Loading branch information
4 people authored Jan 30, 2024
1 parent a44489f commit 50abfb7
Show file tree
Hide file tree
Showing 16 changed files with 110 additions and 41 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ jobs:
task:
- name: Check Python style
run: |
isort --check tests/python/ && isort --check python/
black --check tests/python/ && black --check python/
set -e
isort --check --verbose .
black --check --verbose .
- name: Check Rust style
run: |
Expand Down
11 changes: 4 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ publish:
test: test-python test-rust

test-python:
maturin develop --extras="all"
pytest -vsx tests/python
rm -rf tests/work/*

Expand All @@ -39,13 +40,9 @@ test-rust:
rm -rf tests/work/*

develop:
maturin develop --extras=all
maturin develop --extras="all"

style:
rustfmt --edition 2021 src/*.rs
autopep8 --in-place --recursive python/
isort python/
black python/
autopep8 --in-place --recursive tests/python/
isort tests/python/
black tests/python/
isort .
black .
39 changes: 30 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,14 @@ dolma = "dolma.cli.__main__:main"
[project.optional-dependencies]
dev = [
"black>=22.6.0",
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
"ipython>=8.4.0",
"autopep8>=1.7.0",
"flake8>=5.0",
"ipdb>=0.13.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
"ipdb>=0.13.0",
"ipython>=8.4.0",
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
]
# extension to process code
code = [
Expand Down Expand Up @@ -163,9 +162,7 @@ dolma = ["py.typed", "data/*"]

[tool.black]
line-length = 115

include = '\.pyi?$'

exclude = '''
(
__pycache__
Expand All @@ -176,12 +173,36 @@ exclude = '''
| \.venv
| \bdist\b
| \bdoc\b
| \.cargo
| configs
| docs
| scripts
| sources
| src
| target
| tests/config
| tests/data
| tests/work
)
'''
target-version = ["py38", "py39", "py310", "py311", "py312"]


[tool.isort]
profile = "black"
multi_line_output = 3
py_version=38
known_first_party = ["dolma"]
known_local_folder = ["tests", "python"]
extend_skip_glob = [
"configs/*",
"docs/*",
"scripts/*",
"sources/*",
"src/*",
"tests/config/*",
"tests/data/*",
"tests/work/*"
]

[tool.autopep8]
max_line_length = 115
Expand Down
8 changes: 3 additions & 5 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
Author: Luca Soldaini (@soldni)
"""


from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import Field
from dataclasses import field as dataclass_field
from dataclasses import is_dataclass
from logging import warn
from logging import warning
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -74,14 +73,13 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))

if typ_ is MISSING:
warn(f"No type annotation for field {field_name} in {config.__name__}")
warning(f"No type annotation for field {field_name} in {config.__name__}")
continue

# join prefix and field name
field_name = f"{prefix}.{field_name}" if prefix else field_name

# This section here is to handle Optional[T] types
# We only care for cases where T is a dataclass
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
# and that the union contains only one non-None type
if get_origin(typ_) == Union:
Expand Down
1 change: 1 addition & 0 deletions python/dolma/core/ft_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni
"""

import os
from tempfile import NamedTemporaryFile
from typing import Iterable, Literal, NamedTuple, Optional
Expand Down
1 change: 1 addition & 0 deletions python/dolma/core/taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni
"""

from abc import abstractmethod
from typing import List

Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/code/code_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@akshitab
"""

import logging
import re
from typing import List
Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@akshitab, @soldni
"""

import json
import logging
from pathlib import Path
Expand Down
67 changes: 57 additions & 10 deletions python/dolma/taggers/gopher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from statistics import median
from typing import Counter as CounterType
from typing import List, Tuple
from typing import List, Tuple, Union

from ..core.data_types import DocResult, Document, Span
from ..core.registry import TaggerRegistry
Expand All @@ -14,13 +14,19 @@
BULLET_POINTS = {"*", "-"}


def robust_median(values: List[Union[int, float]]) -> float:
if not values:
return 0.0
return float(median(values))


@dataclass
class GopherAttributes:
fraction_of_characters_in_most_common_ngram: List[Tuple[int, float]]
fraction_of_characters_in_duplicate_ngrams: List[Tuple[int, float]]
character_count: int = 0
word_count: int = 0
median_word_length: float = False
median_word_length: float = 0.0
symbol_to_word_ratio: float = 0.0
fraction_of_words_with_alpha_character: float = 0.0
required_word_count: int = 0
Expand All @@ -33,20 +39,51 @@ def as_spans(self) -> List[Span]:
spans = []
spans.extend(
[
Span(0, self.character_count, f"fraction_of_characters_in_most_common_{n}grams", v)
Span(
0,
self.character_count,
f"fraction_of_characters_in_most_common_{n}grams",
v,
)
for n, v in self.fraction_of_characters_in_most_common_ngram
]
)
spans.extend(
[
Span(0, self.character_count, f"fraction_of_characters_in_duplicate_{n}grams", v)
Span(
0,
self.character_count,
f"fraction_of_characters_in_duplicate_{n}grams",
v,
)
for n, v in self.fraction_of_characters_in_duplicate_ngrams
]
)
spans.append(Span(0, self.character_count, type="character_count", score=self.character_count))
spans.append(
Span(
0,
self.character_count,
type="character_count",
score=self.character_count,
)
)
spans.append(Span(0, self.character_count, type="word_count", score=self.word_count))
spans.append(Span(0, self.character_count, type="median_word_length", score=self.median_word_length))
spans.append(Span(0, self.character_count, type="symbol_to_word_ratio", score=self.symbol_to_word_ratio))
spans.append(
Span(
0,
self.character_count,
type="median_word_length",
score=self.median_word_length,
)
)
spans.append(
Span(
0,
self.character_count,
type="symbol_to_word_ratio",
score=self.symbol_to_word_ratio,
)
)
spans.append(
Span(
0,
Expand All @@ -55,7 +92,14 @@ def as_spans(self) -> List[Span]:
score=self.fraction_of_words_with_alpha_character,
)
)
spans.append(Span(0, self.character_count, type="required_word_count", score=self.required_word_count))
spans.append(
Span(
0,
self.character_count,
type="required_word_count",
score=self.required_word_count,
)
)
spans.append(
Span(
0,
Expand All @@ -74,7 +118,10 @@ def as_spans(self) -> List[Span]:
)
spans.append(
Span(
0, self.character_count, type="fraction_of_duplicate_lines", score=self.fraction_of_duplicate_lines
0,
self.character_count,
type="fraction_of_duplicate_lines",
score=self.fraction_of_duplicate_lines,
)
)
spans.append(
Expand All @@ -100,7 +147,7 @@ def get_attributes(text: str) -> GopherAttributes:
character_count = sum(len(word) for word in words)

attrs.word_count = word_count
attrs.median_word_length = median([len(word) for word in words])
attrs.median_word_length = robust_median([len(word) for word in words])
attrs.symbol_to_word_ratio = sum(1 for word in words if any(s in word for s in SYMBOLS)) / word_count
attrs.fraction_of_words_with_alpha_character = (
sum(1 for word in words if any(c.isalpha() for c in word)) / word_count
Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni
"""

from typing import TYPE_CHECKING, Iterable, List, Tuple

import necessary
Expand Down
1 change: 0 additions & 1 deletion python/dolma/taggers/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""


import re
from typing import List
from warnings import warn
Expand Down
2 changes: 1 addition & 1 deletion tests/python/extras/extras_from_module/extra_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v1")
class ExtraV1Tagger(BaseTagger):
...
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v3")
class ExtraV1Tagger(BaseTagger):
...
pass
2 changes: 1 addition & 1 deletion tests/python/extras/extras_from_path/extra_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v2")
class ExtraV2Tagger(BaseTagger):
...
pass
8 changes: 5 additions & 3 deletions tests/python/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def test_local_glob_with_recursive(self):
paths = list(glob_path(local_glob))
expected = list(
itertools.chain.from_iterable(
(str(fp),)
if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn
else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ())
(
(str(fp),)
if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn
else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ())
)
for fn in os.listdir(LOCAL_DATA)
)
)
Expand Down
1 change: 0 additions & 1 deletion tests/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""


from unittest import TestCase

from dolma.core.data_types import TextSlice
Expand Down

0 comments on commit 50abfb7

Please sign in to comment.