Skip to content

Commit

Permalink
Porting missing code taggers, adding repetition tagger (#86)
Browse files Browse the repository at this point in the history
- adds support for taggers that use metadata
- ports code taggers from `allenai/LLM`
- adds new taggers to count repetitions with regex and tokenizers
- added tagger to count length without whitespaces
- added script to make plots for dolma papers (`scripts/dolma_paper_plots.sh`, `scripts/wandb_to_plot.py`)
- added script to find document from tokenizer offset (`scripts/find_offset.py`)
- added tests for new taggers
- improved GitHub Action to cache state
  • Loading branch information
soldni authored Nov 27, 2023
1 parent 734afa3 commit 38fa168
Show file tree
Hide file tree
Showing 35 changed files with 3,246 additions and 359 deletions.
103 changes: 66 additions & 37 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ permissions:
env:
DOLMA_TESTS_SKIP_AWS: ${{ secrets.AWS_ACCESS_KEY_ID == '' && 'true' || 'false' }}
DOLMA_TEST_S3_PREFIX: s3://dolma-tests
RUST_CHANNEL: stable


jobs:
Expand All @@ -38,17 +39,69 @@ jobs:
echo "PR base repo: ${{ github.event.pull_request.base.repo.full_name }}/tree/${{ github.event.pull_request.base.ref }}"
echo "PR head repo: ${{ github.event.pull_request.head.repo.full_name }}/tree/${{ github.event.pull_request.head.ref }}"
prepare-venv:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Cache Virtual Env
uses: actions/cache@v3
# name for referring later
id: cache-venv
with:
# what we cache: the virtualenv
path: ./.venv/
# The cache key depends on pyproject.toml and Cargo.toml
key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml', '**/Cargo.toml, **/Cargo.lock') }}--${{ hashFiles('python/**', 'src/**') }}

- name: Setup system libraries
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source
- name: Install Rust toolchain
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
rustup update ${{ env.RUST_CHANNEL }}
rustup component add --toolchain ${{ env.RUST_CHANNEL }} rustfmt rust-src
rustup default ${{ env.RUST_CHANNEL }}
- name: Install Python
if: steps.cache-venv.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: '3.8'
architecture: "x64"
cache: 'pip'

- name: Create a new Python environment & install maturin
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install maturin
- name: Install dolma wheels
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
source .venv/bin/activate
maturin build --release -i $(which python) --out dist
wheel_path=$(ls dist/*.whl)
pip install "${wheel_path}[all]"
tests:
runs-on: ubuntu-latest
needs: prepare-venv
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }}
strategy:
fail-fast: true
matrix:
python: [3.8]
task:
- name: Check Python style
run: |
Expand All @@ -73,50 +126,23 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v1
uses: actions/checkout@v3

- name: Setup system libraries
run: |
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source
- name: Install Rust
uses: actions-rs/toolchain@v1
- name: Cache Virtual Env
uses: actions/cache@v3
# name for referring later
id: cache-venv
with:
toolchain: stable
components: rustfmt

- name: Install Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
architecture: "x64"
sccache: true

- name: Create a new Python environment & install maturin
run: |
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install maturin
- name: Install dolma wheels
run: |
source .venv/bin/activate
maturin develop --extras=dev
# what we cache: the virtualenv
path: ./.venv/
# The cache key depends on pyproject.toml and Cargo.toml
key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml', '**/Cargo.toml, **/Cargo.lock') }}--${{ hashFiles('python/**', 'src/**') }}

- name: ${{ matrix.task.name }}
run: |
source .venv/bin/activate
${{ matrix.task.run }}
- name: Clean up
if: always()
run: |
source .venv/bin/activate
pip uninstall -y dolma
build-linux:
if: ${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}
Expand All @@ -132,6 +158,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- name: Setup environment
run: |
sudo apt-get update
Expand Down Expand Up @@ -165,6 +192,7 @@ jobs:
with:
python-version: '3.10'
architecture: ${{ matrix.target }}
cache: 'pip'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand All @@ -188,6 +216,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ target/

# ignore vscode directory
.vscode

# ignore temporary directories
/tmp/
/temp/
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test-rust:
rm -rf tests/work/*

develop:
maturin develop --extras=dev
maturin develop --extras=all

style:
rustfmt --edition 2021 src/*.rs
Expand Down
53 changes: 36 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,20 @@ requires-python = ">=3.8"
dependencies = [
"anyascii>=0.3.2",
"blingfire==0.1.8",
"boto3",
"boto3>=1.28",
"cached-path==1.3.4",
"detect-secrets==1.4.0",
# "fasttext==0.9.2", # broken with new version of setuptools; using fasttext-wheel instead
"fasttext-wheel==0.9.2",
"fsspec",
"fsspec>=2023.6.0",
"msgspec>=0.14.2",
"nltk==3.8.1",
"omegaconf>=2.3.0",
"presidio_analyzer==2.2.32",
"pycld2==0.41",
# "pycld3==0.22", # does not install correctly
"pyyaml",
"requests",
"rich",
"s3fs",
"s3fs>=2023.6.0",
"smart-open",
"tokenizers>=0.13.3,<1.0.0",
"tqdm",
Expand Down Expand Up @@ -108,18 +106,39 @@ dev = [
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
]
warc = [
"warcio>=1.7.4",
"trafilatura>=1.6.1",
"justext>=3.0.0",
"goose3>=3.1.17",

# following are all for speeding up trafilatura
"brotli",
"cchardet >= 2.1.7; python_version < '3.11'", # build issue
"faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
"htmldate[speed] >= 1.4.3",
"py3langid >= 0.2.2",
# extension to process code
code = [
"detect-secrets==1.4.0",
"beautifulsoup4>=4",
"pygments",
"regex"
]
# extension to detect PIIs using presidio
pii = [
"presidio_analyzer==2.2.32",
"regex"
]
# # extension to parse warc files
# warc = [
# "warcio>=1.7.4",
# "trafilatura>=1.6.1",
# "justext>=3.0.0",
# "goose3>=3.1.17",

# # following are all for speeding up trafilatura
# "brotli",
# "cchardet >= 2.1.7; python_version < '3.11'", # build issue
# "faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
# "htmldate[speed] >= 1.4.3",
# "py3langid >= 0.2.2",
# ]

# all extensions
all = [
"dolma[dev]",
"dolma[code]",
"dolma[pii]",
# "dolma[warc]",
]

[build-system]
Expand Down
75 changes: 70 additions & 5 deletions python/dolma/core/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class InputSpec(Struct):
text: str
source: str = ""
version: Optional[str] = None
# ignoring metadata for now; taggers run on text only
# metadata: Optional[Dict[str, Any]] = None


class InputSpecWithMetadata(InputSpec):
metadata: Optional[Dict[str, Any]] = None


class OutputSpec(Struct):
Expand All @@ -48,17 +50,67 @@ def to_spec(self) -> InputSpec:
return InputSpec(source=self.source, version=self.version, id=self.id, text=self.text)

@classmethod
def from_json(cls, d: Dict) -> "Document":
def from_json(cls, d: Dict[str, Any]) -> "Document":
return Document(source=d["source"], version=d["version"], id=d["id"], text=d["text"])

def to_json(self) -> Dict:
def to_json(self) -> Dict[str, Any]:
return {"source": self.source, "version": self.version, "id": self.id, "text": self.text}

def __str__(self) -> str:
attributes_string = ",".join([f"{k}:{repr(v)}" for k, v in self.to_json()])
attributes_string = ",".join([f"{k}:{repr(v)}" for k, v in self.to_json().items()])
return f"{self.__class__.__name__}({attributes_string})"


class DocumentWithMetadata(Document):
__slots__ = ("metadata",)

def __init__(self, *args, metadata: Optional[Dict[str, Any]] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.metadata = metadata or {}

@classmethod
def from_spec(cls, spec: InputSpecWithMetadata) -> "DocumentWithMetadata":
return DocumentWithMetadata(
source=spec.source,
version=spec.version,
id=spec.id,
text=spec.text,
metadata=spec.metadata,
)

def to_spec(self) -> InputSpecWithMetadata:
return InputSpecWithMetadata(
source=self.source,
version=self.version,
id=self.id,
text=self.text,
metadata=self.metadata,
)

@classmethod
def from_json(cls, d: Dict) -> "DocumentWithMetadata":
return DocumentWithMetadata(
source=d["source"],
version=d["version"],
id=d["id"],
text=d["text"],
metadata=d["metadata"],
)

def to_json(self) -> Dict:
return {
"source": self.source,
"version": self.version,
"id": self.id,
"text": self.text,
"metadata": self.metadata,
}

def __str__(self) -> str:
repr_ = super().__str__()
return repr_.rstrip(")") + f",metadata={'...' if self.metadata else 'none'})"


class Span:
__slots__ = "start", "end", "type", "score", "experiment", "tagger"

Expand Down Expand Up @@ -127,6 +179,19 @@ def __str__(self) -> str:
cls_name = self.__class__.__name__
return f"{cls_name}(start={self.start},end={self.end},type={repr(self.type)},score={self.score:.5f})"

def __repr__(self) -> str:
return str(self)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return (
self.start == other.start
and self.end == other.end
and self.type == other.type
and self.score == other.score
)


class DocResult:
__slots__ = "doc", "spans"
Expand Down
7 changes: 6 additions & 1 deletion python/dolma/core/loggers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
import multiprocessing


def get_logger(name: str) -> logging.Logger:
name = f"dolma.{name}"
if (proc_name := multiprocessing.current_process().name) == "MainProcess":
proc_name = "main"
proc_name = proc_name.replace(" ", "_")

name = f"{proc_name}.dolma.{name}"
logger = logging.getLogger(name)
logger.setLevel(logging.WARN)

Expand Down
Loading

0 comments on commit 38fa168

Please sign in to comment.