From c02c4f33aff0443979f2a11862cd3afbb53b759a Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 17 Oct 2024 11:50:17 -0700 Subject: [PATCH] Addressing PR comments --- paperqa/llms.py | 13 +++++-------- pyproject.toml | 6 ++++-- uv.lock | 9 ++++++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index bdc15dcf..8782843b 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -222,8 +222,8 @@ def __init__(self, **kwargs): self._model = SentenceTransformer(self.name) def set_mode(self, mode: EmbeddingModes) -> None: - """Set the embedding mode. SentenceTransformer does not support modes, so this is a no-op.""" # SentenceTransformer does not support different modes. + pass async def embed_documents(self, texts: list[str]) -> list[list[float]]: """ @@ -897,11 +897,8 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel: Args: embedding: The embedding model identifier. Supports prefixes like "st-" for SentenceTransformer - and "hybrid-" for combining multiple embedding models. + and "hybrid-" for combining multiple embedding models. **kwargs: Additional keyword arguments for the embedding model. - - Returns: - EmbeddingModel: An instance of a subclass of EmbeddingModel. """ embedding = embedding.strip() # Remove any leading/trailing whitespace @@ -930,18 +927,18 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel: return SentenceTransformerEmbeddingModel( name=model_name, - config=kwargs, # Pass any additional configurations via config + config=kwargs, ) if embedding.startswith("litellm-"): # Extract the LiteLLM model name after "litellm-" model_name = embedding[len("litellm-") :].strip() if not model_name: - raise ValueError("LiteLLM model name must be specified after 'litellm-'.") + raise ValueError("model name must be specified after 'litellm-'.") return LiteLLMEmbeddingModel( name=model_name, - config=kwargs, # Pass any additional configurations via config + config=kwargs, ) if embedding == "sparse": diff --git a/pyproject.toml b/pyproject.toml index 1b1ded7f..8f562f3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,9 @@ datasets = [ ldp = [ "ldp>=0.9", # For alg namespace grouping ] +sentence-transformers = [ + "sentence-transformers", +] typing = [ "pandas-stubs", "types-PyYAML", @@ -405,7 +408,7 @@ trailing_comma_inline_array = true dev-dependencies = [ "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override - "paper-qa[datasets,ldp,typing,zotero]", + "paper-qa[datasets,ldp,typing,zotero,sentence-transformers]", "pre-commit>=3.4", # Pin to keep recent "pydantic~=2.0", "pylint-pydantic", @@ -419,6 +422,5 @@ dev-dependencies = [ "pytest>=8", # Pin to keep recent "python-dotenv", "refurb>=2", # Pin to keep recent - "sentence-transformers", # for unit tests "typeguard", ] diff --git a/uv.lock b/uv.lock index afac9482..8d55d621 100644 --- a/uv.lock +++ b/uv.lock @@ -1490,7 +1490,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.1.2.dev23+g3d0b0c6.d20241017" +version = "5.1.2.dev25+g9faf790.d20241017" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1521,6 +1521,9 @@ datasets = [ ldp = [ { name = "ldp" }, ] +sentence-transformers = [ + { name = "sentence-transformers" }, +] typing = [ { name = "pandas-stubs" }, { name = "types-pyyaml" }, @@ -1578,6 +1581,7 @@ requires-dist = [ { name = "pymupdf", specifier = ">=1.24.3" }, { name = "pyzotero", marker = "extra == 'zotero'" }, { name = "rich" }, + { name = "sentence-transformers", marker = "extra == 'sentence-transformers'" }, { name = "setuptools" }, { name = "tantivy" }, { name = "tenacity" }, @@ -1590,7 +1594,7 @@ requires-dist = [ dev = [ { name = "ipython", specifier = ">=8" }, { name = "mypy", specifier = ">=1.8" }, - { name = "paper-qa", extras = ["datasets", "ldp", "typing", "zotero"] }, + { name = "paper-qa", extras = ["datasets", "ldp", "typing", "zotero", "sentence-transformers"] }, { name = "pre-commit", specifier = ">=3.4" }, { name = "pydantic", specifier = "~=2.0" }, { name = "pylint-pydantic" }, @@ -1604,7 +1608,6 @@ dev = [ { name = "pytest-xdist" }, { name = "python-dotenv" }, { name = "refurb", specifier = ">=2" }, - { name = "sentence-transformers" }, { name = "typeguard" }, ]