Skip to content

Commit

Permalink
New language ID
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni authored Dec 30, 2024
1 parent a824220 commit a855690
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions python/dolma/taggers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
@kylel, @soldni
"""

from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Tuple, Iterable

import necessary
import regex
from anyascii import anyascii

from ..core.data_types import DocResult, Document, Span
from ..core.ft_tagger import BaseFastTextTagger
from ..core.ft_tagger import BaseFastTextTagger, TextSlice, Prediction
from ..core.registry import TaggerRegistry
from ..core.taggers import BaseTagger
from ..core.utils import split_paragraphs
Expand All @@ -32,14 +32,14 @@

with necessary.necessary("lingua", soft=True) as LINGUA_AVAILABLE:
if LINGUA_AVAILABLE or TYPE_CHECKING:
from lingua import Language, LanguageDetectorBuilder
from lingua import Language, LanguageDetectorBuilder # pylint: disable=import-error # pyright: ignore


class BaseLanguageTagger(BaseTagger):
INCLUDE_NEGATIVE = True
PREDICT_ON_PARAGRAPHS = False

def predict_text(self, text: str) -> List[Tuple[str, float]]:
def predict_text(self, text: str) -> List[Tuple[str, float]]: # pylint: disable=unused-argument
return []

def make_negative(self, spans: List[Span]) -> List[Span]:
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self) -> None:
raise ImportError(f"cld3 is not installed, cannot instantiate {self.__class__.__name__}")

def predict_text(self, text: str) -> List[Tuple[str, float]]:
pred = cld3.get_language(text) # pyright: ignore
pred = cld3.get_language(text) # pyright: ignore # pylint: disable=possibly-used-before-assignment
score = pred.probability if pred.language == "en" else 0.0
return [("en", score)]

Expand Down Expand Up @@ -114,7 +114,7 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]:
is_reliable = False
for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input):
try:
is_reliable, _, details = cld2.detect(fn(text))
is_reliable, _, details = cld2.detect(fn(text)) # pylint: disable=possibly-used-before-assignment
break
except cld2.error:
...
Expand Down Expand Up @@ -146,13 +146,16 @@ class Cld2EnglishLanguageParagraphTagger(Cld2EnglishLanguageTagger):

@TaggerRegistry.add("ft_lang_id_doc_v1")
class FastTextAllLanguagesDocumentTagger(BaseLanguageTagger, BaseFastTextTagger):
MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
MODEL_PATH = "https://dolma-artifacts/lang_id_models/fbai/lid.176.bin"
INCLUDE_NEGATIVE = False
PREDICT_ON_PARAGRAPHS = False

def __init__(self):
BaseFastTextTagger.__init__(self, model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER)

def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]:
raise RuntimeError("This method should not be called; please report this issue.")

def predict_text(self, text: str) -> List[Tuple[str, float]]:
preds = self.classifier.predict(text.lower().replace("\n", " ").strip(), k=-1)
return [(label.replace("__label__", ""), float(score)) for label, score in zip(*preds)]
Expand All @@ -165,6 +168,16 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]:
return [(lang, round(score, 2)) for lang, score in out if score > 0.01]


@TaggerRegistry.add("glotlid_doc_v3")
class FastTextAllLanguagesDocumentGlotV3Tagger(FastTextAllLanguagesDocumentTagger):
MODEL_PATH = "https://dolma-artifacts.org/lang_id_models/cis-lmu/glotlid/model_v3.bin"


@TaggerRegistry.add("glotlid_doc_v3_1e2")
class FastTextAllLanguagesDocumentGlotV3MinScoreTagger(FastTextAllLanguagesDocumentMinScoreTagger):
MODEL_PATH = "https://dolma-artifacts.org/lang_id_models/cis-lmu/glotlid/model_v3.bin"


@TaggerRegistry.add("ft_lang_id_paragraph_v1")
class FastTextAllLanguageParagraphTagger(FastTextAllLanguagesDocumentTagger):
INCLUDE_NEGATIVE = False
Expand Down Expand Up @@ -203,7 +216,8 @@ def __init__(self) -> None:
if not LANGDETECT_AVAILABLE:
raise ImportError("langdetect is not installed, please run `pip install dolma[lang]`.")

(factory := DetectorFactory()).load_profile(PROFILES_DIRECTORY)
factory = DetectorFactory() # pylint: disable=possibly-used-before-assignment
factory.load_profile(PROFILES_DIRECTORY) # pylint: disable=possibly-used-before-assignment
factory.set_seed(0)
self.detector = factory.create()
super().__init__()
Expand All @@ -213,7 +227,7 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]:
self.detector.append(text)
langs = self.detector.get_probabilities()
output = [(str(r.lang.strip().lower()), float(r.prob)) for r in langs]
except LangDetectException:
except LangDetectException: # pylint: disable=possibly-used-before-assignment
output = []
finally:
self.detector.text = ""
Expand Down Expand Up @@ -251,9 +265,13 @@ class LinguaTagger(BaseLanguageTagger):

def __init__(self) -> None:
super().__init__()
if not LINGUA_AVAILABLE:
if not LANGDETECT_AVAILABLE:
raise ImportError("langdetect is not installed, please run `pip install dolma[lang]`.")
self.detector = LanguageDetectorBuilder.from_languages(*Language.all()).build()

all_languages = Language.all() # pylint: disable=possibly-used-before-assignment
self.detector = LanguageDetectorBuilder.from_languages( # pylint: disable=possibly-used-before-assignment
*all_languages
).build()

def predict_text(self, text: str) -> List[Tuple[str, float]]:
langs_conf = self.detector.compute_language_confidence_values(text) or []
Expand Down

0 comments on commit a855690

Please sign in to comment.