From 59ed8d317497c1b848f3029a8f1a34d44b014912 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 16 Jan 2024 14:50:20 -0800 Subject: [PATCH] Made it easier to access LLM names --- README.md | 9 ++++----- paperqa/docs.py | 24 +++++++++++++++++++++++- paperqa/llms.py | 12 ++++++++++++ tests/test_paperqa.py | 13 +++++++++++-- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dd575ff63..b4511407e 100644 --- a/README.md +++ b/README.md @@ -100,17 +100,17 @@ docs = Docs(llm='gpt-3.5-turbo') or you can use any other model available in [langchain](https://github.com/hwchase17/langchain): ```py -from paperqa import Docs, LangchainLLMModel +from paperqa import Docs from langchain_community.chat_models import ChatAnthropic -docs = Docs(llm_model=LangchainLLMModel(), +docs = Docs(llm="langchain", client=ChatAnthropic()) ``` -Note we split the model into `LangchainLLMModel` (always empty) and `client` which is `ChatAnthropic`. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. +Note we split the model into the wrapper and `client`, which is `ChatAnthropic` here. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. ```py import pickle -docs = Docs(llm_model=LangchainLLMModel(), +docs = Docs(llm="langchain", client=ChatAnthropic()) model_str = pickle.dumps(docs) docs = pickle.loads(model_str) @@ -118,7 +118,6 @@ docs = pickle.loads(model_str) docs.set_client(ChatAnthropic()) ``` - #### Locally Hosted You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. diff --git a/paperqa/docs.py b/paperqa/docs.py index c9c45dc38..8d486a111 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -103,6 +103,10 @@ def __init__(self, **data): super().__init__(**data) self._client = client self._embedding_client = embedding_client + # run this here (instead of automateically) so it has access to privates + # If I ever figure out a better way of validating privates + # I can move this back to the decorator + Docs.make_llm_names_consistent(self) @model_validator(mode="before") @classmethod @@ -136,7 +140,6 @@ def setup_alias_models(cls, data: Any) -> Any: raise ValueError( f"Could not guess embedding model type for {data['embedding']}. " ) - return data @model_validator(mode="after") @@ -157,6 +160,24 @@ def config_summary_llm_config(cls, data: Any) -> Any: data.summary_llm_model = data.llm_model return data + @classmethod + def make_llm_names_consistent(cls, data: Any) -> Any: + if isinstance(data, Docs): + data.llm = data.llm_model.name + if data.llm == "langchain": + # from langchain models - kind of hacky + # langchain models cannot know type until + # it sees client + data.llm_model.infer_llm_type(data._client) + data.llm = data.llm_model.name + if data.summary_llm_model is not None: + if data.summary_llm == "langchain": + # from langchain models - kind of hacky + data.summary_llm_model.infer_llm_type(data._client) + data.summary_llm = data.summary_llm_model.name + + return data + def clear_docs(self): self.texts = [] self.docs = {} @@ -193,6 +214,7 @@ def set_client( else: embedding_client = AsyncOpenAI() self._embedding_client = embedding_client + Docs.make_llm_names_consistent(self) def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" diff --git a/paperqa/llms.py b/paperqa/llms.py index a36ecba9d..7f00f67bb 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -100,6 +100,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LLMModel(ABC, BaseModel): llm_type: str | None = None + name: str model_config = ConfigDict(extra="forbid") async def acomplete(self, client: Any, prompt: str) -> str: @@ -208,6 +209,7 @@ async def execute( class OpenAILLMModel(LLMModel): config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + name: str = "gpt-3.5-turbo" def _check_client(self, client: Any) -> AsyncOpenAI: if client is None: @@ -227,6 +229,13 @@ def guess_llm_type(cls, data: Any) -> Any: m.llm_type = guess_model_type(m.config["model"]) return m + @model_validator(mode="after") + @classmethod + def set_model_name(cls, data: Any) -> Any: + m = cast(OpenAILLMModel, data) + m.name = m.config["model"] + return m + async def acomplete(self, client: Any, prompt: str) -> str: aclient = self._check_client(client) completion = await aclient.completions.create( @@ -428,9 +437,12 @@ async def similarity_search( class LangchainLLMModel(LLMModel): """A wrapper around the wrapper langchain""" + name: str = "langchain" + def infer_llm_type(self, client: Any) -> str: from langchain_core.language_models.chat_models import BaseChatModel + self.name = client.model_name if isinstance(client, BaseChatModel): return "chat" return "completion" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index d9d05d417..fc483e3de 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -406,14 +406,15 @@ def accum(x): def test_docs(): - llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion") - docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) + docs = Docs(llm="babbage-002") docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) assert docs.docs["test"].docname == "Wiki2023" + assert docs.llm == "babbage-002" + assert docs.summary_llm == "babbage-002" def test_evidence(): @@ -486,6 +487,8 @@ async def embed_documents(self, client, texts): def test_custom_llm(): class MyLLM(LLMModel): + name: str = "myllm" + async def acomplete(self, client, prompt): assert client is None return "Echo" @@ -502,6 +505,8 @@ async def acomplete(self, client, prompt): def test_custom_llm_stream(): class MyLLM(LLMModel): + name: str = "myllm" + async def acomplete_iter(self, client, prompt): assert client is None yield "Echo" @@ -522,6 +527,8 @@ def test_langchain_llm(): from langchain_openai import ChatOpenAI, OpenAI docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo")) + assert docs.llm == "gpt-3.5-turbo" + assert docs.summary_llm == "gpt-3.5-turbo" docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -567,7 +574,9 @@ def test_langchain_llm(): docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) assert docs2._client is None + assert docs2.llm == "babbage-002" docs2.set_client(OpenAI(model="babbage-002")) + assert docs2.summary_llm == "babbage-002" docs2.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), get_callbacks=lambda x: [lambda y: print(y)],