diff --git a/README.md b/README.md index a5fd3b2a..b4511407 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,40 @@ -# Paper QA- [Paper QA](#paper-qa) -- [Paper QA- Paper QA](#paper-qa--paper-qa) - - [Output Example](#output-example) - - [References](#references) - - [Hugging Face Demo](#hugging-face-demo) - - [Install](#install) - - [Usage](#usage) - - [Adding Documents](#adding-documents) - - [Choosing Model](#choosing-model) - - [Adjusting number of sources](#adjusting-number-of-sources) - - [Using Code or HTML](#using-code-or-html) - - [Version 3 Changes](#version-3-changes) - - [New Features](#new-features) - - [Naming](#naming) - - [Breaking Changes](#breaking-changes) - - [Notebooks](#notebooks) - - [Where do I get papers?](#where-do-i-get-papers) - - [Zotero](#zotero) - - [Paper Scraper](#paper-scraper) - - [PDF Reading Options](#pdf-reading-options) - - [Typewriter View](#typewriter-view) - - [LLM/Embedding Caching](#llmembedding-caching) - - [Caching Embeddings](#caching-embeddings) - - [Customizing Prompts](#customizing-prompts) - - [Pre and Post Prompts](#pre-and-post-prompts) - - [FAQ](#faq) - - [How is this different from LlamaIndex?](#how-is-this-different-from-llamaindex) - - [How is this different from LangChain?](#how-is-this-different-from-langchain) - - [Can I use different LLMs?](#can-i-use-different-llms) - - [Where do the documents come from?](#where-do-the-documents-come-from) - - [Can I save or load?](#can-i-save-or-load) - +# PaperQA [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/whitead/paper-qa) [![tests](https://github.com/whitead/paper-qa/actions/workflows/tests.yml/badge.svg)](https://github.com/whitead/paper-qa) [![PyPI version](https://badge.fury.io/py/paper-qa.svg)](https://badge.fury.io/py/paper-qa) +## YOU ARE LOOKING AT PRE-RELEASE README + +**This is the README for an upcoming v4 release** + +You can see the current stable version [here](https://github.com/whitead/paper-qa/tree/84f13ea32c22b85924cd681a4b5f4fbd174afd71) + This is a minimal package for doing question and answering from PDFs or text files (which can be raw HTML). It strives to give very good answers, with no hallucinations, by grounding responses with in-text citations. -By default, it uses [OpenAI Embeddings](https://platform.openai.com/docs/guides/embeddings) with a vector DB called [FAISS](https://github.com/facebookresearch/faiss) to embed and search documents. However, via [langchain](https://github.com/hwchase17/langchain) you can use open-source models or embeddings (see details below). +By default, it uses [OpenAI Embeddings](https://platform.openai.com/docs/guides/embeddings) with a simple numpy vector DB to embed and search documents. However, via [langchain](https://github.com/hwchase17/langchain) you can use open-source models or embeddings (see details below). -PaperQA uses the process shown below: +paper-qa uses the process shown below: 1. embed docs into vectors 2. embed query into vector 3. search for top k passages in docs 4. create summary of each passage relevant to query -5. put summaries into prompt -6. generate answer with prompt +5. score and select only relevant summaries +6. put summaries into prompt +7. generate answer with prompt + +See our paper for more details: + +```bibtex +@article{lala2023paperqa, + title={PaperQA: Retrieval-Augmented Generative Agent for Scientific Research}, + author={L{\'a}la, Jakub and O'Donoghue, Odhran and Shtedritski, Aleksandar and Cox, Sam and Rodriques, Samuel G and White, Andrew D}, + journal={arXiv preprint arXiv:2312.07559}, + year={2023} +} +``` ## Output Example @@ -63,9 +50,10 @@ Tulevski2007: Tulevski, George S., et al. "Chemically assisted directed assembly Chen2014: Chen, Haitian, et al. "Large-scale complementary macroelectronics using hybrid integration of carbon nanotubes and IGZO thin-film transistors." Nature communications 5.1 (2014): 4097. -## Hugging Face Demo -[Hugging Face Demo](https://huggingface.co/spaces/whitead/paper-qa) +## What's New? + +Version 4 removed langchain from the package because it no longer supports pickling. This also simplifies the package a bit - especially prompts. Langchain can still be used, but it's not required. You can use any LLMs from langchain, but you will need to use the `LangchainLLMModel` class to wrap the model. ## Install @@ -75,17 +63,17 @@ Install with pip: pip install paper-qa ``` -## Usage +You need to have an LLM to use paper-qa. You can use OpenAI, llama.cpp (via Server), or any LLMs from langchain. OpenAI just works, as long as you have set your OpenAI API key (`export OPENAI_API_KEY=sk-...`). See instructions below for other LLMs. -Make sure you have set your OPENAI_API_KEY environment variable to your [openai api key](https://platform.openai.com/account/api-keys) +## Usage -To use paper-qa, you need to have a list of paths (valid extensions include: .pdf, .txt) and a list of citations (strings) that correspond to the paths. You can then use the `Docs` class to add the documents and then query them. If you don't have citations, `Docs` will try to guess them from the first page of your docs. +To use paper-qa, you need to have a list of paths/files/urls (valid extensions include: .pdf, .txt). You can then use the `Docs` class to add the documents and then query them. `Docs` will try to guess citation formats from the content of the files, but you can also provide them yourself. ```python from paperqa import Docs -# get a list of paths +my_docs = ...# get a list of paths docs = Docs() for d in my_docs: @@ -95,7 +83,7 @@ answer = docs.query("What manufacturing challenges are unique to bispecific anti print(answer.formatted_answer) ``` -The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question`, `context` (the summaries of passages found for answer), `references` (the docs from which the passages came), and `passages` which contain the raw text of the passages as a dictionary. +The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question` , and `context` (the summaries of passages found for answer). ### Adding Documents @@ -103,7 +91,7 @@ The answer object has the following attributes: `formatted_answer`, `answer` (an ### Choosing Model -By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4`. If you don't have gpt-4 access or would like to save money, you can adjust: +By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4-turbo`. You can adjust this: ```py docs = Docs(llm='gpt-3.5-turbo') @@ -112,51 +100,79 @@ docs = Docs(llm='gpt-3.5-turbo') or you can use any other model available in [langchain](https://github.com/hwchase17/langchain): ```py -from langchain.chat_models import ChatAnthropic, ChatOpenAI -model = ChatOpenAI(model='gpt-4') -summary_model = ChatAnthropic(model="claude-instant-v1-100k", anthropic_api_key="my-api-key") -docs = Docs(llm=model, summary_llm=summary_model) +from paperqa import Docs +from langchain_community.chat_models import ChatAnthropic +docs = Docs(llm="langchain", + client=ChatAnthropic()) ``` +Note we split the model into the wrapper and `client`, which is `ChatAnthropic` here. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. + +```py +import pickle +docs = Docs(llm="langchain", + client=ChatAnthropic()) +model_str = pickle.dumps(docs) +docs = pickle.loads(model_str) +# but you have to set the client after loading +docs.set_client(ChatAnthropic()) +``` #### Locally Hosted -You can also use any other models (or embeddings) available in [langchain](https://github.com/hwchase17/langchain). Here's an example of using `llama.cpp` to have locally hosted paper-qa: +You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. + +The easiest way to get set-up is to download a [llama file](https://github.com/Mozilla-Ocho/llamafile) and execute it with `-cb -np 4 -a my-llm-model --embedding` which will enable continuous batching and embeddings. ```py -import paperscraper -from paperqa import Docs -from langchain.llms import LlamaCpp -from langchain import PromptTemplate, LLMChain -from langchain.callbacks.manager import CallbackManager -from langchain.embeddings import LlamaCppEmbeddings -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - -# Make sure the model path is correct for your system! -llm = LlamaCpp( - model_path="./ggml-model-q4_0.bin", callbacks=[StreamingStdOutCallbackHandler()] +from paperqa import Docs, LlamaEmbeddingModel +from openai import AsyncOpenAI + +# start llamap.cpp client with + +local_client = AsyncOpenAI( + base_url="http://localhost:8080/v1", + api_key = "sk-no-key-required" ) -embeddings = LlamaCppEmbeddings(model_path="./ggml-model-q4_0.bin") -docs = Docs(llm=llm, embeddings=embeddings) +docs = Docs(client=local_client, + embedding=LlamaEmbeddingModel(), + llm_model=OpenAILLMModel(config=dict(model="my-llm-model", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) +``` -keyword_search = 'bispecific antibody manufacture' -papers = paperscraper.search_papers(keyword_search, limit=2) -for path,data in papers.items(): - try: - docs.add(path,chunk_chars=500) - except ValueError as e: - print('Could not read', path, e) +### Changing Embedding Model -answer = docs.query("What manufacturing challenges are unique to bispecific antibodies?") -print(answer) +You can use langchain embedding models, or the [SentenceTransformer](https://www.sbert.net/) models. For example + +```py +from paperqa import Docs, SentenceTransformerEmbeddingModel +from openai import AsyncOpenAI + +# start llamap.cpp client with + +local_client = AsyncOpenAI( + base_url="http://localhost:8080/v1", + api_key = "sk-no-key-required" +) + +docs = Docs(client=local_client, + embedding=SentenceTransformerEmbeddingModel(), + llm_model=OpenAILLMModel(config=dict(model="my-llm-model", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) +``` + +Just like in the above examples, we have to split the Langchain model into a client and model to keep `Docs` serializable. +```py + +from paperqa import Docs, LangchainEmbeddingModel + +docs = Docs(embedding_model=LangchainEmbeddingModel(), embedding_client=OpenAIEmbeddings()) ``` ### Adjusting number of sources You can adjust the numbers of sources (passages of text) to reduce token usage or add more context. `k` refers to the top k most relevant and diverse (may from different sources) passages. Each passage is sent to the LLM to summarize, or determine if it is irrelevant. After this step, a limit of `max_sources` is applied so that the final answer can fit into the LLM context window. Thus, `k` > `max_sources` and `max_sources` is the number of sources used in the final answer. -```python +```py docs.query("What manufacturing challenges are unique to bispecific antibodies?", k = 5, max_sources = 2) ``` @@ -178,67 +194,32 @@ answer = docs.query("Where is the search bar in the header defined?") print(answer) ``` -## Version 3 Changes - -Version 3 includes many changes to type the code, make it more focused/modular, and enable performance to very large numbers of documents. The major breaking changes are documented below: - - -### New Features - -The following new features are in v3: - -1. Memory is now possible in `query` by setting `Docs(memory=True)` - this means follow-up questions will have a record of the previous question and answer. -2. `add_url` and `add_file` are now supported for adding from URLs and file objects -3. Prompts can be customized, and now can be executed pre and post query -4. Consistent use of `dockey` and `docname` for unique and natural language names enable better tracking with external databases -5. Texts and embeddings are no longer required to be part of `Docs` object, so you can use external databases or other strategies to manage them -6. Various simplifications, bug fixes, and performance improvements - -### Naming - -The following table shows the old names and the new names: - -| Old Name | New Name | Explanation | -| :--- | :---: | ---: | -| `key` | `name` | Name is a natural language name for text. | -| `dockey` | `docname` | Docname is a natural language name for a document. | -| `hash` | `dockey` | Dockey is a unique identifier for the document. | +### Using External DB/Vector DB and Caching +You may want to cache parsed texts and embeddings in an external database or file. You can then build a Docs object from those directly: -### Breaking Changes - - -#### Pickled objects - -The pickled objects are not compatible with the new version. - -#### Agents - -The agent functionality has been removed, as it's not a core focus of the library - -#### Caching - -Caching has been removed because it's not a core focus of the library. See FAQ below for how to use caching. - -#### Answers +```py -Answers will not include passages, but instead return dockeys that can be used to retrieve the passages. Tokens/cost will also not be counted since that is built into langchain by default (see below for an example). +docs = Docs() -#### Search Query +for ... in my_docs: + doc = Doc(docname=..., citation=..., dockey=..., citation=...) + texts = [Text(text=..., name=..., doc=doc) for ... in my_texts] + docs.add_texts(texts, doc) +``` -The search query chain has been removed. You can use langchain directly to do this. +If you want to use an external vector store, you can also do that directly via langchain. For example, to use the [FAISS](https://ai.meta.com/tools/faiss/) vector store from langchain: -## Notebooks +```py +from paperqa import LangchainVectorStore, Docs +from langchain_community.vector_store import FAISS +from langchain_openai import OpenAIEmbeddings -If you want to use this in an jupyter notebook or colab, you need to run the following command: +my_index = LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()) +docs = Docs(texts_index=my_index) -```python -import nest_asyncio -nest_asyncio.apply() ``` -Also - if you know how to make this automated, please let me know! - ## Where do I get papers? Well that's a really good question! It's probably best to just download PDFs of papers you think will help answer your question and start from there. @@ -329,36 +310,22 @@ By default [PyPDF](https://pypi.org/project/pypdf/) is used since it's pure pyth pip install pymupdf ``` -## Typewriter View +## Callbacks Factory -To stream the completions as they occur (giving that ChatGPT typewriter look), you can simply instantiate models with those properties: +To execute a function on each chunk of LLM completions, you need to provide a function that when called with the name of the step produces a list of functions to execute on each chunk. For example, to get a typewriter view of the completions, you can do: ```python -from paperqa import Docs -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models import ChatOpenAI -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - -my_llm = ChatOpenAI(callbacks=[StreamingStdOutCallbackHandler()], streaming=True) -docs = Docs(llm=my_llm) -``` - -## LLM/Embedding Caching - -You can using the builtin langchain caching capabilities. Just run this code at the top of yours: - -```py -from langchain.cache import InMemoryCache -langchain.llm_cache = InMemoryCache() +def make_typewriter(step_name): + def typewriter(chunk): + print(chunk, end="") + return [typewriter] # <- note that this is a list of functions +... +docs.query("What manufacturing challenges are unique to bispecific antibodies?", get_callbacks=make_typewriter) ``` ### Caching Embeddings -In general, embeddings are cached when you pickle a `Docs` regardless of what vector store you use. If you would like to manage caching embeddings via an external database or other strategy, -you can populate a `Docs` object directly via -the `add_texts` object. That can take chunked texts and documents, which are serializable objects, to populate `Docs`. - -You also can simply use a separate vector database by setting the `doc_index` and `texts_index` explicitly when building the `Docs` object. +In general, embeddings are cached when you pickle a `Docs` regardless of what vector store you use. See above for details on more explicit management of them. ## Customizing Prompts @@ -366,17 +333,14 @@ You can customize any of the prompts, using the `PromptCollection` class. For ex ```python from paperqa import Docs, Answer, PromptCollection -from langchain.prompts import PromptTemplate -my_qaprompt = PromptTemplate( - input_variables=["context", "question"], - template="Answer the question '{question}' " +my_qaprompt = "Answer the question '{question}' " "Use the context below if helpful. " "You can cite the context using the key " "like (Example2012). " "If there is insufficient context, write a poem " "about how you cannot answer.\n\n" - "Context: {context}\n\n") + "Context: {context}\n\n" prompts=PromptCollection(qa=my_qaprompt) docs = Docs(prompts=prompts) ``` @@ -395,15 +359,7 @@ It's not that different! This is similar to the tree response method in LlamaInd ### How is this different from LangChain? -It's not! We use langchain to abstract the LLMS, and the process is very similar to the `map_reduce` chain in LangChain. - -### Can I use different LLMs? - -Yes, you can use any LLMs from [langchain](https://langchain.readthedocs.io/) by passing the `llm` argument to the `Docs` class. You can use different LLMs for summarization and for question answering too. - -### Where do the documents come from? - -You can provide your own. I use some of my own code to pull papers from Google Scholar. This code is not included because it may enable people to violate Google's terms of service and publisher's terms of service. +There has been some great work on retrievers in langchain and you could say this is an example of a retreiver. ### Can I save or load? @@ -419,4 +375,6 @@ with open("my_docs.pkl", "wb") as f: # load with open("my_docs.pkl", "rb") as f: docs = pickle.load(f) + +docs.set_client() #defaults to OpenAI ``` diff --git a/dev-requirements.txt b/dev-requirements.txt index 6202c640..92f2a9b2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,4 +6,6 @@ python-dotenv pymupdf build types-requests -numpy +langchain_openai +langchain_community +faiss-cpu diff --git a/paperqa/__init__.py b/paperqa/__init__.py index fa06c892..af23e993 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -1,5 +1,17 @@ from .docs import Answer, Docs, PromptCollection, Doc, Text, Context from .version import __version__ +from .llms import ( + LLMModel, + EmbeddingModel, + LangchainEmbeddingModel, + OpenAIEmbeddingModel, + LangchainLLMModel, + OpenAILLMModel, + LlamaEmbeddingModel, + NumpyVectorStore, + LangchainVectorStore, + SentenceTransformerEmbeddingModel, +) __all__ = [ "Docs", @@ -9,4 +21,14 @@ "Doc", "Text", "Context", + "LLMModel", + "EmbeddingModel", + "OpenAIEmbeddingModel", + "OpenAILLMModel", + "LangchainLLMModel", + "LlamaEmbeddingModel", + "SentenceTransformerEmbeddingModel", + "LangchainEmbeddingModel", + "NumpyVectorStore", + "LangchainVectorStore", ] diff --git a/paperqa/chains.py b/paperqa/chains.py deleted file mode 100644 index 75894ce6..00000000 --- a/paperqa/chains.py +++ /dev/null @@ -1,113 +0,0 @@ -import re -from typing import Any, Dict, List, Optional, cast - -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains import LLMChain -from langchain.chat_models import ChatOpenAI -from langchain.memory.chat_memory import BaseChatMemory -from langchain.prompts import PromptTemplate, StringPromptTemplate -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import LLMResult, SystemMessage - -from .prompts import default_system_prompt -from .types import CBManager - -memory_prompt = PromptTemplate( - input_variables=["memory", "start"], - template="Here are previous questions and answers, which may be referenced in subsequent questions:\n\n{memory}\n\n" - "----------------------------------------\n\n" - "{start}", -) - - -class FallbackLLMChain(LLMChain): - """Chain that falls back to synchronous generation if the async generation fails.""" - - async def agenerate( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CBManager] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - try: - run_manager = cast(AsyncCallbackManagerForChainRun, run_manager) - return await super().agenerate(input_list, run_manager=run_manager) - except NotImplementedError: - run_manager = cast(CallbackManagerForChainRun, run_manager) - return self.generate(input_list) - - -# TODO: If upstream is fixed remove this - - -class ExtendedHumanMessagePromptTemplate(HumanMessagePromptTemplate): - prompt: StringPromptTemplate - - -def make_chain( - prompt: StringPromptTemplate, - llm: BaseLanguageModel, - skip_system: bool = False, - memory: Optional[BaseChatMemory] = None, - system_prompt: str = default_system_prompt, -) -> FallbackLLMChain: - if memory and len(memory.load_memory_variables({})["memory"]) > 0: - # we copy the prompt so we don't modify the original - # TODO: Figure out pipeline prompts to avoid this - # the problem with pipeline prompts is that - # the memory is a constant (or partial), not a prompt - # and I cannot seem to make an empty prompt (or str) - # work as an input to pipeline prompt - assert isinstance( - prompt, PromptTemplate - ), "Memory only works with prompt templates - see comment above" - assert "memory" in memory.load_memory_variables({}) - new_prompt = PromptTemplate( - input_variables=prompt.input_variables, - template=memory_prompt.format( - start=prompt.template, **memory.load_memory_variables({}) - ), - ) - prompt = new_prompt - if type(llm) == ChatOpenAI: - system_message_prompt = SystemMessage(content=system_prompt) - human_message_prompt = ExtendedHumanMessagePromptTemplate(prompt=prompt) - if skip_system: - chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) - else: - chat_prompt = ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - return FallbackLLMChain(prompt=chat_prompt, llm=llm) - return FallbackLLMChain(prompt=prompt, llm=llm) - - -def get_score(text: str) -> int: - # check for N/A - last_line = text.split("\n")[-1] - if "N/A" in last_line or "n/a" in last_line or "NA" in last_line: - return 0 - score = re.search(r"[sS]core[:is\s]+([0-9]+)", text) - if not score: - score = re.search(r"\(([0-9])\w*\/", text) - if not score: - score = re.search(r"([0-9]+)\w*\/", text) - if score: - s = int(score.group(1)) - if s > 10: - s = int(s / 10) # sometimes becomes out of 100 - return s - last_few = text[-15:] - scores = re.findall(r"([0-9]+)", last_few) - if scores: - s = int(scores[-1]) - if s > 10: - s = int(s / 10) # sometimes becomes out of 100 - return s - if len(text) < 100: - return 1 - return 5 diff --git a/paperqa/contrib/zotero.py b/paperqa/contrib/zotero.py index a390cd1c..1d4330dc 100644 --- a/paperqa/contrib/zotero.py +++ b/paperqa/contrib/zotero.py @@ -4,10 +4,7 @@ from pathlib import Path from typing import List, Optional, Union, cast -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel +from pydantic import BaseModel try: from pyzotero import zotero diff --git a/paperqa/docs.py b/paperqa/docs.py index cdc064d2..bb0b026b 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -1,34 +1,40 @@ +import nest_asyncio # isort:skip import asyncio import os import re -import sys import tempfile from datetime import datetime from io import BytesIO from pathlib import Path -from typing import BinaryIO, Dict, List, Optional, Set, Union, cast - -from langchain.chat_models import ChatOpenAI -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.memory import ConversationTokenBufferMemory -from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.embeddings import Embeddings -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore -from langchain.vectorstores import FAISS - -try: - from pydantic.v1 import BaseModel, validator -except ImportError: - from pydantic import BaseModel, validator - -from .chains import get_score, make_chain +from typing import Any, BinaryIO, cast + +from openai import AsyncOpenAI +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .llms import ( + LangchainEmbeddingModel, + LangchainLLMModel, + LLMModel, + NumpyVectorStore, + OpenAILLMModel, + VectorStore, + get_score, + is_openai_model, +) from .paths import PAPERQA_DIR from .readers import read_doc -from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text +from .types import ( + Answer, + CallbackFactory, + Context, + Doc, + DocKey, + LLMResult, + PromptCollection, + Text, +) from .utils import ( gather_with_concurrency, - get_llm_name, guess_is_4xx, maybe_is_html, maybe_is_pdf, @@ -38,80 +44,193 @@ strip_citations, ) +# Apply the patch to allow nested loops +nest_asyncio.apply() + -class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True): +class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - docs: Dict[DocKey, Doc] = {} - texts: List[Text] = [] - docnames: Set[str] = set() - texts_index: Optional[VectorStore] = None - doc_index: Optional[VectorStore] = None - llm: Union[str, BaseLanguageModel] = ChatOpenAI( - temperature=0.1, model="gpt-3.5-turbo", client=None + # ephemeral vars that should not be pickled (_things) + _client: Any | None = None + _embedding_client: Any | None = None + llm: str = "default" + summary_llm: str | None = None + llm_model: LLMModel = Field( + default=OpenAILLMModel(config=dict(model="gpt-4-1106-preview", temperature=0.1)) ) - summary_llm: Optional[Union[str, BaseLanguageModel]] = None + summary_llm_model: LLMModel | None = Field(default=None, validate_default=True) + embedding: str | None = "default" + docs: dict[DocKey, Doc] = {} + texts: list[Text] = [] + docnames: set[str] = set() + texts_index: VectorStore = Field(default_factory=NumpyVectorStore) + docs_index: VectorStore = Field(default_factory=NumpyVectorStore) name: str = "default" - index_path: Optional[Path] = PAPERQA_DIR / name - embeddings: Embeddings = OpenAIEmbeddings(client=None) - max_concurrent: int = 5 - deleted_dockeys: Set[DocKey] = set() + index_path: Path | None = PAPERQA_DIR / name + batch_size: int = 1 + max_concurrent: int = 4 + deleted_dockeys: set[DocKey] = set() prompts: PromptCollection = PromptCollection() - memory: bool = False - memory_model: Optional[BaseChatMemory] = None jit_texts_index: bool = False # This is used to strip indirect citations that come up from the summary llm strip_citations: bool = True - - # TODO: Not sure how to get this to work - # while also passing mypy checks - @validator("llm", "summary_llm") - def check_llm(cls, v: Union[BaseLanguageModel, str]) -> BaseLanguageModel: - if type(v) is str: - return ChatOpenAI(temperature=0.1, model=v, client=None) - return cast(BaseLanguageModel, v) - - @validator("summary_llm", always=True) - def copy_llm_if_not_set(cls, v, values): - return v or values["llm"] - - @validator("memory_model", always=True) - def check_memory_model(cls, v, values): - if values["memory"]: - if v is None: - return ConversationTokenBufferMemory( - llm=values["summary_llm"], - max_token_limit=512, - memory_key="memory", - human_prefix="Question", - ai_prefix="Answer", - input_key="Question", - output_key="Answer", + model_config = ConfigDict(extra="forbid") + + def __init__(self, **data): + # We do it here because we need to move things to private attributes + if "embedding_client" in data: + embedding_client = data.pop("embedding_client") + # convenience to pull embedding_client from client if reasonable + elif ( + "client" in data + and data["client"] is not None + and type(data["client"]) == AsyncOpenAI + ): + # convenience + embedding_client = data["client"] + else: + if "embedding" in data and data["embedding"] != "default": + embedding_client = None + else: + embedding_client = AsyncOpenAI() + if "client" in data: + client = data.pop("client") + else: + # if llm_model is explicitly set, but not client then make it None + if "llm_model" in data and data["llm_model"] is not None: + # except if it is an OpenAILLMModel + if type(data["llm_model"]) == OpenAILLMModel: + client = AsyncOpenAI() + else: + client = None + else: + client = AsyncOpenAI() + # backwards compatibility + if "doc_index" in data: + data["docs_index"] = data.pop("doc_index") + super().__init__(**data) + self._client = client + self._embedding_client = embedding_client + # run this here (instead of automatically) so it has access to privates + # If I ever figure out a better way of validating privates + # I can move this back to the decorator + Docs.make_llm_names_consistent(self) + + @model_validator(mode="before") + @classmethod + def setup_alias_models(cls, data: Any) -> Any: + if isinstance(data, dict): + if "llm" in data and data["llm"] != "default": + if is_openai_model(data["llm"]): + data["llm_model"] = OpenAILLMModel(config=dict(model=data["llm"])) + elif data["llm"] == "langchain": + data["llm_model"] = LangchainLLMModel() + else: + raise ValueError(f"Could not guess model type for {data['llm']}. ") + if "summary_llm" in data and data["summary_llm"] is not None: + if is_openai_model(data["summary_llm"]): + data["summary_llm_model"] = OpenAILLMModel( + config=dict(model=data["summary_llm"]) + ) + else: + raise ValueError(f"Could not guess model type for {data['llm']}. ") + if "embedding" in data and data["embedding"] != "default": + if data["embedding"] == "langchain": + if "texts_index" not in data: + data["texts_index"] = NumpyVectorStore( + embedding_model=LangchainEmbeddingModel() + ) + if "docs_index" not in data: + data["docs_index"] = NumpyVectorStore( + embedding_model=LangchainEmbeddingModel() + ) + else: + raise ValueError( + f"Could not guess embedding model type for {data['embedding']}. " + ) + return data + + @model_validator(mode="after") + @classmethod + def config_summary_llm_config(cls, data: Any) -> Any: + if isinstance(data, Docs): + # check our default gpt-4/3.5-turbo config + # default check is hard - becauise either llm is set or llm_model is set + if ( + data.summary_llm_model is None + and data.llm == "default" + and type(data.llm_model) == OpenAILLMModel + ): + data.summary_llm_model = OpenAILLMModel( + config=dict(model="gpt-3.5-turbo", temperature=0.1) ) - if v.memory_variables()[0] != "memory": - raise ValueError("Memory model must have memory_variables=['memory']") - return values["memory_model"] - return None + elif data.summary_llm_model is None: + data.summary_llm_model = data.llm_model + return data + + @classmethod + def make_llm_names_consistent(cls, data: Any) -> Any: + if isinstance(data, Docs): + data.llm = data.llm_model.name + if data.llm == "langchain": + # from langchain models - kind of hacky + # langchain models cannot know type until + # it sees client + data.llm_model.infer_llm_type(data._client) + data.llm = data.llm_model.name + if data.summary_llm_model is not None: + if ( + data.summary_llm is None + and data.summary_llm_model is data.llm_model + ): + data.summary_llm = data.llm + if data.summary_llm == "langchain": + # from langchain models - kind of hacky + data.summary_llm_model.infer_llm_type(data._client) + data.summary_llm = data.summary_llm_model.name + return data def clear_docs(self): self.texts = [] self.docs = {} self.docnames = set() - def update_llm( + def __getstate__(self): + # You may wonder why make these private if we're just going + # to be overriding the behavior on setstaet/getstate anyway. + # The reason is that the other serialization methods from Pydantic - + # model_dump - will not drop private attributes. + # So - this getstate/setstate removes private attributes for pickling + # and Pydantic will handle removing private attributes for other + # serialization methods (like model_dump) + state = super().__getstate__() + # remove client from private attributes + del state["__pydantic_private__"]["_client"] + del state["__pydantic_private__"]["_embedding_client"] + return state + + def __setstate__(self, state): + # add client back to private attributes + state["__pydantic_private__"]["_client"] = None + state["__pydantic_private__"]["_embedding_client"] = None + super().__setstate__(state) + + def set_client( self, - llm: Union[BaseLanguageModel, str], - summary_llm: Optional[Union[BaseLanguageModel, str]] = None, - ) -> None: - """Update the LLM for answering questions.""" - if type(llm) is str: - llm = ChatOpenAI(temperature=0.1, model=llm, client=None) - if type(summary_llm) is str: - summary_llm = ChatOpenAI(temperature=0.1, model=summary_llm, client=None) - self.llm = cast(BaseLanguageModel, llm) - if summary_llm is None: - summary_llm = llm - self.summary_llm = cast(BaseLanguageModel, summary_llm) + client: AsyncOpenAI | None = None, + embedding_client: AsyncOpenAI | None = None, + ): + if client is None: + client = AsyncOpenAI() + self._client = client + if embedding_client is None: + if type(client) == AsyncOpenAI: + embedding_client = client + else: + embedding_client = AsyncOpenAI() + self._embedding_client = embedding_client + Docs.make_llm_names_consistent(self) def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" @@ -128,11 +247,11 @@ def _get_unique_name(self, docname: str) -> str: def add_file( self, file: BinaryIO, - citation: Optional[str] = None, - docname: Optional[str] = None, - dockey: Optional[DocKey] = None, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" # just put in temp file and use existing method suffix = ".txt" @@ -155,11 +274,11 @@ def add_file( def add_url( self, url: str, - citation: Optional[str] = None, - docname: Optional[str] = None, - dockey: Optional[DocKey] = None, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" import urllib.request @@ -177,20 +296,20 @@ def add_url( def add( self, path: Path, - citation: Optional[str] = None, - docname: Optional[str] = None, + citation: str | None = None, + docname: str | None = None, disable_check: bool = False, - dockey: Optional[DocKey] = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" if dockey is None: dockey = md5sum(path) if citation is None: # skip system because it's too hesitant to answer - cite_chain = make_chain( + cite_chain = self.llm_model.make_chain( + client=self._client, prompt=self.prompts.cite, - llm=cast(BaseLanguageModel, self.summary_llm), skip_system=True, ) # peak first chunk @@ -198,7 +317,10 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - citation = cite_chain.run(texts[0].text) + chain_result = asyncio.run( + cite_chain(dict(text=texts[0].text), None), + ) + citation = chain_result.text if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -237,7 +359,7 @@ def add( def add_texts( self, - texts: List[Text], + texts: list[Text], doc: Doc, ) -> bool: """Add chunked texts to the collection. This is useful if you have already chunked the texts yourself. @@ -253,35 +375,38 @@ def add_texts( for t in texts: t.name = t.name.replace(doc.docname, new_docname) doc.docname = new_docname - if texts[0].embeddings is None: - text_embeddings = self.embeddings.embed_documents([t.text for t in texts]) - for i, t in enumerate(texts): - t.embeddings = text_embeddings[i] - else: - text_embeddings = cast(List[List[float]], [t.embeddings for t in texts]) - if self.texts_index is not None: - try: - # TODO: Simplify - super weird - vec_store_text_and_embeddings = list( - map(lambda x: (x.text, x.embeddings), texts) + if texts[0].embedding is None: + text_embeddings = asyncio.run( + self.texts_index.embedding_model.embed_documents( + self._embedding_client, [t.text for t in texts] ) - self.texts_index.add_embeddings( # type: ignore - vec_store_text_and_embeddings, - metadatas=[t.dict(exclude={"embeddings", "text"}) for t in texts], + ) + for i, t in enumerate(texts): + t.embedding = text_embeddings[i] + if doc.embedding is None: + doc.embedding = asyncio.run( + self.docs_index.embedding_model.embed_documents( + self._embedding_client, [doc.citation] ) - except AttributeError: - raise ValueError("Need a vector store that supports adding embeddings.") - if self.doc_index is not None: - self.doc_index.add_texts([doc.citation], metadatas=[doc.dict()]) + )[0] + if not self.jit_texts_index: + self.texts_index.add_texts_and_embeddings(texts) + self.docs_index.add_texts_and_embeddings([doc]) self.docs[doc.dockey] = doc self.texts += texts self.docnames.add(doc.docname) return True def delete( - self, name: Optional[str] = None, dockey: Optional[DocKey] = None + self, + name: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, ) -> None: """Delete a document from the collection.""" + # name is an alias for docname + name = docname if name is None else name + if name is not None: doc = next((doc for doc in self.docs.values() if doc.docname == name), None) if doc is None: @@ -295,253 +420,199 @@ async def adoc_match( self, query: str, k: int = 25, - rerank: Optional[bool] = None, + rerank: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, - ) -> Set[DocKey]: + answer: Answer | None = None, # used for tracking tokens + ) -> set[DocKey]: """Return a list of dockeys that match the query.""" - if self.doc_index is None: - if len(self.docs) == 0: - return set() - texts = [doc.citation for doc in self.docs.values()] - metadatas = [d.dict() for d in self.docs.values()] - self.doc_index = FAISS.from_texts( - texts, metadatas=metadatas, embedding=self.embeddings - ) - matches = self.doc_index.max_marginal_relevance_search( - query, k=k + len(self.deleted_dockeys) + matches, _ = await self.docs_index.max_marginal_relevance_search( + self._embedding_client, + query, + k=k + len(self.deleted_dockeys), + fetch_k=5 * (k + len(self.deleted_dockeys)), ) # filter the matches - matches = [ - m for m in matches if m.metadata["dockey"] not in self.deleted_dockeys + matched_docs = [ + m for m in cast(list[Doc], matches) if m.dockey not in self.deleted_dockeys ] - try: - # for backwards compatibility (old pickled objects) - matched_docs = [self.docs[m.metadata["dockey"]] for m in matches] - except KeyError: - matched_docs = [Doc(**m.metadata) for m in matches] + if len(matched_docs) == 0: return set() # this only works for gpt-4 (in my testing) try: if ( rerank is None - and get_llm_name(cast(BaseLanguageModel, self.llm)).startswith("gpt-4") + and ( + type(self.llm_model) == OpenAILLMModel + and cast(OpenAILLMModel, self).config["model"].startswith("gpt-4") + ) or rerank is True ): - chain = make_chain( - self.prompts.select, - cast(BaseLanguageModel, self.llm), + chain = self.llm_model.make_chain( + client=self._client, + prompt=self.prompts.select, skip_system=True, ) papers = [f"{d.docname}: {d.citation}" for d in matched_docs] - result = await chain.arun( # type: ignore - question=query, - papers="\n".join(papers), - callbacks=get_callbacks("filter"), + result = await chain( + dict(question=query, papers="\n".join(papers)), + get_callbacks("filter"), ) - return set([d.dockey for d in matched_docs if d.docname in result]) + if answer: + answer.add_tokens(result) + return set([d.dockey for d in matched_docs if d.docname in str(result)]) except AttributeError: pass return set([d.dockey for d in matched_docs]) - def __getstate__(self): - state = self.__dict__.copy() - if self.texts_index is not None and self.index_path is not None: - state["texts_index"].save_local(self.index_path) - del state["texts_index"] - del state["doc_index"] - return {"__dict__": state, "__fields_set__": self.__fields_set__} - - def __setstate__(self, state): - object.__setattr__(self, "__dict__", state["__dict__"]) - object.__setattr__(self, "__fields_set__", state["__fields_set__"]) - try: - self.texts_index = FAISS.load_local(self.index_path, self.embeddings) - except Exception: - # they use some special exception type, but I don't want to import it - self.texts_index = None - self.doc_index = None - - def _build_texts_index(self, keys: Optional[Set[DocKey]] = None): + def _build_texts_index(self, keys: set[DocKey] | None = None): + texts = self.texts if keys is not None and self.jit_texts_index: - del self.texts_index - self.texts_index = None - if self.texts_index is None: - texts = self.texts + # TODO: what is JIT even for?? if keys is not None: texts = [t for t in texts if t.doc.dockey in keys] if len(texts) == 0: return - raw_texts = [t.text for t in texts] - text_embeddings = [t.embeddings for t in texts] - metadatas = [t.dict(exclude={"embeddings", "text"}) for t in texts] - self.texts_index = FAISS.from_embeddings( - # wow adding list to the zip was tricky - text_embeddings=list(zip(raw_texts, text_embeddings)), - embedding=self.embeddings, - metadatas=metadatas, + self.texts_index.clear() + self.texts_index.add_texts_and_embeddings(texts) + if self.jit_texts_index and keys is None: + # Not sure what else to do here??????? + print( + "Warning: JIT text index without keys " + "requires rebuilding index each time!" ) - - def clear_memory(self): - """Clear the memory of the model.""" - if self.memory_model is not None: - self.memory_model.clear() + self.texts_index.clear() + self.texts_index.add_texts_and_embeddings(texts) def get_evidence( self, answer: Answer, k: int = 10, max_sources: int = 5, - marginal_relevance: bool = True, get_callbacks: CallbackFactory = lambda x: None, detailed_citations: bool = False, disable_vector_search: bool = False, - disable_summarization: bool = False, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aget_evidence( answer, k=k, max_sources=max_sources, - marginal_relevance=marginal_relevance, get_callbacks=get_callbacks, detailed_citations=detailed_citations, disable_vector_search=disable_vector_search, - disable_summarization=disable_summarization, ) ) async def aget_evidence( self, answer: Answer, - k: int = 10, # Number of vectors to retrieve + k: int = 10, # Number of evidence pieces to retrieve max_sources: int = 5, # Number of scored contexts to use - marginal_relevance: bool = True, get_callbacks: CallbackFactory = lambda x: None, detailed_citations: bool = False, disable_vector_search: bool = False, - disable_summarization: bool = False, ) -> Answer: - if disable_vector_search: - k = k * 10000 - if len(self.docs) == 0 and self.doc_index is None: + if len(self.docs) == 0 and self.docs_index is None: + # do we have no docs? return answer self._build_texts_index(keys=answer.dockey_filter) - if self.texts_index is None: - return answer - self.texts_index = cast(VectorStore, self.texts_index) _k = k if answer.dockey_filter is not None: - _k = k * 10 # heuristic - if marginal_relevance: - matches = self.texts_index.max_marginal_relevance_search( - answer.question, k=_k, fetch_k=5 * _k - ) + _k = k * 10 # heuristic - get enough so we can downselect + if disable_vector_search: + matches = self.texts else: - matches = self.texts_index.similarity_search( - answer.question, k=_k, fetch_k=5 * _k + matches = cast( + list[Text], + ( + await self.texts_index.max_marginal_relevance_search( + self._embedding_client, answer.question, k=_k, fetch_k=5 * _k + ) + )[0], ) - # ok now filter + # ok now filter (like ones from adoc_match) if answer.dockey_filter is not None: - matches = [ - m - for m in matches - if m.metadata["doc"]["dockey"] in answer.dockey_filter - ] + matches = [m for m in matches if m.doc.dockey in answer.dockey_filter] # check if it is deleted - matches = [ - m - for m in matches - if m.metadata["doc"]["dockey"] not in self.deleted_dockeys - ] + matches = [m for m in matches if m.doc.dockey not in self.deleted_dockeys] # check if it is already in answer cur_names = [c.text.name for c in answer.contexts] - matches = [m for m in matches if m.metadata["name"] not in cur_names] + matches = [m for m in matches if m.name not in cur_names] # now finally cut down matches = matches[:k] async def process(match): - callbacks = get_callbacks("evidence:" + match.metadata["name"]) - summary_chain = make_chain( - self.prompts.summary, - self.summary_llm, - memory=self.memory_model, - system_prompt=self.prompts.system, - ) - # This is dangerous because it - # could mask errors that are important- like auth errors - # I also cannot know what the exception - # type is because any model could be used - # my best idea is see if there is a 4XX - # http code in the exception - try: - citation = match.metadata["doc"]["citation"] - if detailed_citations: - citation = match.metadata["name"] + ": " + citation - if self.prompts.skip_summary: - context = match.page_content - else: - context = await summary_chain.arun( - question=answer.question, - # Add name so chunk is stated - citation=citation, - summary_length=answer.summary_length, - text=match.page_content, - callbacks=callbacks, + callbacks = get_callbacks("evidence:" + match.name) + citation = match.doc.citation + # empty result + llm_result = LLMResult(model="", date="") + if detailed_citations: + citation = match.name + ": " + citation + + if self.prompts.skip_summary: + context = match.text + score = 5 + else: + summary_chain = self.summary_llm_model.make_chain( + client=self._client, + prompt=self.prompts.summary, + system_prompt=self.prompts.system, + ) + # This is dangerous because it + # could mask errors that are important- like auth errors + # I also cannot know what the exception + # type is because any model could be used + # my best idea is see if there is a 4XX + # http code in the exception + try: + llm_result = await summary_chain( + dict( + question=answer.question, + # Add name so chunk is stated + citation=citation, + summary_length=answer.summary_length, + text=match.text, + ), + callbacks, ) - except Exception as e: - if guess_is_4xx(str(e)): - return None - raise e - if "not applicable" in context.lower() or "not relevant" in context.lower(): - return None - if self.strip_citations: - # remove citations that collide with our grounded citations (for the answer LLM) - context = strip_citations(context) + context = llm_result.text + except Exception as e: + if guess_is_4xx(str(e)): + return None, llm_result + raise e + if ( + "not applicable" in context.lower() + or "not relevant" in context.lower() + ): + return None, llm_result + if self.strip_citations: + # remove citations that collide with our grounded citations (for the answer LLM) + context = strip_citations(context) + score = get_score(context) c = Context( context=context, + # below will remove embedding from Text/Doc text=Text( - text=match.page_content, - name=match.metadata["name"], - doc=Doc(**match.metadata["doc"]), + text=match.text, + name=match.name, + doc=Doc(**match.doc.model_dump()), ), - score=get_score(context), + score=score, ) - return c - - if disable_summarization: - contexts = [ - Context( - context=match.page_content, - score=10, - text=Text( - text=match.page_content, - name=match.metadata["name"], - doc=Doc(**match.metadata["doc"]), - ), - ) - for match in matches - ] + return c, llm_result - else: - results = await gather_with_concurrency( - self.max_concurrent, *[process(m) for m in matches] - ) - # filter out failures - contexts = [c for c in results if c is not None] + results = await gather_with_concurrency( + self.max_concurrent, [process(m) for m in matches] + ) + # update token counts + [answer.add_tokens(r[1]) for r in results] + + # filter out failures + contexts = [c for c, r in results if c is not None] answer.contexts = sorted( contexts + answer.contexts, key=lambda x: x.score, reverse=True @@ -566,28 +637,16 @@ def query( k: int = 10, max_sources: int = 5, length_prompt="about 100 words", - marginal_relevance: bool = True, - answer: Optional[Answer] = None, - key_filter: Optional[bool] = None, + answer: Answer | None = None, + key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aquery( query, k=k, max_sources=max_sources, length_prompt=length_prompt, - marginal_relevance=marginal_relevance, answer=answer, key_filter=key_filter, get_callbacks=get_callbacks, @@ -600,9 +659,8 @@ async def aquery( k: int = 10, max_sources: int = 5, length_prompt: str = "about 100 words", - marginal_relevance: bool = True, - answer: Optional[Answer] = None, - key_filter: Optional[bool] = None, + answer: Answer | None = None, + key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: if k < max_sources: @@ -614,7 +672,9 @@ async def aquery( # comparable - one is chunks and one is docs if key_filter or (key_filter is None and len(self.docs) > k): keys = await self.adoc_match( - answer.question, get_callbacks=get_callbacks + answer.question, + get_callbacks=get_callbacks, + answer=answer, ) if len(keys) > 0: answer.dockey_filter = keys @@ -622,43 +682,43 @@ async def aquery( answer, k=k, max_sources=max_sources, - marginal_relevance=marginal_relevance, get_callbacks=get_callbacks, ) if self.prompts.pre is not None: - chain = make_chain( - self.prompts.pre, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + chain = self.llm_model.make_chain( + client=self._client, + prompt=self.prompts.pre, system_prompt=self.prompts.system, ) - pre = await chain.arun( - question=answer.question, callbacks=get_callbacks("pre") + pre = await chain(dict(question=answer.question), get_callbacks("pre")) + answer.add_tokens(pre) + answer.context = ( + answer.context + "\n\nExtra background information:" + str(pre) ) - answer.context = answer.context + "\n\nExtra background information:" + pre bib = dict() - if len(answer.context) < 10 and not self.memory: + if len(answer.context) < 10: # and not self.memory: answer_text = ( "I cannot answer this question due to insufficient information." ) else: - callbacks = get_callbacks("answer") - qa_chain = make_chain( - self.prompts.qa, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + qa_chain = self.llm_model.make_chain( + client=self._client, + prompt=self.prompts.qa, system_prompt=self.prompts.system, ) - answer_text = await qa_chain.arun( - context=answer.context, - answer_length=answer.answer_length, - question=answer.question, - callbacks=callbacks, - verbose=True, + answer_result = await qa_chain( + dict( + context=answer.context, + answer_length=answer.answer_length, + question=answer.question, + ), + get_callbacks("answer"), ) + answer_text = answer_result.text + answer.add_tokens(answer_result) # it still happens - if "(Example2012)" in answer_text: - answer_text = answer_text.replace("(Example2012)", "") + if "(Example2012Example pages 3-4)" in answer_text: + answer_text = answer_text.replace("(Example2012Example pages 3-4)", "") for c in answer.contexts: name = c.text.name citation = c.text.doc.citation @@ -676,21 +736,21 @@ async def aquery( answer.references = bib_str if self.prompts.post is not None: - chain = make_chain( - self.prompts.post, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + chain = self.llm_model.make_chain( + client=self._client, + prompt=self.prompts.post, system_prompt=self.prompts.system, ) - post = await chain.arun(**answer.dict(), callbacks=get_callbacks("post")) - answer.answer = post + post = await chain(answer.model_dump(), get_callbacks("post")) + answer.answer = post.text + answer.add_tokens(post) answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: answer.formatted_answer += f"\nReferences\n\n{bib_str}\n" - if self.memory_model is not None: - answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] - self.memory_model.save_context( - {"Question": answer.question}, {"Answer": answer.answer} - ) + # if self.memory_model is not None: + # answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] + # self.memory_model.save_context( + # {"Question": answer.question}, {"Answer": answer.answer} + # ) return answer diff --git a/paperqa/llms.py b/paperqa/llms.py new file mode 100644 index 00000000..dea6645c --- /dev/null +++ b/paperqa/llms.py @@ -0,0 +1,659 @@ +import asyncio +import datetime +import re +from abc import ABC, abstractmethod +from inspect import signature +from typing import ( + Any, + AsyncGenerator, + Callable, + Coroutine, + Sequence, + Type, + cast, + get_args, + get_type_hints, +) + +import numpy as np +from openai import AsyncOpenAI +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .prompts import default_system_prompt +from .types import Doc, Embeddable, LLMResult, Text +from .utils import batch_iter, flatten, gather_with_concurrency + + +def guess_model_type(model_name: str) -> str: + """Guess the model type from the model name for OpenAI models""" + import openai + + model_type = get_type_hints( + openai.types.chat.completion_create_params.CompletionCreateParamsBase + )["model"] + model_union = get_args(get_args(model_type)[1]) + model_arr = list(model_union) + if model_name in model_arr: + return "chat" + return "completion" + + +def is_openai_model(model_name): + import openai + + model_type = get_type_hints( + openai.types.chat.completion_create_params.CompletionCreateParamsBase + )["model"] + model_union = get_args(get_args(model_type)[1]) + model_arr = list(model_union) + + complete_model_types = get_type_hints( + openai.types.completion_create_params.CompletionCreateParamsBase + )["model"] + complete_model_union = get_args(get_args(complete_model_types)[1]) + complete_model_arr = list(complete_model_union) + + return model_name in model_arr or model_name in complete_model_arr + + +def process_llm_config(llm_config: dict) -> dict: + """Remove model_type and try to set max_tokens""" + result = {k: v for k, v in llm_config.items() if k != "model_type"} + if "max_tokens" not in result or result["max_tokens"] == -1: + model = llm_config["model"] + # now we guess - we could use tiktoken to count, + # but do have the initative right now + if model.startswith("gpt-4") or ( + model.startswith("gpt-3.5") and "1106" in model + ): + result["max_tokens"] = 3000 + else: + result["max_tokens"] = 1500 + return result + + +async def embed_documents( + client: AsyncOpenAI, texts: list[str], embedding_model: str +) -> list[list[float]]: + """Embed a list of documents with batching""" + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + response = await client.embeddings.create( + model=embedding_model, input=texts, encoding_format="float" + ) + return [e.embedding for e in response.data] + + +class EmbeddingModel(ABC, BaseModel): + @abstractmethod + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + pass + + +class OpenAIEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="text-embedding-ada-002") + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + return await embed_documents( + cast(AsyncOpenAI, client), texts, self.embedding_model + ) + + +class LLMModel(ABC, BaseModel): + llm_type: str | None = None + name: str + model_config = ConfigDict(extra="forbid") + + async def acomplete(self, client: Any, prompt: str) -> str: + raise NotImplementedError + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + """Return an async generator that yields chunks of the completion. + + I cannot get mypy to understand the override, so marked as Any""" + raise NotImplementedError + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + raise NotImplementedError + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + """Return an async generator that yields chunks of the completion. + + I cannot get mypy to understand the override, so marked as Any""" + raise NotImplementedError + + def infer_llm_type(self, client: Any) -> str: + return "completion" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 # gross approximation + + def make_chain( + self, + client: Any, + prompt: str, + skip_system: bool = False, + system_prompt: str = default_system_prompt, + ) -> Callable[ + [dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, LLMResult] + ]: + """Create a function to execute a batch of prompts + + This replaces the previous use of langchain for combining prompts and LLMs. + + Args: + client: a ephemeral client to use + prompt: The prompt to use + skip_system: Whether to skip the system prompt + system_prompt: The system prompt to use + + Returns: + A function to execute a prompt. Its signature is: + execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> LLMResult + where data is a dict with keys for the input variables that will be formatted into prompt + and callbacks is a list of functions to call with each chunk of the completion. + """ + # check if it needs to be set + if self.llm_type is None: + self.llm_type = self.infer_llm_type(client) + if self.llm_type == "chat": + system_message_prompt = dict(role="system", content=system_prompt) + human_message_prompt = dict(role="user", content=prompt) + if skip_system: + chat_prompt = [human_message_prompt] + else: + chat_prompt = [system_message_prompt, human_message_prompt] + + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> LLMResult: + start_clock = asyncio.get_running_loop().time() + result = LLMResult( + model=self.name, + date=datetime.datetime.now().isoformat(), + ) + messages = chat_prompt[:-1] + [ + dict(role="user", content=chat_prompt[-1]["content"].format(**data)) + ] + result.prompt_count = sum( + [self.count_tokens(m["content"]) for m in messages] + ) + sum([self.count_tokens(m["role"]) for m in messages]) + + if callbacks is None: + output = await self.achat(client, messages) + else: + completion = self.achat_iter(client, messages) # type: ignore + text_result = [] + async for chunk in completion: # type: ignore + if chunk: + if result.seconds_to_first_token == 0: + result.seconds_to_first_token = ( + asyncio.get_running_loop().time() - start_clock + ) + text_result.append(chunk) + [f(chunk) for f in callbacks] + output = "".join(text_result) + result.completion_count = self.count_tokens(output) + result.text = output + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock + ) + return result + + return execute + elif self.llm_type == "completion": + if skip_system: + completion_prompt = prompt + else: + completion_prompt = system_prompt + "\n\n" + prompt + + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> LLMResult: + start_clock = asyncio.get_running_loop().time() + result = LLMResult( + model=self.name, + date=datetime.datetime.now().isoformat(), + ) + formatted_prompt = completion_prompt.format(**data) + result.prompt_count = self.count_tokens(formatted_prompt) + + if callbacks is None: + output = await self.acomplete(client, formatted_prompt) + else: + completion = self.acomplete_iter( # type: ignore + client, + formatted_prompt, + ) + text_result = [] + async for chunk in completion: # type: ignore + if chunk: + if result.seconds_to_first_token == 0: + result.seconds_to_first_token = ( + asyncio.get_running_loop().time() - start_clock + ) + text_result.append(chunk) + [f(chunk) for f in callbacks] + output = "".join(text_result) + result.completion_count = self.count_tokens(output) + result.text = output + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock + ) + return result + + return execute + raise ValueError(f"Unknown llm_type: {self.llm_type}") + + +class OpenAILLMModel(LLMModel): + config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + name: str = "gpt-3.5-turbo" + + def _check_client(self, client: Any) -> AsyncOpenAI: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + if not isinstance(client, AsyncOpenAI): + raise ValueError( + f"Your client is not a required AsyncOpenAI client. It is a {type(client)}" + ) + return cast(AsyncOpenAI, client) + + @model_validator(mode="after") + @classmethod + def guess_llm_type(cls, data: Any) -> Any: + m = cast(OpenAILLMModel, data) + m.llm_type = guess_model_type(m.config["model"]) + return m + + @model_validator(mode="after") + @classmethod + def set_model_name(cls, data: Any) -> Any: + m = cast(OpenAILLMModel, data) + m.name = m.config["model"] + return m + + async def acomplete(self, client: Any, prompt: str) -> str: + aclient = self._check_client(client) + completion = await aclient.completions.create( + prompt=prompt, **process_llm_config(self.config) + ) + return completion.choices[0].text + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + aclient = self._check_client(client) + completion = await aclient.completions.create( + prompt=prompt, **process_llm_config(self.config), stream=True + ) + async for chunk in completion: + yield chunk.choices[0].text + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + aclient = self._check_client(client) + completion = await aclient.chat.completions.create( + messages=messages, **process_llm_config(self.config) # type: ignore + ) + return completion.choices[0].message.content or "" + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + aclient = self._check_client(client) + completion = await aclient.chat.completions.create( + messages=messages, **process_llm_config(self.config), stream=True # type: ignore + ) + async for chunk in cast(AsyncGenerator, completion): + yield chunk.choices[0].delta.content + + +class LlamaEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="llama") + + batch_size: int = Field(default=4) + concurrency: int = Field(default=1) + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + cast(AsyncOpenAI, client) + + async def process(texts: list[str]) -> list[float]: + for i in range(3): + # access httpx client directly to avoid type casting + response = await client._client.post( + client.base_url.join("../embedding"), json={"content": texts} + ) + body = response.json() + if len(texts) == 1: + if type(body) != dict or body.get("embedding") is None: + continue + return [body["embedding"]] + else: + if type(body) != list or body[0] != "results": + continue + return [e["embedding"] for e in body[1]] + raise ValueError("Failed to embed documents - response was ", body) + + return flatten( + await gather_with_concurrency( + self.concurrency, + [process(b) for b in batch_iter(texts, self.batch_size)], + ) + ) + + +class SentenceTransformerEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="multi-qa-MiniLM-L6-cos-v1") + _model: Any = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError("Please install sentence-transformers to use this model") + + self._model = SentenceTransformer(self.embedding_model) + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + from sentence_transformers import SentenceTransformer + + embeddings = cast(SentenceTransformer, self._model).encode(texts) + return embeddings + + +def cosine_similarity(a, b): + dot_product = np.dot(a, b.T) + norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) + return dot_product / norm_product + + +class VectorStore(BaseModel, ABC): + """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" + + embedding_model: EmbeddingModel = Field(default=OpenAIEmbeddingModel()) + # can be tuned for different tasks + mmr_lambda: float = Field(default=0.5) + model_config = ConfigDict(extra="forbid") + + @abstractmethod + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: + pass + + @abstractmethod + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + pass + + @abstractmethod + def clear(self) -> None: + pass + + async def max_marginal_relevance_search( + self, client: Any, query: str, k: int, fetch_k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + """Vectorized implementation of Maximal Marginal Relevance (MMR) search. + + Args: + query: Query vector. + k: Number of results to return. + + Returns: + List of tuples (doc, score) of length k. + """ + if fetch_k < k: + raise ValueError("fetch_k must be greater or equal to k") + + texts, scores = await self.similarity_search(client, query, fetch_k) + if len(texts) <= k: + return texts, scores + + embeddings = np.array([t.embedding for t in texts]) + np_scores = np.array(scores) + similarity_matrix = cosine_similarity(embeddings, embeddings) + + selected_indices = [0] + remaining_indices = list(range(1, len(texts))) + + while len(selected_indices) < k: + selected_similarities = similarity_matrix[:, selected_indices] + max_sim_to_selected = selected_similarities.max(axis=1) + + mmr_scores = ( + self.mmr_lambda * np_scores + - (1 - self.mmr_lambda) * max_sim_to_selected + ) + mmr_scores[selected_indices] = -np.inf # Exclude already selected documents + + max_mmr_index = mmr_scores.argmax() + selected_indices.append(max_mmr_index) + remaining_indices.remove(max_mmr_index) + + return [texts[i] for i in selected_indices], [ + scores[i] for i in selected_indices + ] + + +class NumpyVectorStore(VectorStore): + texts: list[Embeddable] = [] + _embeddings_matrix: np.ndarray | None = None + + def clear(self) -> None: + self.texts = [] + self._embeddings_matrix = None + + def add_texts_and_embeddings( + self, + texts: Sequence[Embeddable], + ) -> None: + self.texts.extend(texts) + self._embeddings_matrix = np.array([t.embedding for t in self.texts]) + + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + if len(self.texts) == 0: + return [], [] + np_query = np.array( + (await self.embedding_model.embed_documents(client, [query]))[0] + ) + similarity_scores = cosine_similarity( + np_query.reshape(1, -1), self._embeddings_matrix + )[0] + similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) + sorted_indices = np.argsort(similarity_scores)[::-1] + return ( + [self.texts[i] for i in sorted_indices[:k]], + [similarity_scores[i] for i in sorted_indices[:k]], + ) + + +# All the langchain stuff is below +# Many confusing woes here because langchain +# is not serializable and so we have to +# do some gymnastics to make it work + + +class LangchainLLMModel(LLMModel): + """A wrapper around the wrapper langchain""" + + name: str = "langchain" + + def infer_llm_type(self, client: Any) -> str: + from langchain_core.language_models.chat_models import BaseChatModel + + self.name = client.model_name + if isinstance(client, BaseChatModel): + return "chat" + return "completion" + + async def acomplete(self, client: Any, prompt: str) -> str: + return await client.ainvoke(prompt) + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + async for chunk in cast(AsyncGenerator, client.astream(prompt)): + yield chunk + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + return (await client.ainvoke(lc_messages)).content + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + async for chunk in client.astream(lc_messages): + yield chunk.content + + +class LangchainEmbeddingModel(EmbeddingModel): + """A wrapper around the wrapper langchain""" + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + return await client.aembed_documents(texts) + + +class LangchainVectorStore(VectorStore): + """A wrapper around the wrapper langchain + + Note that if you this is cleared (e.g., by `Docs` having `jit_texts_index` set to True), + this will calls the `from_texts` class method on the `store`. This means that any non-default + constructor arguments will be lost. You can override the clear method on this class. + """ + + _store_builder: Any | None = None + _store: Any | None = None + # JIT Generics - store the class type (Doc or Text) + class_type: Type[Embeddable] = Field(default=Embeddable) + model_config = ConfigDict(extra="forbid") + + def __init__(self, **data): + # we have to separate out store from the rest of the data + # because langchain objects are not serializable + store_builder = None + if "store_builder" in data: + store_builder = LangchainVectorStore.check_store_builder( + data.pop("store_builder") + ) + if "cls" in data and "embedding_model" in data: + # make a little closure + cls = data.pop("cls") + embedding_model = data.pop("embedding_model") + + def candidate(x, y): + return cls.from_embeddings(x, embedding_model, y) + + store_builder = LangchainVectorStore.check_store_builder(candidate) + super().__init__(**data) + self._store_builder = store_builder + + @classmethod + def check_store_builder(cls, builder: Any) -> Any: + # check it is a callable + if not callable(builder): + raise ValueError("store_builder must be callable") + # check it takes two arguments + # we don't use type hints because it could be + # a partial + sig = signature(builder) + if len(sig.parameters) != 2: + raise ValueError("store_builder must take two arguments") + return builder + + def __getstate__(self): + state = super().__getstate__() + # remove non-serializable private attributes + del state["__pydantic_private__"]["_store"] + del state["__pydantic_private__"]["_store_builder"] + return state + + def __setstate__(self, state): + # restore non-serializable private attributes + state["__pydantic_private__"]["_store"] = None + state["__pydantic_private__"]["_store_builder"] = None + super().__setstate__(state) + + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: + if self._store_builder is None: + raise ValueError("You must set store_builder before adding texts") + self.class_type = type(texts[0]) + if self.class_type == Text: + vec_store_text_and_embeddings = list( + map(lambda x: (x.text, x.embedding), cast(list[Text], texts)) + ) + elif self.class_type == Doc: + vec_store_text_and_embeddings = list( + map(lambda x: (x.citation, x.embedding), cast(list[Doc], texts)) + ) + else: + raise ValueError("Only embeddings of type Text are supported") + if self._store is None: + self._store = self._store_builder( # type: ignore + vec_store_text_and_embeddings, + texts, + ) + if self._store is None or not hasattr(self._store, "add_embeddings"): + raise ValueError("store_builder did not return a valid vectorstore") + self._store.add_embeddings( # type: ignore + vec_store_text_and_embeddings, + metadatas=texts, + ) + + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + if self._store is None: + return [], [] + results = await self._store.asimilarity_search_with_relevance_scores(query, k=k) + texts, scores = [self.class_type(**r[0].metadata) for r in results], [ + r[1] for r in results + ] + return texts, scores + + def clear(self) -> None: + del self._store # be explicit, because it could be large + self._store = None + + +def get_score(text: str) -> int: + # check for N/A + last_line = text.split("\n")[-1] + if "N/A" in last_line or "n/a" in last_line or "NA" in last_line: + return 0 + score = re.search(r"[sS]core[:is\s]+([0-9]+)", text) + if not score: + score = re.search(r"\(([0-9])\w*\/", text) + if not score: + score = re.search(r"([0-9]+)\w*\/", text) + if score: + s = int(score.group(1)) + if s > 10: + s = int(s / 10) # sometimes becomes out of 100 + return s + last_few = text[-15:] + scores = re.findall(r"([0-9]+)", last_few) + if scores: + s = int(scores[-1]) + if s > 10: + s = int(s / 10) # sometimes becomes out of 100 + return s + if len(text) < 100: + return 1 + return 5 diff --git a/paperqa/prompts.py b/paperqa/prompts.py index a13ddb6d..83c4d877 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -1,37 +1,37 @@ -from langchain.prompts import PromptTemplate - -summary_prompt = PromptTemplate( - input_variables=["text", "citation", "question", "summary_length"], - template="Summarize the text below to help answer a question. " - "Do not directly answer the question, instead summarize " - "to give evidence to help answer the question. " - "Focus on specific details, including numbers, equations, or specific quotes. " - 'Reply "Not applicable" if text is irrelevant. ' - "Use {summary_length}. At the end of your response, provide a score from 1-10 on a newline " - "indicating relevance to question. Do not explain your score. " - "\n\n" - "{text}\n\n" - "Excerpt from {citation}\n" - "Question: {question}\n" - "Relevant Information Summary:", +summary_prompt = ( + "Summarize the excerpt below to help answer a question.\n\n" + "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n" + "Question: {question}\n\n" + "Do not directly answer the question, instead summarize to give evidence to help " + "answer the question. Stay detailed; report specific numbers, equations, or " + 'direct quotes (marked with quotation marks). Reply "Not applicable" if the ' + "excerpt is irrelevant. At the end of your response, provide an integer score " + "from 1-10 on a newline indicating relevance to question. Do not explain your score." + "\n\nRelevant Information Summary ({summary_length}):" ) -qa_prompt = PromptTemplate( - input_variables=["context", "answer_length", "question"], - template="Write an answer ({answer_length}) " - "for the question below based on the provided context. " - "If the context provides insufficient information and the question cannot be directly answered, " - 'reply "I cannot answer". ' - "For each part of your answer, indicate which sources most support it " - "via valid citation markers at the end of sentences, like (Example2012). \n" - "Context (with relevance scores):\n {context}\n" - "Question: {question}\n" - "Answer: ", +qa_prompt = ( + "Answer the question below with the context.\n\n" + "Context (with relevance scores):\n\n{context}\n\n----\n\n" + "Question: {question}\n\n" + "Write an answer based on the context. " + "If the context provides insufficient information and " + "the question cannot be directly answered, reply " + '"I cannot answer."' + "For each part of your answer, indicate which sources most support " + "it via citation keys at the end of sentences, " + "like (Example2012Example pages 3-4). Only cite from the context " + "below and only use the valid keys. Write in the style of a " + "Wikipedia article, with concise sentences and coherent paragraphs. " + "The context comes from a variety of sources and is only a summary, " + "so there may inaccuracies or ambiguities. If quotes are present and " + "relevant, use them in the answer. This answer will go directly onto " + "Wikipedia, so do not add any extraneous information.\n\n" + "Answer ({answer_length}):" ) -select_paper_prompt = PromptTemplate( - input_variables=["question", "papers"], - template="Select papers that may help answer the question below. " +select_paper_prompt = ( + "Select papers that may help answer the question below. " "Papers are listed as $KEY: $PAPER_INFO. " "Return a list of keys, separated by commas. " 'Return "None", if no papers are applicable. ' @@ -39,16 +39,13 @@ "(if the question requires timely information). \n\n" "Question: {question}\n\n" "Papers: {papers}\n\n" - "Selected keys:", + "Selected keys:" ) - -# We are unable to serialize with partial variables -# so TODO: update year next year -citation_prompt = PromptTemplate( - input_variables=["text"], - template="Provide the citation for the following text in MLA Format. The year is 2023\n" +citation_prompt = ( + "Provide the citation for the following text in MLA Format. " + "If reporting date accessed, the current year is 2024\n\n" "{text}\n\n" - "Citation:", + "Citation:" ) default_system_prompt = ( diff --git a/paperqa/readers.py b/paperqa/readers.py index 1b977306..f2363086 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -1,8 +1,9 @@ +from math import ceil from pathlib import Path from typing import List +import tiktoken from html2text import html2text -from langchain.text_splitter import TokenTextSplitter from .types import Doc, Text @@ -31,7 +32,7 @@ def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List ) split = split[chunk_chars - overlap :] pages = [str(i + 1)] - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: pg = "-".join([pages[0], pages[-1]]) texts.append( Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc) @@ -64,7 +65,7 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text ) split = split[chunk_chars - overlap :] pages = [str(i + 1)] - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: pg = "-".join([pages[0], pages[-1]]) texts.append( Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc) @@ -76,6 +77,12 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text def parse_txt( path: Path, doc: Doc, chunk_chars: int, overlap: int, html: bool = False ) -> List[Text]: + """Parse a document into chunks, based on tiktoken encoding. + + NOTE: We get some byte continuation errors. + Currnetly ignored, but should explore more to make sure we + don't miss anything. + """ try: with open(path) as f: text = f.read() @@ -84,13 +91,32 @@ def parse_txt( text = f.read() if html: text = html2text(text) - # yo, no idea why but the texts are not split correctly - text_splitter = TokenTextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) - raw_texts = text_splitter.split_text(text) - texts = [ - Text(text=t, name=f"{doc.docname} chunk {i}", doc=doc) - for i, t in enumerate(raw_texts) - ] + texts: list[Text] = [] + # we tokenize using tiktoken so cuts are in reasonable places + # See https://github.com/openai/tiktoken + enc = tiktoken.get_encoding("cl100k_base") + encoded = enc.encode_ordinary(text) + split = [] + # convert from characters to chunks + char_count = len(text) # e.g., 25,000 + token_count = len(encoded) # e.g., 4,500 + chars_per_token = char_count / token_count # e.g., 5.5 + chunk_tokens = chunk_chars / chars_per_token # e.g., 3000 / 5.5 = 545 + overlap_tokens = overlap / chars_per_token # e.g., 100 / 5.5 = 18 + chunk_count = ceil(token_count / chunk_tokens) # e.g., 4500 / 545 = 9 + for i in range(chunk_count): + split = encoded[ + max(int(i * chunk_tokens - overlap_tokens), 0) : int( + (i + 1) * chunk_tokens + overlap_tokens + ) + ] + texts.append( + Text( + text=enc.decode(split), + name=f"{doc.docname} chunk {i + 1}", + doc=doc, + ) + ) return texts @@ -104,7 +130,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List with open(path) as f: for i, line in enumerate(f): split += line - if len(split) > chunk_chars: + while len(split) > chunk_chars: texts.append( Text( text=split[:chunk_chars], @@ -114,7 +140,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List ) split = split[chunk_chars - overlap :] last_line = i - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: texts.append( Text( text=split[:chunk_chars], diff --git a/paperqa/types.py b/paperqa/types.py index f468d657..d259ac6e 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,18 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable -from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.prompts import PromptTemplate - -try: - from pydantic.v1 import BaseModel, validator -except ImportError: - from pydantic import BaseModel, validator - -import re +from pydantic import BaseModel, ConfigDict, Field, field_validator from .prompts import ( citation_prompt, @@ -21,75 +9,117 @@ select_paper_prompt, summary_prompt, ) -from .utils import extract_doi, iter_citations +# Just for clarity DocKey = Any -CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] -CallbackFactory = Callable[[str], Union[None, List[BaseCallbackHandler]]] +CallbackFactory = Callable[[str], list[Callable[[str], None]] | None] + + +class LLMResult(BaseModel): + text: str = "" + prompt_count: int = 0 + completion_count: int = 0 + model: str + date: str + seconds_to_first_token: float = 0 + seconds_to_last_token: float = 0 + + def __str__(self): + return self.text + + +class Embeddable(BaseModel): + embedding: list[float] | None = Field(default=None, repr=False) -class Doc(BaseModel): + +class Doc(Embeddable): docname: str citation: str dockey: DocKey -class Text(BaseModel): +class Text(Embeddable): text: str name: str doc: Doc - embeddings: Optional[List[float]] = None + + +# Mock a dictionary and store any missing items +class _FormatDict(dict): + def __init__(self) -> None: + self.key_set: set[str] = set() + + def __missing__(self, key: str) -> str: + self.key_set.add(key) + return key + + +def get_formatted_variables(s: str) -> set[str]: + """Returns the set of variables implied by the format string""" + format_dict = _FormatDict() + s.format_map(format_dict) + return format_dict.key_set class PromptCollection(BaseModel): - summary: PromptTemplate = summary_prompt - qa: PromptTemplate = qa_prompt - select: PromptTemplate = select_paper_prompt - cite: PromptTemplate = citation_prompt - pre: Optional[PromptTemplate] = None - post: Optional[PromptTemplate] = None + summary: str = summary_prompt + qa: str = qa_prompt + select: str = select_paper_prompt + cite: str = citation_prompt + pre: str | None = None + post: str | None = None system: str = default_system_prompt skip_summary: bool = False - @validator("summary") - def check_summary(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset(set(summary_prompt.input_variables)): + @field_validator("summary") + @classmethod + def check_summary(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(summary_prompt)) + ): raise ValueError( - f"Summary prompt can only have variables: {summary_prompt.input_variables}" + f"Summary prompt can only have variables: {get_formatted_variables(summary_prompt)}" ) return v - @validator("qa") - def check_qa(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset(set(qa_prompt.input_variables)): + @field_validator("qa") + @classmethod + def check_qa(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(qa_prompt)) + ): raise ValueError( - f"QA prompt can only have variables: {qa_prompt.input_variables}" + f"QA prompt can only have variables: {get_formatted_variables(qa_prompt)}" ) return v - @validator("select") - def check_select(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset( - set(select_paper_prompt.input_variables) + @field_validator("select") + @classmethod + def check_select(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(select_paper_prompt)) ): raise ValueError( - f"Select prompt can only have variables: {select_paper_prompt.input_variables}" + f"Select prompt can only have variables: {get_formatted_variables(select_paper_prompt)}" ) return v - @validator("pre") - def check_pre(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]: + @field_validator("pre") + @classmethod + def check_pre(cls, v: str | None) -> str | None: if v is not None: - if set(v.input_variables) != set(["question"]): + if set(get_formatted_variables(v)) != set(["question"]): raise ValueError("Pre prompt must have input variables: question") return v - @validator("post") - def check_post(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]: + @field_validator("post") + @classmethod + def check_post(cls, v: str | None) -> str | None: if v is not None: # kind of a hack to get list of attributes in answer - attrs = [a.name for a in Answer.__fields__.values()] - if not set(v.input_variables).issubset(attrs): + attrs = set(Answer.model_fields.keys()) + if not set(get_formatted_variables(v)).issubset(attrs): raise ValueError(f"Post prompt must have input variables: {attrs}") return v @@ -113,18 +143,18 @@ class Answer(BaseModel): question: str answer: str = "" context: str = "" - contexts: List[Context] = [] + contexts: list[Context] = [] references: str = "" formatted_answer: str = "" - dockey_filter: Optional[Set[DocKey]] = None + dockey_filter: set[DocKey] | None = None summary_length: str = "about 100 words" answer_length: str = "about 100 words" - memory: Optional[str] = None - # these two below are for convenience - # and are not set. But you can set them - # if you want to use them. - cost: Optional[float] = None - token_counts: Optional[Dict[str, List[int]]] = None + memory: str | None = None + # just for convenience you can override this + cost: float | None = None + # key is model name, value is (prompt, completion) token counts + token_counts: dict[str, list[int]] = Field(default_factory=dict) + model_config = ConfigDict(extra="forbid") def __str__(self) -> str: """Return the answer as a string.""" @@ -138,87 +168,13 @@ def get_citation(self, name: str) -> str: raise ValueError(f"Could not find docname {name} in contexts") return doc.citation - def markdown(self) -> Tuple[str, str]: - """Return the answer with footnote style citations.""" - # example: This is an answer.[^1] - # [^1]: This the citation. - output = self.answer - refs: Dict[str, int] = dict() - index = 1 - for citation in iter_citations(self.answer): - compound = "" - strip = True - for c in re.split(",|;", citation): - c = c.strip("() ") - if c == "Extra background information": - continue - if c in refs: - compound += f"[^{refs[c]}]" - continue - # check if it is a citation - try: - self.get_citation(c) - except ValueError: - # not a citation - strip = False - continue - refs[c] = index - compound += f"[^{index}]" - index += 1 - if strip: - output = output.replace(citation, compound) - formatted_refs = "\n".join( - [ - f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})" - for r, i in refs.items() + def add_tokens(self, result: LLMResult): + """Update the token counts for the given result.""" + if result.model not in self.token_counts: + self.token_counts[result.model] = [ + result.prompt_count, + result.completion_count, ] - ) - # quick fix of space before period - output = output.replace(" .", ".") - return output, formatted_refs - - def combine_with(self, other: "Answer") -> "Answer": - """ - Combine this answer object with another, merging their context/answer. - """ - combined = Answer( - question=self.question + " / " + other.question, - answer=self.answer + " " + other.answer, - context=self.context + " " + other.context, - contexts=self.contexts + other.contexts, - references=self.references + " " + other.references, - formatted_answer=self.formatted_answer + " " + other.formatted_answer, - summary_length=self.summary_length, # Assuming the same summary_length for both - answer_length=self.answer_length, # Assuming the same answer_length for both - memory=self.memory if self.memory else other.memory, - cost=self.cost if self.cost else other.cost, - token_counts=self.merge_token_counts(self.token_counts, other.token_counts), - ) - # Handling dockey_filter if present in either of the Answer objects - if self.dockey_filter or other.dockey_filter: - combined.dockey_filter = ( - self.dockey_filter if self.dockey_filter else set() - ) | (other.dockey_filter if other.dockey_filter else set()) - return combined - - @staticmethod - def merge_token_counts( - counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]] - ) -> Optional[Dict[str, List[int]]]: - """ - Merge two dictionaries of token counts. - """ - if counts1 is None and counts2 is None: - return None - if counts1 is None: - return counts2 - if counts2 is None: - return counts1 - merged_counts = counts1.copy() - for key, values in counts2.items(): - if key in merged_counts: - merged_counts[key][0] += values[0] - merged_counts[key][1] += values[1] - else: - merged_counts[key] = values - return merged_counts + else: + self.token_counts[result.model][0] += result.prompt_count + self.token_counts[result.model][1] += result.completion_count diff --git a/paperqa/utils.py b/paperqa/utils.py index 6cb0d1a0..76105aa7 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -3,10 +3,9 @@ import re import string from pathlib import Path -from typing import BinaryIO, List, Union +from typing import Any, BinaryIO, Coroutine, Iterator, Union import pypdf -from langchain.base_language import BaseLanguageModel StrPath = Union[str, Path] @@ -76,7 +75,7 @@ def md5sum(file_path: StrPath) -> str: return hashlib.md5(f.read()).hexdigest() -async def gather_with_concurrency(n: int, *coros: List) -> List: +async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]: # https://stackoverflow.com/a/61478547/2392535 semaphore = asyncio.Semaphore(n) @@ -93,13 +92,6 @@ def guess_is_4xx(msg: str) -> bool: return False -def get_llm_name(llm: BaseLanguageModel) -> str: - try: - return llm.model_name # type: ignore - except AttributeError: - return llm.model # type: ignore - - def strip_citations(text: str) -> str: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" @@ -108,7 +100,7 @@ def strip_citations(text: str) -> str: return text -def iter_citations(text: str) -> List[str]: +def iter_citations(text: str) -> list[str]: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" result = re.findall(citation_regex, text, flags=re.MULTILINE) @@ -131,3 +123,26 @@ def extract_doi(reference: str) -> str: return "https://doi.org/" + doi_match.group() else: return "" + + +def batch_iter(iterable: list, n: int = 1) -> Iterator[list]: + """ + Batch an iterable into chunks of size n + + :param iterable: The iterable to batch + :param n: The size of the batches + :return: A list of batches + """ + length = len(iterable) + for ndx in range(0, length, n): + yield iterable[ndx : min(ndx + n, length)] + + +def flatten(iteratble: list) -> list: + """ + Flatten a list of lists + + :param l: The list of lists to flatten + :return: A flattened list + """ + return [item for sublist in iteratble for item in sublist] diff --git a/paperqa/version.py b/paperqa/version.py index 5de4250b..16763f33 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "3.13.5" +__version__ = "4.0.0-pre.1" diff --git a/setup.py b/setup.py index d2463365..a69571ff 100644 --- a/setup.py +++ b/setup.py @@ -18,10 +18,10 @@ packages=["paperqa", "paperqa.contrib"], install_requires=[ "pypdf", - "pydantic<2", - "langchain>=0.0.303", - "openai <1", - "faiss-cpu", + "pydantic>=2", + "openai>=1", + "numpy", + "nest-asyncio", "PyCryptodome", "html2text", "tiktoken>=0.4.0", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 75526880..57d9a417 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1,18 +1,22 @@ import os import pickle from io import BytesIO -from typing import Any from unittest import IsolatedAsyncioTestCase import numpy as np import requests -from langchain.callbacks.base import AsyncCallbackHandler -from langchain.llms import OpenAI -from langchain.llms.fake import FakeListLLM -from langchain.prompts import PromptTemplate - -from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text -from paperqa.chains import get_score +from openai import AsyncOpenAI + +from paperqa import Answer, Doc, Docs, NumpyVectorStore, PromptCollection, Text +from paperqa.llms import ( + EmbeddingModel, + LangchainEmbeddingModel, + LangchainLLMModel, + LangchainVectorStore, + LLMModel, + OpenAILLMModel, + get_score, +) from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -24,11 +28,6 @@ ) -class TestHandler(AsyncCallbackHandler): - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - print(token) - - def test_iter_citations(): text = ( "Yes, COVID-19 vaccines are effective. Various studies have documented the " @@ -127,68 +126,17 @@ def test_citations_with_nonstandard_chars(): ) -def test_markdown(): - answer = Answer( - question="What was Fredic's greatest accomplishment?", - answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes " - "and his service as governor of Missouri (Wiki2023 chunk 1, Wiki2023 chunk 2). It is said (in 2010) that foo." - "However many dispute this (Wiki2023 chunk 1).", - contexts=[ - Context( - context="", - text=Text( - text="Frederick Bates's greatest accomplishment was his role in resolving land disputes " - "and his service as governor of Missouri.", - name="Wiki2023 chunk 1", - doc=Doc( - name="Wiki2023", - docname="Wiki2023", - citation="WikiMedia Foundation, 2023, Accessed now", - texts=[], - ), - ), - score=5, - ), - Context( - context="", - text=Text( - text="It is said (in 2010) that foo.", - name="Wiki2023 chunk 2", - doc=Doc( - name="Wiki2023", - docname="Wiki2023", - citation="WikiMedia Foundation, 2023, Accessed now", - texts=[], - ), - ), - score=5, - ), - ], - ) - m, r = answer.markdown() - assert len(r.split("\n")) == 2 - assert "[^2]" in m - assert "[^3]" not in m - assert "[^1]" in m - print(m, r) - answer = answer.combine_with(answer) - m2, r2 = answer.markdown() - assert m2.startswith(m) - assert r2 == r - - def test_ablations(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs() + docs = Docs(prompts=PromptCollection(skip_summary=True)) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.get_evidence( Answer( question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" + "chemistry because it can accurately model non-linear structure-function relationships.' on?" - ), - disable_summarization=True, + ) ) assert ( answer.contexts[0].text.text == answer.contexts[0].context @@ -419,24 +367,66 @@ def test_extract_score(): assert get_score(sample) == 9 +class TestChains(IsolatedAsyncioTestCase): + async def test_chain_completion(self): + client = AsyncOpenAI() + llm = OpenAILLMModel(config=dict(model="babbage-002", temperature=0.2)) + call = llm.make_chain( + client, + "The {animal} says", + skip_system=True, + ) + outputs = [] + + def accum(x): + outputs.append(x) + + completion = await call(dict(animal="duck"), callbacks=[accum]) + assert completion.seconds_to_first_token > 0 + assert completion.prompt_count > 0 + assert completion.completion_count > 0 + assert str(completion) == "".join(outputs) + + completion = await call(dict(animal="duck")) + assert completion.seconds_to_first_token == 0 + assert completion.seconds_to_last_token > 0 + + async def test_chain_chat(self): + client = AsyncOpenAI() + llm = OpenAILLMModel( + config=dict(temperature=0, model="gpt-3.5-turbo", max_tokens=56) + ) + call = llm.make_chain( + client, + "The {animal} says", + skip_system=True, + ) + outputs = [] + + def accum(x): + outputs.append(x) + + completion = await call(dict(animal="duck"), callbacks=[accum]) + assert completion.seconds_to_first_token > 0 + assert completion.prompt_count > 0 + assert completion.completion_count > 0 + assert str(completion) == "".join(outputs) + + completion = await call(dict(animal="duck")) + assert completion.seconds_to_first_token == 0 + assert completion.seconds_to_last_token > 0 + + def test_docs(): - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + docs = Docs(llm="babbage-002") docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) assert docs.docs["test"].docname == "Wiki2023" - - -def test_update_llm(): - doc = Docs() - doc.update_llm("gpt-3.5-turbo") - assert doc.llm == doc.summary_llm - - doc.update_llm(OpenAI(client=None, temperature=0.1, model="text-ada-001")) - assert doc.llm == doc.summary_llm + assert docs.llm == "babbage-002" + assert docs.summary_llm == "babbage-002" def test_evidence(): @@ -488,6 +478,241 @@ def test_duplicate(): ) +def test_custom_embedding(): + class MyEmbeds(EmbeddingModel): + async def embed_documents(self, client, texts): + return [[1, 2, 3] for _ in texts] + + docs = Docs( + docs_index=NumpyVectorStore(embedding_model=MyEmbeds()), + texts_index=NumpyVectorStore(embedding_model=MyEmbeds()), + embedding_client=None, + ) + assert docs._embedding_client is None + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + assert docs.docs["test"].embedding == [1, 2, 3] + + +def test_custom_llm(): + class MyLLM(LLMModel): + name: str = "myllm" + + async def acomplete(self, client, prompt): + assert client is None + return "Echo" + + docs = Docs(llm_model=MyLLM(), client=None) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + evidence = docs.get_evidence(Answer(question="Echo")) + assert "Echo" in evidence.context + + +def test_custom_llm_stream(): + class MyLLM(LLMModel): + name: str = "myllm" + + async def acomplete_iter(self, client, prompt): + assert client is None + yield "Echo" + + docs = Docs(llm_model=MyLLM(), client=None) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + evidence = docs.get_evidence( + Answer(question="Echo"), get_callbacks=lambda x: [lambda y: print(y, end="")] + ) + assert "Echo" in evidence.context + + +def test_langchain_llm(): + from langchain_openai import ChatOpenAI, OpenAI + + docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo")) + assert type(docs.llm_model) == LangchainLLMModel + assert type(docs.summary_llm_model) == LangchainLLMModel + assert docs.llm == "gpt-3.5-turbo" + assert docs.summary_llm == "gpt-3.5-turbo" + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + assert docs._client is not None + assert type(docs.llm_model) == LangchainLLMModel + assert docs.summary_llm_model == docs.llm_model + + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y, end="")], + ) + + assert docs.llm_model.llm_type == "chat" + + # trying without callbacks (different codepath) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + + # now completion + + docs = Docs(llm_model=LangchainLLMModel(), client=OpenAI(model="babbage-002")) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y, end="")], + ) + + assert docs.summary_llm_model.llm_type == "completion" + + # trying without callbacks (different codepath) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + + # now make sure we can pickle it + docs_pickle = pickle.dumps(docs) + docs2 = pickle.loads(docs_pickle) + assert docs2._client is None + assert docs2.llm == "babbage-002" + docs2.set_client(OpenAI(model="babbage-002")) + assert docs2.summary_llm == "babbage-002" + docs2.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y)], + ) + + +def test_langchain_embeddings(): + from langchain_openai import OpenAIEmbeddings + + docs = Docs( + texts_index=NumpyVectorStore(embedding_model=LangchainEmbeddingModel()), + docs_index=NumpyVectorStore(embedding_model=LangchainEmbeddingModel()), + embedding_client=OpenAIEmbeddings(), + ) + assert docs._embedding_client is not None + + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + docs = Docs(embedding="langchain", embedding_client=OpenAIEmbeddings()) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + + +class TestVectorStore(IsolatedAsyncioTestCase): + async def test_langchain_vector_store(self): + from langchain_community.vectorstores.faiss import FAISS + from langchain_openai import OpenAIEmbeddings + + some_texts = [ + Text( + embedding=OpenAIEmbeddings().embed_query("test"), + text="this is a test", + name="test", + doc=Doc(docname="test", citation="test", dockey="test"), + ) + ] + + # checks on builder + try: + index = LangchainVectorStore() + index.add_texts_and_embeddings(some_texts) + raise "Failed to check for builder" + except ValueError: + pass + + try: + index = LangchainVectorStore(store_builder=lambda x: None) + raise "Failed to count arguments" + except ValueError: + pass + + try: + index = LangchainVectorStore(store_builder="foo") + raise "Failed to check if builder is callable" + except ValueError: + pass + + # now with real builder + index = LangchainVectorStore( + store_builder=lambda x, y: FAISS.from_embeddings(x, OpenAIEmbeddings(), y) + ) + assert index._store is None + index.add_texts_and_embeddings(some_texts) + assert index._store is not None + # check search returns Text obj + data, score = await index.similarity_search(None, "test", k=1) + print(data) + assert type(data[0]) == Text + + # now try with convenience + index = LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()) + assert index._store is None + index.add_texts_and_embeddings(some_texts) + assert index._store is not None + + docs = Docs( + texts_index=LangchainVectorStore( + cls=FAISS, embedding_model=OpenAIEmbeddings() + ) + ) + assert docs._embedding_client is not None # from docs_index default + + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + # should be embedded + + # now try with JIT + docs = Docs(texts_index=index, jit_texts_index=True) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + # should get cleared and rebuilt here + ev = docs.get_evidence( + answer=Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + assert len(ev.context) > 0 + # now with dockkey filter + docs.get_evidence( + answer=Answer( + question="What is Frederick Bates's greatest accomplishment?", + dockey_filter=["test"], + ) + ) + + # make sure we can pickle it + docs_pickle = pickle.dumps(docs) + pickle.loads(docs_pickle) + + # will not work at this point - have to reset index + + class Test(IsolatedAsyncioTestCase): async def test_aquery(self): docs = Docs() @@ -511,7 +736,6 @@ async def test_adoc_match(self): "What is Frederick Bates's greatest accomplishment?" ) assert len(sources) > 0 - docs.update_llm("gpt-3.5-turbo") sources = await docs.adoc_match( "What is Frederick Bates's greatest accomplishment?" ) @@ -524,15 +748,25 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) + ) + assert docs._client is not None + old_config = docs.llm_model.config + old_sconfig = docs.summary_llm_model.config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert llm.model_name == docs2.llm.model_name - assert docs2.summary_llm.model_name == docs2.llm.model_name + # make sure it fails if we haven't set client + try: + docs2.query("What date is bring your dog to work in the US?") + except ValueError: + pass + docs2.set_client() + assert docs2._client is not None + assert docs2.llm_model.config == old_config + assert docs2.summary_llm_model.config == old_sconfig assert len(docs.docs) == len(docs2.docs) context1, context2 = ( docs.get_evidence( @@ -557,45 +791,6 @@ def test_docs_pickle(): docs.query("What date is bring your dog to work in the US?") -def test_docs_pickle_no_faiss(): - doc_path = "example.html" - with open(doc_path, "w", encoding="utf-8") as f: - # get front page of wikipedia - r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") - f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) - docs.doc_index = None - docs.texts_index = None - docs_pickle = pickle.dumps(docs) - docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert len(docs.docs) == len(docs2.docs) - assert ( - strings_similarity( - docs.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - docs2.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - ) - > 0.75 - ) - os.remove(doc_path) - - def test_bad_context(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: @@ -615,7 +810,9 @@ def test_repeat_keys(): # get wiki page about politician r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") f.write(r.text) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + ) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") try: docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") @@ -644,9 +841,9 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs(llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-4"))) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") - answer = docs.query("Are counterfactuals actionable?") + answer = docs.query("Are counterfactuals actionable? [yes/no]") assert "yes" in answer.answer or "Yes" in answer.answer @@ -654,15 +851,15 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs() docs.add_file(f, "Wellawatte et al, XAI Review, 2023") - answer = docs.query("Are counterfactuals actionable?") + answer = docs.query("Are counterfactuals actionable?[yes/no]") assert "yes" in answer.answer or "Yes" in answer.answer def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs(llm=OpenAI(client=None, temperature=0.0)) + docs = Docs() r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") @@ -672,7 +869,7 @@ def test_fileio_reader_txt(): chunk_chars=1000, ) answer = docs.query("What country was Frederick Bates born in?") - assert "Virginia" in answer.answer + assert "United States" in answer.answer def test_pdf_pypdf_reader(): @@ -712,7 +909,9 @@ def test_prompt_length(): def test_code(): # load this script doc_path = os.path.abspath(__file__) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + ) docs.add(doc_path, "test_paperqa.py", docname="test_paperqa.py", disable_check=True) assert len(docs.docs) == 1 docs.query("What function tests the preview?") @@ -727,8 +926,10 @@ def test_citation(): docs = Docs() docs.add(doc_path) assert ( - list(docs.docs.values())[0].docname == "Wikipedia2023" - or list(docs.docs.values())[0].docname == "Frederick2023" + list(docs.docs.values())[0].docname == "Wikipedia2024" + or list(docs.docs.values())[0].docname == "Frederick2024" + or list(docs.docs.values())[0].docname == "Wikipedia" + or list(docs.docs.values())[0].docname == "Frederick" ) @@ -746,7 +947,7 @@ def test_dockey_filter(): f.write(r.text) f.write("\n") # so we don't have same hash docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", dockey="test") - answer = Answer(question="What country is Bates from?", key_filter=["test"]) + answer = Answer(question="What country is Bates from?", dockey_filter=["test"]) docs.get_evidence(answer) @@ -763,18 +964,20 @@ def test_dockey_delete(): with open("example.txt", "w", encoding="utf-8") as f: f.write(r.text) f.write("\n\nBates could be from Angola") # so we don't have same hash - docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", dockey="test") + docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", docname="test") answer = Answer(question="What country was Bates born in?") - answer = docs.get_evidence(answer, marginal_relevance=False) - print(answer) + answer = docs.get_evidence( + answer, max_sources=25, k=30 + ) # we just have a lot so we get both docs keys = set([c.text.doc.dockey for c in answer.contexts]) assert len(keys) == 2 assert len(docs.docs) == 2 - docs.delete(dockey="test") - assert len(docs.docs) == 1 + docs.delete(docname="test") answer = Answer(question="What country was Bates born in?") - answer = docs.get_evidence(answer, marginal_relevance=False) + assert len(docs.docs) == 1 + assert len(docs.deleted_dockeys) == 1 + answer = docs.get_evidence(answer, max_sources=25, k=30) keys = set([c.text.doc.dockey for c in answer.contexts]) assert len(keys) == 1 @@ -800,19 +1003,6 @@ def test_query_filter(): # the filter shouldn't trigger, so just checking that it doesn't crash -def test_nonopenai_client(): - responses = ["This is a test", "This is another test"] * 50 - model = FakeListLLM(responses=responses) - doc_path = "example.txt" - with open(doc_path, "w", encoding="utf-8") as f: - # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") - f.write(r.text) - docs = Docs(llm=model) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") - docs.query("What country is Bates from?") - - def test_zotera(): from paperqa.contrib import ZoteroDB @@ -846,11 +1036,10 @@ def test_too_much_evidence(): def test_custom_prompts(): - my_qaprompt = PromptTemplate( - input_variables=["question", "context"], - template="Answer the question '{question}' " + my_qaprompt = ( + "Answer the question '{question}' " "using the country name alone. For example: " - "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: ", + "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: " ) docs = Docs(prompts=PromptCollection(qa=my_qaprompt)) @@ -862,16 +1051,11 @@ def test_custom_prompts(): f.write(r.text) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") answer = docs.query("What country is Frederick Bates from?") - print(answer.answer) assert "United States" in answer.answer def test_pre_prompt(): - pre = PromptTemplate( - input_variables=["question"], - template="Provide context you have memorized " - "that could help answer '{question}'. ", - ) + pre = "Provide context you have memorized " "that could help answer '{question}'. " docs = Docs(prompts=PromptCollection(pre=pre)) @@ -885,13 +1069,12 @@ def test_pre_prompt(): def test_post_prompt(): - post = PromptTemplate( - input_variables=["question", "answer"], - template="We are trying to answer the question below " + post = ( + "We are trying to answer the question below " "and have an answer provided. " "Please edit the answer be extremely terse, with no extra words or formatting" "with no extra information.\n\n" - "Q: {question}\nA: {answer}\n\n", + "Q: {question}\nA: {answer}\n\n" ) docs = Docs(prompts=PromptCollection(post=post)) @@ -943,8 +1126,8 @@ def disabled_test_memory(): def test_add_texts(): - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="babbage-02") + docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", @@ -954,17 +1137,17 @@ def test_add_texts(): docs2 = Docs() texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) for t1, t2 in zip(docs2.texts, docs.texts): assert t1.text == t2.text - assert np.allclose(t1.embeddings, t2.embeddings, atol=1e-3) + assert np.allclose(t1.embedding, t2.embedding, atol=1e-3) docs2._build_texts_index() # now do it again to test after text index is already built - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="babbage-02") + docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -973,7 +1156,7 @@ def test_add_texts(): texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) assert len(docs2.docs) == 2 @@ -988,7 +1171,7 @@ def test_external_doc_index(): dockey="test", ) evidence = docs.query(query="What is the date of flag day?", key_filter=True) - docs2 = Docs(doc_index=docs.doc_index, texts_index=docs.texts_index) + docs2 = Docs(docs_index=docs.docs_index, texts_index=docs.texts_index) assert len(docs2.docs) == 0 evidence = docs2.query("What is the date of flag day?", key_filter=True) assert "February 15" in evidence.context @@ -1001,6 +1184,7 @@ def test_external_texts_index(): citation="Flag Day of Canada, WikiMedia Foundation, 2023, Accessed now", ) answer = docs.query(query="On which date is flag day annually observed?") + print(answer.model_dump()) assert "February 15" in answer.answer docs.add_url(