Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tagger_modules option to tagger cli #69

Merged
merged 9 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions docs/taggers.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The following parameters are supported either via CLI (e.g. `dolma tag --paramet
|`destination`|No| One or more paths for output attribute files. Each accepts a single wildcard `*` character. Can be local, or an S3-compatible cloud path. If not provided, the destination will be derived from the document path. |
|`experiment`|No| Used to name output attribute files. One output file will be created for each input document file, where the key is obtained by substituting `documents` with `attributes/<experiment>`. If not provided, we will use `attributes/<tagger_name>`. |
|`taggers`|Yes| One or more taggers to run. |
|`tagger_modules`|No| List of one or more Python modules to load taggers from. See section [*"Using Custom Taggers"*](#using-custom-taggers) for more details. |
|`processes`|No| Number of processes to use for tagging. One process is used by default. |
|`ignore_existing`|No| If true, ignore existing outputs and re-run the taggers. |
|`dryrun`|No| If true, only print the configuration and exit without running the taggers. |
Expand Down Expand Up @@ -80,10 +81,9 @@ All taggers inherit from the `BaseTagger` class defined in [`core/taggers.py`](h
import random

from dolma.core.data_types import DocResult, Document, Span
from dolma.core.registry import TaggerRegistry
from dolma.core.taggers import BaseTagger
from dolma import add_tagger, BaseTagger

@TaggerRegistry.add("random_number_v1")
@add_tagger("new_random_number")
class RandomNumberTagger(BaseTagger):
def predict(self, doc: Document) -> DocResult:
# first, we generate a random number
Expand All @@ -102,4 +102,13 @@ class RandomNumberTagger(BaseTagger):
return DocResult(doc=doc, spans=[span])
```

Name for each tagger is specified using the `@TaggerRegistry.add` decorator. The name must be unique.
Name for each tagger is specified using the `add_tagger` decorator. The name must be unique.

## Using Custom Taggers

Taggers can be added either as part of the Dolma package, or they can be imported at runtime by providing the `tagger_modules` parameter.

For example, let's assume `new_random_number` is saved in a file called `my_taggers.py` in python module `my_module`. Then, we can run the tagger using one of the following commands:

- `dolma tag --taggers new_random_number --tagger_modules path/to/my_module/my_taggers.py ...`
- `PYTHONPATH="path/to/my_module" dolma tag --taggers new_random_number --tagger_modules my_taggers`
10 changes: 10 additions & 0 deletions python/dolma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
# must import taggers to register them
# we import the rust extension here and wrap it in a python module
from . import dolma as _dolma # type: ignore # noqa: E402
from .core import TaggerRegistry # noqa: E402
from .core.errors import DolmaRustPipelineError # noqa: E402
from .core.taggers import BaseTagger # noqa: E402
from .taggers import * # noqa: E402

__all__ = [
"add_tagger",
"BaseTagger",
]

# we create a shortcut to easily add taggers to the registry
add_tagger = TaggerRegistry.add


def deduper(config: dict):
try:
Expand Down
18 changes: 17 additions & 1 deletion python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dolma.core.paths import glob_path
from dolma.core.registry import TaggerRegistry
from dolma.core.runtime import create_and_run_tagger
from dolma.core.utils import import_modules


@dataclass
Expand Down Expand Up @@ -53,6 +54,14 @@ class TaggerConfig:
"If not provided, destination will be derived from the document path."
),
)
tagger_modules: List[str] = field(
default=[],
help=(
"Additional modules to import taggers from; this is useful for taggers that are not part of Dolma. "
"Modules must be available in $PYTHONPATH or a path to module. Taggers should be registered using the "
"@dolma.add_tagger(...) decorator."
),
)
taggers: List[str] = field(
default=[],
help="List of taggers to run.",
Expand Down Expand Up @@ -122,6 +131,7 @@ def run(cls, parsed_config: TaggerConfig):
destination=parsed_config.destination,
metadata=work_dirs.output,
taggers=taggers,
taggers_modules=parsed_config.tagger_modules,
ignore_existing=parsed_config.ignore_existing,
num_processes=parsed_config.processes,
experiment=parsed_config.experiment,
Expand All @@ -135,7 +145,10 @@ def run(cls, parsed_config: TaggerConfig):

@dataclass
class ListTaggerConfig:
...
tagger_modules: List[str] = field(
default=[],
help="List of Python modules $PYTHONPATH to import custom taggers from.",
)


class ListTaggerCli(BaseCli):
Expand All @@ -144,6 +157,9 @@ class ListTaggerCli(BaseCli):

@classmethod
def run(cls, parsed_config: ListTaggerConfig):
# import tagger modules
import_modules(parsed_config.tagger_modules)

table = Table(title="dolma taggers", style="bold")
table.add_column("name", justify="left", style="cyan")
table.add_column("class", justify="left", style="magenta")
Expand Down
13 changes: 13 additions & 0 deletions python/dolma/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def _add(

return _add

@classmethod
def remove(cls, name: str) -> bool:
"""Remove a tagger from the registry."""
if name in cls.__taggers:
cls.__taggers.pop(name)
return True
return False

@classmethod
def has(cls, name: str) -> bool:
"""Check if a tagger exists in the registry."""
return name in cls.__taggers

@classmethod
def get(cls, name: str) -> Type[BaseTagger]:
"""Get a tagger from the registry; raise ValueError if it doesn't exist."""
Expand Down
8 changes: 7 additions & 1 deletion python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .parallel import BaseParallelProcessor, QueueType
from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path
from .registry import TaggerRegistry
from .utils import make_variable_name
from .utils import import_modules, make_variable_name

# this placeholder gets used when a user has provided no experiment name, and we want to use taggers'
# names as experiment names.
Expand Down Expand Up @@ -220,6 +220,10 @@ def process_single(
**kwargs,
):
"""Lets count run the taggers! We will use the destination path to save each tagger output."""
# import tagger modules
taggers_modules = kwargs.get("taggers_modules", None)
if taggers_modules is not None:
import_modules(taggers_modules)

# get names of taggers
taggers_names = kwargs.get("taggers_names", None)
Expand Down Expand Up @@ -329,6 +333,7 @@ def profiler(
def create_and_run_tagger(
documents: List[str],
taggers: List[str],
taggers_modules: Optional[List[str]] = None,
experiment: Optional[str] = None,
destination: Union[None, str, List[str]] = None,
metadata: Union[None, str, List[str]] = None,
Expand Down Expand Up @@ -421,6 +426,7 @@ def create_and_run_tagger(
tagger(
experiment_name=experiment,
taggers_names=taggers,
taggers_modules=taggers_modules,
skip_on_failure=skip_on_failure,
steps=profile_steps,
)
Expand Down
59 changes: 57 additions & 2 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import importlib
import os
import re
import string
from typing import List
import sys
from typing import List, Union

try:
import blingfire
Expand All @@ -19,8 +22,10 @@


from .data_types import TextSlice
from .loggers import get_logger

sent_tokenizer = PunktSentenceTokenizer()
logger = get_logger(__name__)


def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str:
Expand Down Expand Up @@ -56,7 +61,7 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]:
Split a string into sentences.
"""
if text and BLINGFIRE_AVAILABLE:
_, offsets = blingfire.text_to_sentences_and_offsets(text)
_, offsets = blingfire.text_to_sentences_and_offsets(text) # pyright: ignore
elif text:
offsets = [(start, end) for start, end in sent_tokenizer.span_tokenize(text)]
else:
Expand All @@ -66,3 +71,53 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]:
return [TextSlice(doc=text, start=start, end=end) for (start, end) in offsets]
else:
raise NotImplementedError("remove_empty=False is not implemented yet")


def import_modules(modules_path: Union[List[str], None]):
"""Import one or more user modules from either names or paths.
Importing from path is modeled after fairseq's import_user_module function:
github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/utils.py#L464

Args:
modules_path (Union[List[str], None]): List of module names or paths to import.
"""

for module_path in modules_path or []:
# try importing the module directly
try:
importlib.import_module(module_path)
continue
except ModuleNotFoundError:
pass
except Exception as exp:
raise RuntimeError(f"Could not import module {module_path}: {exp}") from exp

# if that fails, try importing the module as a path

# check if this function has a memorization attribute; if not; create it
# the memorization attribute is used to ensure that user modules are only imported once
if (modules_memo := getattr(import_modules, "memo", None)) is None:
modules_memo = set()
import_modules.memo = modules_memo # type: ignore

# ensure that user modules are only imported once
if module_path not in modules_memo:
modules_memo.add(module_path)

if not os.path.exists(module_path):
raise FileNotFoundError(f"Could not find module {module_path}")

# the format is `<module_parent>/<module_name>.py` or `<module_parent>/<module_name>`
module_parent, module_name = os.path.split(module_path)
module_name, _ = os.path.splitext(module_name)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
elif module_path in sys.modules[module_name].__path__:
logger.info(f"{module_path} has already been imported.")
else:
raise ImportError(
f"Failed to import {module_path} because the corresponding module name "
f"({module_name}) is not globally unique. Please rename the directory to "
"something unique and try again."
)
Empty file added tests/python/extras/__init__.py
Empty file.
1 change: 1 addition & 0 deletions tests/python/extras/extras_from_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .extra_taggers import * # noqa
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_module/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v1")
class ExtraV1Tagger(BaseTagger):
...
1 change: 1 addition & 0 deletions tests/python/extras/extras_from_module_path/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .extra_taggers import * # noqa
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_module_path/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v3")
class ExtraV1Tagger(BaseTagger):
...
Empty file.
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_path/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v2")
class ExtraV2Tagger(BaseTagger):
...
24 changes: 24 additions & 0 deletions tests/python/test_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import unittest
from pathlib import Path

from dolma.core.registry import TaggerRegistry
from dolma.core.utils import import_modules


class TestExtra(unittest.TestCase):
def setUp(self) -> None:
self.current_path = Path(__file__).parent.absolute()

def test_import_from_module(self):
sys.path.append(f"{self.current_path}/extras")
import_modules(["extras_from_module"])
self.assertTrue(TaggerRegistry.has("extra_v1"))

def test_import_from_path(self):
import_modules([f"{self.current_path}/extras/extras_from_path/extra_taggers.py"])
self.assertTrue(TaggerRegistry.has("extra_v2"))

def test_import_from_module_path(self):
import_modules([f"{self.current_path}/extras/extras_from_module_path"])
self.assertTrue(TaggerRegistry.has("extra_v3"))
Loading