diff --git a/docs/taggers.md b/docs/taggers.md index 42a9f75f..f5e6ddbd 100644 --- a/docs/taggers.md +++ b/docs/taggers.md @@ -24,6 +24,7 @@ The following parameters are supported either via CLI (e.g. `dolma tag --paramet |`destination`|No| One or more paths for output attribute files. Each accepts a single wildcard `*` character. Can be local, or an S3-compatible cloud path. If not provided, the destination will be derived from the document path. | |`experiment`|No| Used to name output attribute files. One output file will be created for each input document file, where the key is obtained by substituting `documents` with `attributes/`. If not provided, we will use `attributes/`. | |`taggers`|Yes| One or more taggers to run. | +|`tagger_modules`|No| List of one or more Python modules to load taggers from. See section [*"Using Custom Taggers"*](#using-custom-taggers) for more details. | |`processes`|No| Number of processes to use for tagging. One process is used by default. | |`ignore_existing`|No| If true, ignore existing outputs and re-run the taggers. | |`dryrun`|No| If true, only print the configuration and exit without running the taggers. | @@ -80,10 +81,9 @@ All taggers inherit from the `BaseTagger` class defined in [`core/taggers.py`](h import random from dolma.core.data_types import DocResult, Document, Span -from dolma.core.registry import TaggerRegistry -from dolma.core.taggers import BaseTagger +from dolma import add_tagger, BaseTagger -@TaggerRegistry.add("random_number_v1") +@add_tagger("new_random_number") class RandomNumberTagger(BaseTagger): def predict(self, doc: Document) -> DocResult: # first, we generate a random number @@ -102,4 +102,13 @@ class RandomNumberTagger(BaseTagger): return DocResult(doc=doc, spans=[span]) ``` -Name for each tagger is specified using the `@TaggerRegistry.add` decorator. The name must be unique. +Name for each tagger is specified using the `add_tagger` decorator. The name must be unique. + +## Using Custom Taggers + +Taggers can be added either as part of the Dolma package, or they can be imported at runtime by providing the `tagger_modules` parameter. + +For example, let's assume `new_random_number` is saved in a file called `my_taggers.py` in python module `my_module`. Then, we can run the tagger using one of the following commands: + +- `dolma tag --taggers new_random_number --tagger_modules path/to/my_module/my_taggers.py ...` +- `PYTHONPATH="path/to/my_module" dolma tag --taggers new_random_number --tagger_modules my_taggers` diff --git a/python/dolma/__init__.py b/python/dolma/__init__.py index 635767b2..b5b90f1b 100644 --- a/python/dolma/__init__.py +++ b/python/dolma/__init__.py @@ -9,9 +9,19 @@ # must import taggers to register them # we import the rust extension here and wrap it in a python module from . import dolma as _dolma # type: ignore # noqa: E402 +from .core import TaggerRegistry # noqa: E402 from .core.errors import DolmaRustPipelineError # noqa: E402 +from .core.taggers import BaseTagger # noqa: E402 from .taggers import * # noqa: E402 +__all__ = [ + "add_tagger", + "BaseTagger", +] + +# we create a shortcut to easily add taggers to the registry +add_tagger = TaggerRegistry.add + def deduper(config: dict): try: diff --git a/python/dolma/cli/tagger.py b/python/dolma/cli/tagger.py index 7e14d2a2..3a594fe7 100644 --- a/python/dolma/cli/tagger.py +++ b/python/dolma/cli/tagger.py @@ -12,6 +12,7 @@ from dolma.core.paths import glob_path from dolma.core.registry import TaggerRegistry from dolma.core.runtime import create_and_run_tagger +from dolma.core.utils import import_modules @dataclass @@ -53,6 +54,14 @@ class TaggerConfig: "If not provided, destination will be derived from the document path." ), ) + tagger_modules: List[str] = field( + default=[], + help=( + "Additional modules to import taggers from; this is useful for taggers that are not part of Dolma. " + "Modules must be available in $PYTHONPATH or a path to module. Taggers should be registered using the " + "@dolma.add_tagger(...) decorator." + ), + ) taggers: List[str] = field( default=[], help="List of taggers to run.", @@ -122,6 +131,7 @@ def run(cls, parsed_config: TaggerConfig): destination=parsed_config.destination, metadata=work_dirs.output, taggers=taggers, + taggers_modules=parsed_config.tagger_modules, ignore_existing=parsed_config.ignore_existing, num_processes=parsed_config.processes, experiment=parsed_config.experiment, @@ -135,7 +145,10 @@ def run(cls, parsed_config: TaggerConfig): @dataclass class ListTaggerConfig: - ... + tagger_modules: List[str] = field( + default=[], + help="List of Python modules $PYTHONPATH to import custom taggers from.", + ) class ListTaggerCli(BaseCli): @@ -144,6 +157,9 @@ class ListTaggerCli(BaseCli): @classmethod def run(cls, parsed_config: ListTaggerConfig): + # import tagger modules + import_modules(parsed_config.tagger_modules) + table = Table(title="dolma taggers", style="bold") table.add_column("name", justify="left", style="cyan") table.add_column("class", justify="left", style="magenta") diff --git a/python/dolma/core/registry.py b/python/dolma/core/registry.py index e520b49c..6ebd27ea 100644 --- a/python/dolma/core/registry.py +++ b/python/dolma/core/registry.py @@ -36,6 +36,19 @@ def _add( return _add + @classmethod + def remove(cls, name: str) -> bool: + """Remove a tagger from the registry.""" + if name in cls.__taggers: + cls.__taggers.pop(name) + return True + return False + + @classmethod + def has(cls, name: str) -> bool: + """Check if a tagger exists in the registry.""" + return name in cls.__taggers + @classmethod def get(cls, name: str) -> Type[BaseTagger]: """Get a tagger from the registry; raise ValueError if it doesn't exist.""" diff --git a/python/dolma/core/runtime.py b/python/dolma/core/runtime.py index 2faf713d..62ea0f5e 100644 --- a/python/dolma/core/runtime.py +++ b/python/dolma/core/runtime.py @@ -23,7 +23,7 @@ from .parallel import BaseParallelProcessor, QueueType from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path from .registry import TaggerRegistry -from .utils import make_variable_name +from .utils import import_modules, make_variable_name # this placeholder gets used when a user has provided no experiment name, and we want to use taggers' # names as experiment names. @@ -220,6 +220,10 @@ def process_single( **kwargs, ): """Lets count run the taggers! We will use the destination path to save each tagger output.""" + # import tagger modules + taggers_modules = kwargs.get("taggers_modules", None) + if taggers_modules is not None: + import_modules(taggers_modules) # get names of taggers taggers_names = kwargs.get("taggers_names", None) @@ -329,6 +333,7 @@ def profiler( def create_and_run_tagger( documents: List[str], taggers: List[str], + taggers_modules: Optional[List[str]] = None, experiment: Optional[str] = None, destination: Union[None, str, List[str]] = None, metadata: Union[None, str, List[str]] = None, @@ -421,6 +426,7 @@ def create_and_run_tagger( tagger( experiment_name=experiment, taggers_names=taggers, + taggers_modules=taggers_modules, skip_on_failure=skip_on_failure, steps=profile_steps, ) diff --git a/python/dolma/core/utils.py b/python/dolma/core/utils.py index e21882cd..6ec012a9 100644 --- a/python/dolma/core/utils.py +++ b/python/dolma/core/utils.py @@ -1,6 +1,9 @@ +import importlib +import os import re import string -from typing import List +import sys +from typing import List, Union try: import blingfire @@ -19,8 +22,10 @@ from .data_types import TextSlice +from .loggers import get_logger sent_tokenizer = PunktSentenceTokenizer() +logger = get_logger(__name__) def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str: @@ -56,7 +61,7 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]: Split a string into sentences. """ if text and BLINGFIRE_AVAILABLE: - _, offsets = blingfire.text_to_sentences_and_offsets(text) + _, offsets = blingfire.text_to_sentences_and_offsets(text) # pyright: ignore elif text: offsets = [(start, end) for start, end in sent_tokenizer.span_tokenize(text)] else: @@ -66,3 +71,53 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]: return [TextSlice(doc=text, start=start, end=end) for (start, end) in offsets] else: raise NotImplementedError("remove_empty=False is not implemented yet") + + +def import_modules(modules_path: Union[List[str], None]): + """Import one or more user modules from either names or paths. + Importing from path is modeled after fairseq's import_user_module function: + github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/utils.py#L464 + + Args: + modules_path (Union[List[str], None]): List of module names or paths to import. + """ + + for module_path in modules_path or []: + # try importing the module directly + try: + importlib.import_module(module_path) + continue + except ModuleNotFoundError: + pass + except Exception as exp: + raise RuntimeError(f"Could not import module {module_path}: {exp}") from exp + + # if that fails, try importing the module as a path + + # check if this function has a memorization attribute; if not; create it + # the memorization attribute is used to ensure that user modules are only imported once + if (modules_memo := getattr(import_modules, "memo", None)) is None: + modules_memo = set() + import_modules.memo = modules_memo # type: ignore + + # ensure that user modules are only imported once + if module_path not in modules_memo: + modules_memo.add(module_path) + + if not os.path.exists(module_path): + raise FileNotFoundError(f"Could not find module {module_path}") + + # the format is `/.py` or `/` + module_parent, module_name = os.path.split(module_path) + module_name, _ = os.path.splitext(module_name) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + elif module_path in sys.modules[module_name].__path__: + logger.info(f"{module_path} has already been imported.") + else: + raise ImportError( + f"Failed to import {module_path} because the corresponding module name " + f"({module_name}) is not globally unique. Please rename the directory to " + "something unique and try again." + ) diff --git a/tests/python/extras/__init__.py b/tests/python/extras/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/python/extras/extras_from_module/__init__.py b/tests/python/extras/extras_from_module/__init__.py new file mode 100644 index 00000000..5b40202a --- /dev/null +++ b/tests/python/extras/extras_from_module/__init__.py @@ -0,0 +1 @@ +from .extra_taggers import * # noqa diff --git a/tests/python/extras/extras_from_module/extra_taggers.py b/tests/python/extras/extras_from_module/extra_taggers.py new file mode 100644 index 00000000..92a3fb61 --- /dev/null +++ b/tests/python/extras/extras_from_module/extra_taggers.py @@ -0,0 +1,6 @@ +from dolma import BaseTagger, add_tagger + + +@add_tagger("extra_v1") +class ExtraV1Tagger(BaseTagger): + ... diff --git a/tests/python/extras/extras_from_module_path/__init__.py b/tests/python/extras/extras_from_module_path/__init__.py new file mode 100644 index 00000000..5b40202a --- /dev/null +++ b/tests/python/extras/extras_from_module_path/__init__.py @@ -0,0 +1 @@ +from .extra_taggers import * # noqa diff --git a/tests/python/extras/extras_from_module_path/extra_taggers.py b/tests/python/extras/extras_from_module_path/extra_taggers.py new file mode 100644 index 00000000..c88c7750 --- /dev/null +++ b/tests/python/extras/extras_from_module_path/extra_taggers.py @@ -0,0 +1,6 @@ +from dolma import BaseTagger, add_tagger + + +@add_tagger("extra_v3") +class ExtraV1Tagger(BaseTagger): + ... diff --git a/tests/python/extras/extras_from_path/__init__.py b/tests/python/extras/extras_from_path/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/python/extras/extras_from_path/extra_taggers.py b/tests/python/extras/extras_from_path/extra_taggers.py new file mode 100644 index 00000000..bbe0b2ea --- /dev/null +++ b/tests/python/extras/extras_from_path/extra_taggers.py @@ -0,0 +1,6 @@ +from dolma import BaseTagger, add_tagger + + +@add_tagger("extra_v2") +class ExtraV2Tagger(BaseTagger): + ... diff --git a/tests/python/test_extra.py b/tests/python/test_extra.py new file mode 100644 index 00000000..be2866d7 --- /dev/null +++ b/tests/python/test_extra.py @@ -0,0 +1,24 @@ +import sys +import unittest +from pathlib import Path + +from dolma.core.registry import TaggerRegistry +from dolma.core.utils import import_modules + + +class TestExtra(unittest.TestCase): + def setUp(self) -> None: + self.current_path = Path(__file__).parent.absolute() + + def test_import_from_module(self): + sys.path.append(f"{self.current_path}/extras") + import_modules(["extras_from_module"]) + self.assertTrue(TaggerRegistry.has("extra_v1")) + + def test_import_from_path(self): + import_modules([f"{self.current_path}/extras/extras_from_path/extra_taggers.py"]) + self.assertTrue(TaggerRegistry.has("extra_v2")) + + def test_import_from_module_path(self): + import_modules([f"{self.current_path}/extras/extras_from_module_path"]) + self.assertTrue(TaggerRegistry.has("extra_v3"))