From a855690b02e83dc0c8dc25f8226150f2791c091a Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 30 Dec 2024 15:04:37 -0800 Subject: [PATCH] New language ID --- python/dolma/taggers/language.py | 40 +++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/python/dolma/taggers/language.py b/python/dolma/taggers/language.py index 121fd5c6..a8ba0cf3 100644 --- a/python/dolma/taggers/language.py +++ b/python/dolma/taggers/language.py @@ -4,14 +4,14 @@ @kylel, @soldni """ -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Tuple, Iterable import necessary import regex from anyascii import anyascii from ..core.data_types import DocResult, Document, Span -from ..core.ft_tagger import BaseFastTextTagger +from ..core.ft_tagger import BaseFastTextTagger, TextSlice, Prediction from ..core.registry import TaggerRegistry from ..core.taggers import BaseTagger from ..core.utils import split_paragraphs @@ -32,14 +32,14 @@ with necessary.necessary("lingua", soft=True) as LINGUA_AVAILABLE: if LINGUA_AVAILABLE or TYPE_CHECKING: - from lingua import Language, LanguageDetectorBuilder + from lingua import Language, LanguageDetectorBuilder # pylint: disable=import-error # pyright: ignore class BaseLanguageTagger(BaseTagger): INCLUDE_NEGATIVE = True PREDICT_ON_PARAGRAPHS = False - def predict_text(self, text: str) -> List[Tuple[str, float]]: + def predict_text(self, text: str) -> List[Tuple[str, float]]: # pylint: disable=unused-argument return [] def make_negative(self, spans: List[Span]) -> List[Span]: @@ -79,7 +79,7 @@ def __init__(self) -> None: raise ImportError(f"cld3 is not installed, cannot instantiate {self.__class__.__name__}") def predict_text(self, text: str) -> List[Tuple[str, float]]: - pred = cld3.get_language(text) # pyright: ignore + pred = cld3.get_language(text) # pyright: ignore # pylint: disable=possibly-used-before-assignment score = pred.probability if pred.language == "en" else 0.0 return [("en", score)] @@ -114,7 +114,7 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]: is_reliable = False for fn in (self._identity_fn, self._to_ascii_input, self._sanitize_input): try: - is_reliable, _, details = cld2.detect(fn(text)) + is_reliable, _, details = cld2.detect(fn(text)) # pylint: disable=possibly-used-before-assignment break except cld2.error: ... @@ -146,13 +146,16 @@ class Cld2EnglishLanguageParagraphTagger(Cld2EnglishLanguageTagger): @TaggerRegistry.add("ft_lang_id_doc_v1") class FastTextAllLanguagesDocumentTagger(BaseLanguageTagger, BaseFastTextTagger): - MODEL_PATH = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + MODEL_PATH = "https://dolma-artifacts/lang_id_models/fbai/lid.176.bin" INCLUDE_NEGATIVE = False PREDICT_ON_PARAGRAPHS = False def __init__(self): BaseFastTextTagger.__init__(self, model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER) + def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]: + raise RuntimeError("This method should not be called; please report this issue.") + def predict_text(self, text: str) -> List[Tuple[str, float]]: preds = self.classifier.predict(text.lower().replace("\n", " ").strip(), k=-1) return [(label.replace("__label__", ""), float(score)) for label, score in zip(*preds)] @@ -165,6 +168,16 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]: return [(lang, round(score, 2)) for lang, score in out if score > 0.01] +@TaggerRegistry.add("glotlid_doc_v3") +class FastTextAllLanguagesDocumentGlotV3Tagger(FastTextAllLanguagesDocumentTagger): + MODEL_PATH = "https://dolma-artifacts.org/lang_id_models/cis-lmu/glotlid/model_v3.bin" + + +@TaggerRegistry.add("glotlid_doc_v3_1e2") +class FastTextAllLanguagesDocumentGlotV3MinScoreTagger(FastTextAllLanguagesDocumentMinScoreTagger): + MODEL_PATH = "https://dolma-artifacts.org/lang_id_models/cis-lmu/glotlid/model_v3.bin" + + @TaggerRegistry.add("ft_lang_id_paragraph_v1") class FastTextAllLanguageParagraphTagger(FastTextAllLanguagesDocumentTagger): INCLUDE_NEGATIVE = False @@ -203,7 +216,8 @@ def __init__(self) -> None: if not LANGDETECT_AVAILABLE: raise ImportError("langdetect is not installed, please run `pip install dolma[lang]`.") - (factory := DetectorFactory()).load_profile(PROFILES_DIRECTORY) + factory = DetectorFactory() # pylint: disable=possibly-used-before-assignment + factory.load_profile(PROFILES_DIRECTORY) # pylint: disable=possibly-used-before-assignment factory.set_seed(0) self.detector = factory.create() super().__init__() @@ -213,7 +227,7 @@ def predict_text(self, text: str) -> List[Tuple[str, float]]: self.detector.append(text) langs = self.detector.get_probabilities() output = [(str(r.lang.strip().lower()), float(r.prob)) for r in langs] - except LangDetectException: + except LangDetectException: # pylint: disable=possibly-used-before-assignment output = [] finally: self.detector.text = "" @@ -251,9 +265,13 @@ class LinguaTagger(BaseLanguageTagger): def __init__(self) -> None: super().__init__() - if not LINGUA_AVAILABLE: + if not LANGDETECT_AVAILABLE: raise ImportError("langdetect is not installed, please run `pip install dolma[lang]`.") - self.detector = LanguageDetectorBuilder.from_languages(*Language.all()).build() + + all_languages = Language.all() # pylint: disable=possibly-used-before-assignment + self.detector = LanguageDetectorBuilder.from_languages( # pylint: disable=possibly-used-before-assignment + *all_languages + ).build() def predict_text(self, text: str) -> List[Tuple[str, float]]: langs_conf = self.detector.compute_language_confidence_values(text) or []