diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 97fcb7ac..06779d0d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -71,7 +71,7 @@ import alog # Local -from caikit_nlp.modules.text_embedding.utils import env_val_to_bool, env_val_to_int +from caikit_nlp.modules.text_embedding.utils import env_val_to_bool logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) @@ -99,19 +99,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument SentenceTransformer = SentenceTransformerNotAvailable -embedding_cfg = get_config().get("embedding", {}) - -AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast")) -IPEX = env_val_to_bool(val=embedding_cfg.get("ipex")) -PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) -RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) -BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) -NO_IMPLICIT_TRUNCATION = env_val_to_bool( - val=embedding_cfg.get("implicit_truncation_errors", True) -) -DEVICE = embedding_cfg.get("device", "") -TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code") - RT = TypeVar("RT") # return type @@ -146,8 +133,6 @@ class TruncatedTokensTuple(NamedTuple): ], ) class EmbeddingModule(ModuleBase): - # Retry count if enabled to try again (was for thread contention errors) - RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop! _ARTIFACTS_PATH_KEY = "artifacts_path" _ARTIFACTS_PATH_DEFAULT = "artifacts" @@ -159,13 +144,33 @@ def __init__( super().__init__() self.model = model + # Read config/env settings that are needed at run_* time. + embedding_cfg = get_config().get("embedding", {}) + + self.autocast = env_val_to_bool(embedding_cfg.get("autocast")) + self.no_implicit_truncation = env_val_to_bool( + embedding_cfg.get("implicit_truncation_errors", True) + ) + + self.batch_size = embedding_cfg.get("batch_size", 0) + error.type_check("", int, EMBEDDING_BATCH_SIZE=self.batch_size) + + # Retry count if enabled to try again (was for thread contention errors) + retries = embedding_cfg.get("retries", 0) + error.type_check("", int, EMBEDDING_RETRIES=retries) + self.retry_count = max( + retries, 0 + ) # Ensure non-negative, before using in loop! (treat <0 as zero) + @classmethod - def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": + def load( + cls, model_path: Union[str, ModuleConfig], *args, **kwargs + ) -> "EmbeddingModule": """Load model Args: - model_path: str - Path to the config dir under the model_id (where the config.yml lives) + model_path (Union[str, ModuleConfig]): Path to saved model or + in-memory ModuleConfig Returns: EmbeddingModule @@ -173,28 +178,38 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": """ config = ModuleConfig.load(model_path) - artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) + error.dir_check("", config.model_path) + artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) error.value_check( "", artifacts_path, ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), ) - artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path)) + artifacts_path = os.path.abspath( + os.path.join(config.model_path, artifacts_path) + ) error.dir_check("", artifacts_path) - ipex = cls._get_ipex(IPEX) - device = cls._select_device(ipex, DEVICE) + # Read config/env settings that are needed at load time. + embedding_cfg = get_config().get("embedding", {}) + + autocast = env_val_to_bool(embedding_cfg.get("autocast")) + pt2_compile = env_val_to_bool(embedding_cfg.get("pt2_compile")) + trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code")) + ipex = cls._get_ipex(env_val_to_bool(embedding_cfg.get("ipex"))) + device = cls._select_device(ipex, embedding_cfg.get("device", "")) + model = SentenceTransformerWithTruncate( model_name_or_path=artifacts_path, device=device, - trust_remote_code=TRUST_REMOTE_CODE, + trust_remote_code=trust_remote_code, ) model.eval() # required for IPEX at least if device is not None: model.to(torch.device(device)) - model = EmbeddingModule._optimize(model, ipex, device, AUTOCAST, PT2_COMPILE) + model = EmbeddingModule._optimize(model, ipex, device, autocast, pt2_compile) return cls(model) @property @@ -310,16 +325,16 @@ def _optimize(model, ipex, device, autocast, pt2_compile): def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT: first_exception = None - for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed) + for count in range(1 + self.retry_count): # try once plus retries (if needed) try: return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-exception-caught if first_exception is None: first_exception = e - if self.RETRY_COUNT > 0: + if self.retry_count > 0: warn_msg = f"Try {count + 1}: {fn} failed due to: {e}" logger.warning("", warn_msg, exc_info=True) - if count + 1 < self.RETRY_COUNT: + if count + 1 < self.retry_count: time.sleep(0.1 * (count * 2)) # If above return did not happen, raise the first exception @@ -334,16 +349,17 @@ def _encode_with_retry( """All encode calls should use this for consistent param adding and retry loop""" # Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE - if BATCH_SIZE > 0: + if self.batch_size > 0: if kwargs is None: kwargs = {} if "batch_size" not in kwargs: - kwargs["batch_size"] = BATCH_SIZE + kwargs["batch_size"] = self.batch_size if isinstance(self.model, SentenceTransformerWithTruncate): kwargs[ "implicit_truncation_errors" - ] = NO_IMPLICIT_TRUNCATION # config/env overrides default + ] = self.no_implicit_truncation # config/env overrides default + kwargs["autocast"] = self.autocast # config/env overrides default return self._with_retry(self.model.encode, *args, **kwargs) # Else... @@ -357,6 +373,8 @@ def _encode_with_retry( del kwargs["return_token_count"] if "implicit_truncation_errors" in kwargs: del kwargs["implicit_truncation_errors"] + if "autocast" in kwargs: + del kwargs["autocast"] return self._with_retry(self.model.encode, *args, **kwargs) @EmbeddingTask.taskmethod() @@ -718,19 +736,21 @@ def add_query(q): ) @classmethod - def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": + def bootstrap(cls, *args, **kwargs) -> "EmbeddingModule": """Bootstrap a sentence-transformers model Args: - model_name_or_path: str - Model name (Hugging Face hub) or path to model to load. + kwargs are passed to SentenceTransformer(**kwargs) """ - return cls( - model=SentenceTransformer( - model_name_or_path=model_name_or_path, - trust_remote_code=TRUST_REMOTE_CODE, + + if "trust_remote_code" not in kwargs: + # Read config/env settings that are needed at bootstrap time. + embedding_cfg = get_config().get("embedding", {}) + kwargs["trust_remote_code"] = env_val_to_bool( + embedding_cfg.get("trust_remote_code") ) - ) + + return cls(model=SentenceTransformer(*args, **kwargs)) def save(self, model_path: str, *args, **kwargs): """Save model using config in model_path @@ -1056,6 +1076,7 @@ def encode( truncate_input_tokens: int = 0, return_token_count: bool = False, implicit_truncation_errors: bool = True, + autocast: bool = False, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -1083,6 +1104,7 @@ def encode( :param return_token_count: If true, a tuple is returned to add the input token count. :param implicit_truncation_errors: If true (default) implicit truncation throws an error. If false, the model default behavior or used. + :param autocast: If true (not default) run with torch.cpu.amp.autocast() :return: If return_token_count is False, the embedding is returned as a numpy matrix. @@ -1171,7 +1193,7 @@ def encode( features = batch_to_device(features, device) - if AUTOCAST: + if autocast: with torch.no_grad(), torch.cpu.amp.autocast(): out_features = self.forward(features) embeddings = out_features["sentence_embedding"] diff --git a/caikit_nlp/modules/text_embedding/utils.py b/caikit_nlp/modules/text_embedding/utils.py index 39adfb82..377f6dac 100644 --- a/caikit_nlp/modules/text_embedding/utils.py +++ b/caikit_nlp/modules/text_embedding/utils.py @@ -22,11 +22,3 @@ def env_val_to_bool(val): # For testing env vars for values that mean false (else True!) return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "") - - -def env_val_to_int(val, default): - """Returns the integer value of env var or default value if None or invalid integer""" - try: - return int(val) - except (TypeError, ValueError): - return default diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 20a70fca..cbd22ee8 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -24,6 +24,7 @@ Token, TokenizationResults, ) +import aconfig # Local from caikit_nlp.modules.text_embedding import EmbeddingModule, utils @@ -249,10 +250,22 @@ def test_save_type_checks(model_path): BOOTSTRAPPED_MODEL.save(model_path) +def test_load_without_model_path(): + """Test coverage for the error message when config has no model_path""" + match = "stat: path should be string, bytes, os.PathLike or integer, not NoneType" + with pytest.raises(TypeError, match=match): + EmbeddingModule.load(ModuleConfig({})) + + def test_load_without_artifacts(): """Test coverage for the error message when config has no artifacts to load""" - with pytest.raises(ValueError): - EmbeddingModule.load(ModuleConfig({})) + with tempfile.TemporaryDirectory(suffix="-load") as model_dir: + config_yml_path = os.path.join(model_dir, "config.yml") + with open(config_yml_path, "a") as f: + f.write("module_id: foo") + match = "value check failed: Model config missing 'artifacts_path'" + with pytest.raises(ValueError, match=match): + EmbeddingModule.load(ModuleConfig({}).load(model_dir)) def test_run_embedding_type_check(loaded_model): @@ -856,7 +869,7 @@ def fn(): def test__with_retry_fail_fail(loaded_model, monkeypatch): """fn needs a few tries, tries twice and fails.""" - monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries + monkeypatch.setattr(loaded_model, "retry_count", 1) # less than 3 tries def generate_ints(): yield from range(9) # More than enough for retry loop @@ -880,7 +893,7 @@ def fail_fail_win(): def test__with_retry_fail_fail_win(loaded_model, monkeypatch): """fn needs a few tries, logs, loops and succeeds""" - monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries + monkeypatch.setattr(loaded_model, "retry_count", 6) # test needs at least 3 tries def generate_ints(): yield from range(9) # More than enough for retry loop @@ -915,21 +928,33 @@ def test_env_val_to_bool(): assert utils.env_val_to_bool(" tRuE ") -def test_env_val_to_int(): +def test_config_val_to_int(): + conf = aconfig.Config( + { + "zero": 0, + "zero_str": "0", + "false": False, + "number_str": "456", + "number_str2": " 456 ", + "true": True, + "non_int_str": "non int str", + } + ) expected_default = 12345 - assert expected_default == utils.env_val_to_int(None, expected_default) - assert expected_default == utils.env_val_to_int("", expected_default) - assert expected_default == utils.env_val_to_int(" ", expected_default) - assert expected_default == utils.env_val_to_int(" ss ", expected_default) - assert expected_default == utils.env_val_to_int(" sss ", expected_default) - assert expected_default == utils.env_val_to_int(" ssss ", expected_default) - - assert 0 == utils.env_val_to_int(0, expected_default) - assert 0 == utils.env_val_to_int("0", expected_default) - assert 0 == utils.env_val_to_int(False, expected_default) - assert 456 == utils.env_val_to_int("456", expected_default) - assert 456 == utils.env_val_to_int(" 456 ", expected_default) - assert 1 == utils.env_val_to_int(True, expected_default) + assert expected_default == conf.get("bogus", expected_default) + + assert 0 == conf.get("zero", expected_default) + assert 0 == int(conf.get("zero_str", expected_default)) + assert 0 == int(conf.get("false", expected_default)) + assert 456 == int(conf.get("number_str", expected_default)) + assert 456 == int(conf.get("number_str2", expected_default)) + assert 1 == conf.get("true", expected_default) + assert 1 == int(conf.get("true", expected_default)) + + assert "non int str" == conf.get("non_int_str", 123) # default not used + # Using a bad config (e.g., some non-integer string) with int() will raise ValueError + with pytest.raises(ValueError): + int(conf.get("non_int_str", 123)) # default not used, int("non int str") raises @pytest.mark.parametrize(