From 89d3eabc28fb6804d08625db6b5da37fb913cee4 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 23 Jan 2024 13:17:01 -0800 Subject: [PATCH] Removed nest_asyncio (#227) * Removed nest_asyncio * Updated README for sync notes --- README.md | 26 ++++++++++++ paperqa/docs.py | 97 ++++++++++++++++++++++++++++++++++--------- paperqa/llms.py | 21 +++++++--- paperqa/utils.py | 18 ++++++++ paperqa/version.py | 2 +- setup.py | 1 - tests/test_paperqa.py | 18 +++++--- 7 files changed, 151 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index b4511407e..e0188f814 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,32 @@ print(answer.formatted_answer) The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question` , and `context` (the summaries of passages found for answer). +### Async + +paper-qa is written to be used asynchronously. The synchronous API is just a wrapper around the async. Here are the methods and their async equivalents: + +| Sync | Async | +| --- | --- | +| `Docs.add` | `Docs.aadd` | +| `Docs.add_file` | `Docs.aadd_file` | +| `Docs.add_url` | `Docs.add_url` | +| `Docs.get_evidence` | `Docs.aget_evidence` | +| `Docs.query` | `Docs.aquery` | + +The synchronous version just call the async version in a loop. Most modern python environments support async natively (including Jupyter notebooks!). So you can do this in a Jupyter Notebook: + +```py +from paperqa import Docs + +my_docs = ...# get a list of paths + +docs = Docs() +for d in my_docs: + await docs.aadd(d) + +answer = await docs.aquery("What manufacturing challenges are unique to bispecific antibodies?") +``` + ### Adding Documents `add` will add from paths. You can also use `add_file` (expects a file object) or `add_url` to work with other sources. diff --git a/paperqa/docs.py b/paperqa/docs.py index bb0b026bf..a0100b385 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -1,5 +1,3 @@ -import nest_asyncio # isort:skip -import asyncio import os import re import tempfile @@ -35,6 +33,7 @@ ) from .utils import ( gather_with_concurrency, + get_loop, guess_is_4xx, maybe_is_html, maybe_is_pdf, @@ -44,9 +43,6 @@ strip_citations, ) -# Apply the patch to allow nested loops -nest_asyncio.apply() - class Docs(BaseModel): """A collection of documents to be used for answering questions.""" @@ -251,6 +247,25 @@ def add_file( docname: str | None = None, dockey: DocKey | None = None, chunk_chars: int = 3000, + ) -> str | None: + loop = get_loop() + return loop.run_until_complete( + self.aadd_file( + file, + citation=citation, + docname=docname, + dockey=dockey, + chunk_chars=chunk_chars, + ) + ) + + async def aadd_file( + self, + file: BinaryIO, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, + chunk_chars: int = 3000, ) -> str | None: """Add a document to the collection.""" # just put in temp file and use existing method @@ -263,7 +278,7 @@ def add_file( with tempfile.NamedTemporaryFile(suffix=suffix) as f: f.write(file.read()) f.seek(0) - return self.add( + return await self.aadd( Path(f.name), citation=citation, docname=docname, @@ -278,6 +293,25 @@ def add_url( docname: str | None = None, dockey: DocKey | None = None, chunk_chars: int = 3000, + ) -> str | None: + loop = get_loop() + return loop.run_until_complete( + self.aadd_url( + url, + citation=citation, + docname=docname, + dockey=dockey, + chunk_chars=chunk_chars, + ) + ) + + async def aadd_url( + self, + url: str, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, + chunk_chars: int = 3000, ) -> str | None: """Add a document to the collection.""" import urllib.request @@ -285,7 +319,7 @@ def add_url( with urllib.request.urlopen(url) as f: # need to wrap to enable seek file = BytesIO(f.read()) - return self.add_file( + return await self.aadd_file( file, citation=citation, docname=docname, @@ -301,6 +335,27 @@ def add( disable_check: bool = False, dockey: DocKey | None = None, chunk_chars: int = 3000, + ) -> str | None: + loop = get_loop() + return loop.run_until_complete( + self.aadd( + path, + citation=citation, + docname=docname, + disable_check=disable_check, + dockey=dockey, + chunk_chars=chunk_chars, + ) + ) + + async def aadd( + self, + path: Path, + citation: str | None = None, + docname: str | None = None, + disable_check: bool = False, + dockey: DocKey | None = None, + chunk_chars: int = 3000, ) -> str | None: """Add a document to the collection.""" if dockey is None: @@ -317,9 +372,7 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - chain_result = asyncio.run( - cite_chain(dict(text=texts[0].text), None), - ) + chain_result = await cite_chain(dict(text=texts[0].text), None) citation = chain_result.text if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -353,7 +406,7 @@ def add( raise ValueError( f"This does not look like a text document: {path}. Path disable_check to ignore this error." ) - if self.add_texts(texts, doc): + if await self.aadd_texts(texts, doc): return docname return None @@ -361,6 +414,14 @@ def add_texts( self, texts: list[Text], doc: Doc, + ) -> bool: + loop = get_loop() + return loop.run_until_complete(self.aadd_texts(texts, doc)) + + async def aadd_texts( + self, + texts: list[Text], + doc: Doc, ) -> bool: """Add chunked texts to the collection. This is useful if you have already chunked the texts yourself. @@ -376,16 +437,14 @@ def add_texts( t.name = t.name.replace(doc.docname, new_docname) doc.docname = new_docname if texts[0].embedding is None: - text_embeddings = asyncio.run( - self.texts_index.embedding_model.embed_documents( - self._embedding_client, [t.text for t in texts] - ) + text_embeddings = await self.texts_index.embedding_model.embed_documents( + self._embedding_client, [t.text for t in texts] ) for i, t in enumerate(texts): t.embedding = text_embeddings[i] if doc.embedding is None: - doc.embedding = asyncio.run( - self.docs_index.embedding_model.embed_documents( + doc.embedding = ( + await self.docs_index.embedding_model.embed_documents( self._embedding_client, [doc.citation] ) )[0] @@ -493,7 +552,7 @@ def get_evidence( detailed_citations: bool = False, disable_vector_search: bool = False, ) -> Answer: - return asyncio.run( + return get_loop().run_until_complete( self.aget_evidence( answer, k=k, @@ -641,7 +700,7 @@ def query( key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: - return asyncio.run( + return get_loop().run_until_complete( self.aquery( query, k=k, diff --git a/paperqa/llms.py b/paperqa/llms.py index dea6645c2..10b5c3998 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -21,7 +21,7 @@ from .prompts import default_system_prompt from .types import Doc, Embeddable, LLMResult, Text -from .utils import batch_iter, flatten, gather_with_concurrency +from .utils import batch_iter, flatten, gather_with_concurrency, is_coroutine_callable def guess_model_type(model_name: str) -> str: @@ -167,7 +167,7 @@ def make_chain( chat_prompt = [system_message_prompt, human_message_prompt] async def execute( - data: dict, callbacks: list[Callable[[str], None]] | None = None + data: dict, callbacks: list[Callable] | None = None ) -> LLMResult: start_clock = asyncio.get_running_loop().time() result = LLMResult( @@ -184,6 +184,10 @@ async def execute( if callbacks is None: output = await self.achat(client, messages) else: + sync_callbacks = [ + f for f in callbacks if not is_coroutine_callable(f) + ] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] completion = self.achat_iter(client, messages) # type: ignore text_result = [] async for chunk in completion: # type: ignore @@ -193,7 +197,8 @@ async def execute( asyncio.get_running_loop().time() - start_clock ) text_result.append(chunk) - [f(chunk) for f in callbacks] + [await f(chunk) for f in async_callbacks] + [f(chunk) for f in sync_callbacks] output = "".join(text_result) result.completion_count = self.count_tokens(output) result.text = output @@ -210,7 +215,7 @@ async def execute( completion_prompt = system_prompt + "\n\n" + prompt async def execute( - data: dict, callbacks: list[Callable[[str], None]] | None = None + data: dict, callbacks: list[Callable] | None = None ) -> LLMResult: start_clock = asyncio.get_running_loop().time() result = LLMResult( @@ -223,6 +228,11 @@ async def execute( if callbacks is None: output = await self.acomplete(client, formatted_prompt) else: + sync_callbacks = [ + f for f in callbacks if not is_coroutine_callable(f) + ] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + completion = self.acomplete_iter( # type: ignore client, formatted_prompt, @@ -235,7 +245,8 @@ async def execute( asyncio.get_running_loop().time() - start_clock ) text_result.append(chunk) - [f(chunk) for f in callbacks] + [await f(chunk) for f in async_callbacks] + [f(chunk) for f in sync_callbacks] output = "".join(text_result) result.completion_count = self.count_tokens(output) result.text = output diff --git a/paperqa/utils.py b/paperqa/utils.py index 8f2a9c600..5152a79b7 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -1,4 +1,5 @@ import asyncio +import inspect import math import re import string @@ -152,3 +153,20 @@ def flatten(iteratble: list) -> list: :return: A flattened list """ return [item for sublist in iteratble for item in sublist] + + +def get_loop() -> asyncio.AbstractEventLoop: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +def is_coroutine_callable(obj): + if inspect.isfunction(obj): + return inspect.iscoroutinefunction(obj) + elif callable(obj): + return inspect.iscoroutinefunction(obj.__call__) + return False diff --git a/paperqa/version.py b/paperqa/version.py index 16763f330..b98e98d99 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "4.0.0-pre.1" +__version__ = "4.0.0-pre.2" diff --git a/setup.py b/setup.py index a69571ff1..d712958e3 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,6 @@ "pydantic>=2", "openai>=1", "numpy", - "nest-asyncio", "PyCryptodome", "html2text", "tiktoken>=0.4.0", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 57d9a4175..6f8bece4b 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -416,6 +416,12 @@ def accum(x): assert completion.seconds_to_first_token == 0 assert completion.seconds_to_last_token > 0 + # check with mixed callbacks + async def ac(x): + pass + + completion = await call(dict(animal="duck"), callbacks=[accum, ac]) + def test_docs(): docs = Docs(llm="babbage-002") @@ -679,7 +685,7 @@ async def test_langchain_vector_store(self): ) assert docs._embedding_client is not None # from docs_index default - docs.add_url( + await docs.aadd_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", @@ -688,18 +694,18 @@ async def test_langchain_vector_store(self): # now try with JIT docs = Docs(texts_index=index, jit_texts_index=True) - docs.add_url( + await docs.aadd_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) # should get cleared and rebuilt here - ev = docs.get_evidence( + ev = await docs.aget_evidence( answer=Answer(question="What is Frederick Bates's greatest accomplishment?") ) assert len(ev.context) > 0 # now with dockkey filter - docs.get_evidence( + await docs.aget_evidence( answer=Answer( question="What is Frederick Bates's greatest accomplishment?", dockey_filter=["test"], @@ -716,7 +722,7 @@ async def test_langchain_vector_store(self): class Test(IsolatedAsyncioTestCase): async def test_aquery(self): docs = Docs() - docs.add_url( + await docs.aadd_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", @@ -727,7 +733,7 @@ async def test_aquery(self): class TestDocMatch(IsolatedAsyncioTestCase): async def test_adoc_match(self): docs = Docs() - docs.add_url( + await docs.aadd_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test",