Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Nov 6, 2023
1 parent 8a6ba4e commit 7c293e9
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 8 deletions.
10 changes: 10 additions & 0 deletions python/dolma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
# must import taggers to register them
# we import the rust extension here and wrap it in a python module
from . import dolma as _dolma # type: ignore # noqa: E402
from .core import TaggerRegistry # noqa: E402
from .core.errors import DolmaRustPipelineError # noqa: E402
from .core.taggers import BaseTagger # noqa: E402
from .taggers import * # noqa: E402

__all__ = [
"add_tagger",
"BaseTagger",
]

# we create a shortcut to easily add taggers to the registry
add_tagger = TaggerRegistry.add


def deduper(config: dict):
try:
Expand Down
6 changes: 5 additions & 1 deletion python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ class TaggerConfig:
)
tagger_modules: List[str] = field(
default=[],
help="List of Python modules in $PYTHONPATH to import custom taggers from.",
help=(
"Additional modules to import taggers from; this is useful for taggers that are not part of Dolma. "
"Modules must be available in $PYTHONPATH or a path to module. Taggers should be registered using the "
"@dolma.add_tagger(...) decorator."
),
)
taggers: List[str] = field(
default=[],
Expand Down
13 changes: 13 additions & 0 deletions python/dolma/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def _add(

return _add

@classmethod
def remove(cls, name: str) -> bool:
"""Remove a tagger from the registry."""
if name in cls.__taggers:
cls.__taggers.pop(name)
return True
return False

@classmethod
def has(cls, name: str) -> bool:
"""Check if a tagger exists in the registry."""
return name in cls.__taggers

@classmethod
def get(cls, name: str) -> Type[BaseTagger]:
"""Get a tagger from the registry; raise ValueError if it doesn't exist."""
Expand Down
60 changes: 53 additions & 7 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import os
import re
import string
from importlib import import_module
from typing import List
import sys
from typing import List, Union

try:
import blingfire
Expand All @@ -20,8 +22,10 @@


from .data_types import TextSlice
from .loggers import get_logger

sent_tokenizer = PunktSentenceTokenizer()
logger = get_logger(__name__)


def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str:
Expand Down Expand Up @@ -57,7 +61,7 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]:
Split a string into sentences.
"""
if text and BLINGFIRE_AVAILABLE:
_, offsets = blingfire.text_to_sentences_and_offsets(text)
_, offsets = blingfire.text_to_sentences_and_offsets(text) # pyright: ignore
elif text:
offsets = [(start, end) for start, end in sent_tokenizer.span_tokenize(text)]
else:
Expand All @@ -69,9 +73,51 @@ def split_sentences(text: str, remove_empty: bool = True) -> List[TextSlice]:
raise NotImplementedError("remove_empty=False is not implemented yet")


def import_modules(modules: List[str]):
for module in modules:
def import_modules(modules_path: Union[List[str], None]):
"""Import one or more user modules from either names or paths.
Importing from path is modeled after fairseq's import_user_module function:
github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/utils.py#L464
Args:
modules_path (Union[List[str], None]): List of module names or paths to import.
"""

for module_path in modules_path or []:
# try importing the module directly
try:
import_module(module)
importlib.import_module(module_path)
continue
except ModuleNotFoundError:
raise RuntimeError("Did not find module named {module}")
pass
except Exception as exp:
raise RuntimeError(f"Could not import module {module_path}: {exp}") from exp

# if that fails, try importing the module as a path

# check if this function has a memorization attribute; if not; create it
# the memorization attribute is used to ensure that user modules are only imported once
if (modules_memo := getattr(import_modules, "memo", None)) is None:
modules_memo = set()
import_modules.memo = modules_memo # type: ignore

# ensure that user modules are only imported once
if module_path not in modules_memo:
modules_memo.add(module_path)

if not os.path.exists(module_path):
raise FileNotFoundError(f"Could not find module {module_path}")

# the format is `<module_parent>/<module_name>.py` or `<module_parent>/<module_name>`
module_parent, module_name = os.path.split(module_path)
module_name, _ = os.path.splitext(module_name)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
elif module_path in sys.modules[module_name].__path__:
logger.info(f"{module_path} has already been imported.")
else:
raise ImportError(
f"Failed to import {module_path} because the corresponding module name "
f"({module_name}) is not globally unique. Please rename the directory to "
"something unique and try again."
)
Empty file added tests/python/extras/__init__.py
Empty file.
1 change: 1 addition & 0 deletions tests/python/extras/extras_from_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .extra_taggers import * # noqa
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_module/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v1")
class ExtraV1Tagger(BaseTagger):
...
1 change: 1 addition & 0 deletions tests/python/extras/extras_from_module_path/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .extra_taggers import * # noqa
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_module_path/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v3")
class ExtraV1Tagger(BaseTagger):
...
Empty file.
6 changes: 6 additions & 0 deletions tests/python/extras/extras_from_path/extra_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dolma import BaseTagger, add_tagger


@add_tagger("extra_v2")
class ExtraV2Tagger(BaseTagger):
...
24 changes: 24 additions & 0 deletions tests/python/test_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import unittest
from pathlib import Path

from dolma.core.registry import TaggerRegistry
from dolma.core.utils import import_modules


class TestExtra(unittest.TestCase):
def setUp(self) -> None:
self.current_path = Path(__file__).parent.absolute()

def test_import_from_module(self):
sys.path.append(f"{self.current_path}/extras")
import_modules(["extras_from_module"])
self.assertTrue(TaggerRegistry.has("extra_v1"))

def test_import_from_path(self):
import_modules([f"{self.current_path}/extras/extras_from_path/extra_taggers.py"])
self.assertTrue(TaggerRegistry.has("extra_v2"))

def test_import_from_module_path(self):
import_modules([f"{self.current_path}/extras/extras_from_module_path"])
self.assertTrue(TaggerRegistry.has("extra_v3"))

0 comments on commit 7c293e9

Please sign in to comment.