Skip to content

Commit

Permalink
Addressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Oct 17, 2024
1 parent 9faf790 commit c02c4f3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
13 changes: 5 additions & 8 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def __init__(self, **kwargs):
self._model = SentenceTransformer(self.name)

def set_mode(self, mode: EmbeddingModes) -> None:
"""Set the embedding mode. SentenceTransformer does not support modes, so this is a no-op."""
# SentenceTransformer does not support different modes.
pass

async def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Expand Down Expand Up @@ -897,11 +897,8 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:
Args:
embedding: The embedding model identifier. Supports prefixes like "st-" for SentenceTransformer
and "hybrid-" for combining multiple embedding models.
and "hybrid-" for combining multiple embedding models.
**kwargs: Additional keyword arguments for the embedding model.
Returns:
EmbeddingModel: An instance of a subclass of EmbeddingModel.
"""
embedding = embedding.strip() # Remove any leading/trailing whitespace

Expand Down Expand Up @@ -930,18 +927,18 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:

return SentenceTransformerEmbeddingModel(
name=model_name,
config=kwargs, # Pass any additional configurations via config
config=kwargs,
)

if embedding.startswith("litellm-"):
# Extract the LiteLLM model name after "litellm-"
model_name = embedding[len("litellm-") :].strip()
if not model_name:
raise ValueError("LiteLLM model name must be specified after 'litellm-'.")
raise ValueError("model name must be specified after 'litellm-'.")

return LiteLLMEmbeddingModel(
name=model_name,
config=kwargs, # Pass any additional configurations via config
config=kwargs,
)

if embedding == "sparse":
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ datasets = [
ldp = [
"ldp>=0.9", # For alg namespace grouping
]
sentence-transformers = [
"sentence-transformers",
]
typing = [
"pandas-stubs",
"types-PyYAML",
Expand Down Expand Up @@ -405,7 +408,7 @@ trailing_comma_inline_array = true
dev-dependencies = [
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"paper-qa[datasets,ldp,typing,zotero]",
"paper-qa[datasets,ldp,typing,zotero,sentence-transformers]",
"pre-commit>=3.4", # Pin to keep recent
"pydantic~=2.0",
"pylint-pydantic",
Expand All @@ -419,6 +422,5 @@ dev-dependencies = [
"pytest>=8", # Pin to keep recent
"python-dotenv",
"refurb>=2", # Pin to keep recent
"sentence-transformers", # for unit tests
"typeguard",
]
9 changes: 6 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c02c4f3

Please sign in to comment.