Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add robust median to gopher filter #98

Merged
merged 16 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ jobs:
task:
- name: Check Python style
run: |
isort --check tests/python/ && isort --check python/
black --check tests/python/ && black --check python/
set -e
isort --check --verbose .
black --check --verbose .

- name: Check Rust style
run: |
Expand Down
11 changes: 4 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ publish:
test: test-python test-rust

test-python:
maturin develop --extras="all"
pytest -vsx tests/python
rm -rf tests/work/*

Expand All @@ -39,13 +40,9 @@ test-rust:
rm -rf tests/work/*

develop:
maturin develop --extras=all
maturin develop --extras="all"

style:
rustfmt --edition 2021 src/*.rs
autopep8 --in-place --recursive python/
isort python/
black python/
autopep8 --in-place --recursive tests/python/
isort tests/python/
black tests/python/
isort .
black .
39 changes: 30 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,14 @@ dolma = "dolma.cli.__main__:main"
[project.optional-dependencies]
dev = [
"black>=22.6.0",
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
"ipython>=8.4.0",
"autopep8>=1.7.0",
"flake8>=5.0",
"ipdb>=0.13.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
"ipdb>=0.13.0",
"ipython>=8.4.0",
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
]
# extension to process code
code = [
Expand Down Expand Up @@ -163,9 +162,7 @@ dolma = ["py.typed", "data/*"]

[tool.black]
line-length = 115

include = '\.pyi?$'

exclude = '''
(
__pycache__
Expand All @@ -176,12 +173,36 @@ exclude = '''
| \.venv
| \bdist\b
| \bdoc\b
| \.cargo
| configs
| docs
| scripts
| sources
| src
| target
| tests/config
| tests/data
| tests/work
)
'''
target-version = ["py38", "py39", "py310", "py311", "py312"]


[tool.isort]
profile = "black"
multi_line_output = 3
py_version=38
known_first_party = ["dolma"]
known_local_folder = ["tests", "python"]
extend_skip_glob = [
"configs/*",
"docs/*",
"scripts/*",
"sources/*",
"src/*",
"tests/config/*",
"tests/data/*",
"tests/work/*"
]

[tool.autopep8]
max_line_length = 115
Expand Down
8 changes: 3 additions & 5 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
Author: Luca Soldaini (@soldni)
"""


from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import Field
from dataclasses import field as dataclass_field
from dataclasses import is_dataclass
from logging import warn
from logging import warning
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -74,14 +73,13 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))

if typ_ is MISSING:
warn(f"No type annotation for field {field_name} in {config.__name__}")
warning(f"No type annotation for field {field_name} in {config.__name__}")
continue

# join prefix and field name
field_name = f"{prefix}.{field_name}" if prefix else field_name

# This section here is to handle Optional[T] types
# We only care for cases where T is a dataclass
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
# and that the union contains only one non-None type
if get_origin(typ_) == Union:
Expand Down
1 change: 1 addition & 0 deletions python/dolma/core/ft_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni

"""

import os
from tempfile import NamedTemporaryFile
from typing import Iterable, Literal, NamedTuple, Optional
Expand Down
1 change: 1 addition & 0 deletions python/dolma/core/taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni

"""

from abc import abstractmethod
from typing import List

Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/code/code_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@akshitab

"""

import logging
import re
from typing import List
Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@akshitab, @soldni

"""

import json
import logging
from pathlib import Path
Expand Down
67 changes: 57 additions & 10 deletions python/dolma/taggers/gopher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from statistics import median
from typing import Counter as CounterType
from typing import List, Tuple
from typing import List, Tuple, Union

from ..core.data_types import DocResult, Document, Span
from ..core.registry import TaggerRegistry
Expand All @@ -14,13 +14,19 @@
BULLET_POINTS = {"*", "-"}


def robust_median(values: List[Union[int, float]]) -> float:
if not values:
return 0.0
return float(median(values))


@dataclass
class GopherAttributes:
fraction_of_characters_in_most_common_ngram: List[Tuple[int, float]]
fraction_of_characters_in_duplicate_ngrams: List[Tuple[int, float]]
character_count: int = 0
word_count: int = 0
median_word_length: float = False
median_word_length: float = 0.0
symbol_to_word_ratio: float = 0.0
fraction_of_words_with_alpha_character: float = 0.0
required_word_count: int = 0
Expand All @@ -33,20 +39,51 @@ def as_spans(self) -> List[Span]:
spans = []
spans.extend(
[
Span(0, self.character_count, f"fraction_of_characters_in_most_common_{n}grams", v)
Span(
0,
self.character_count,
f"fraction_of_characters_in_most_common_{n}grams",
v,
)
for n, v in self.fraction_of_characters_in_most_common_ngram
]
)
spans.extend(
[
Span(0, self.character_count, f"fraction_of_characters_in_duplicate_{n}grams", v)
Span(
0,
self.character_count,
f"fraction_of_characters_in_duplicate_{n}grams",
v,
)
for n, v in self.fraction_of_characters_in_duplicate_ngrams
]
)
spans.append(Span(0, self.character_count, type="character_count", score=self.character_count))
spans.append(
Span(
0,
self.character_count,
type="character_count",
score=self.character_count,
)
)
spans.append(Span(0, self.character_count, type="word_count", score=self.word_count))
spans.append(Span(0, self.character_count, type="median_word_length", score=self.median_word_length))
spans.append(Span(0, self.character_count, type="symbol_to_word_ratio", score=self.symbol_to_word_ratio))
spans.append(
Span(
0,
self.character_count,
type="median_word_length",
score=self.median_word_length,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given than median_word_length is bool | float, wouldn't this make score potentially a bool? score is supposed to be a float, so we would have to cast back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it starts out as a False, so tried to match the existing pattern. However the median can be undefined (empty list) but multiple value could represent that (np.nan, 0, False). I would probably go for np.nan or None if that is valid?

)
)
spans.append(
Span(
0,
self.character_count,
type="symbol_to_word_ratio",
score=self.symbol_to_word_ratio,
)
)
spans.append(
Span(
0,
Expand All @@ -55,7 +92,14 @@ def as_spans(self) -> List[Span]:
score=self.fraction_of_words_with_alpha_character,
)
)
spans.append(Span(0, self.character_count, type="required_word_count", score=self.required_word_count))
spans.append(
Span(
0,
self.character_count,
type="required_word_count",
score=self.required_word_count,
)
)
spans.append(
Span(
0,
Expand All @@ -74,7 +118,10 @@ def as_spans(self) -> List[Span]:
)
spans.append(
Span(
0, self.character_count, type="fraction_of_duplicate_lines", score=self.fraction_of_duplicate_lines
0,
self.character_count,
type="fraction_of_duplicate_lines",
score=self.fraction_of_duplicate_lines,
)
)
spans.append(
Expand All @@ -100,7 +147,7 @@ def get_attributes(text: str) -> GopherAttributes:
character_count = sum(len(word) for word in words)

attrs.word_count = word_count
attrs.median_word_length = median([len(word) for word in words])
attrs.median_word_length = robust_median([len(word) for word in words])
attrs.symbol_to_word_ratio = sum(1 for word in words if any(s in word for s in SYMBOLS)) / word_count
attrs.fraction_of_words_with_alpha_character = (
sum(1 for word in words if any(c.isalpha() for c in word)) / word_count
Expand Down
1 change: 1 addition & 0 deletions python/dolma/taggers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@kylel, @soldni

"""

from typing import TYPE_CHECKING, Iterable, List, Tuple

import necessary
Expand Down
1 change: 0 additions & 1 deletion python/dolma/taggers/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

"""


import re
from typing import List
from warnings import warn
Expand Down
2 changes: 1 addition & 1 deletion tests/python/extras/extras_from_module/extra_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v1")
class ExtraV1Tagger(BaseTagger):
...
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v3")
class ExtraV1Tagger(BaseTagger):
...
pass
2 changes: 1 addition & 1 deletion tests/python/extras/extras_from_path/extra_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@add_tagger("extra_v2")
class ExtraV2Tagger(BaseTagger):
...
pass
8 changes: 5 additions & 3 deletions tests/python/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def test_local_glob_with_recursive(self):
paths = list(glob_path(local_glob))
expected = list(
itertools.chain.from_iterable(
(str(fp),)
if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn
else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ())
(
(str(fp),)
if (fp := LOCAL_DATA / fn).is_file() and "paragraphs" in fn
else ((str(fp / sn) for sn in os.listdir(fp) if "paragraphs" in sn) if fp.is_dir() else ())
)
for fn in os.listdir(LOCAL_DATA)
)
)
Expand Down
1 change: 0 additions & 1 deletion tests/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

"""


from unittest import TestCase

from dolma.core.data_types import TextSlice
Expand Down
Loading