Skip to content

Commit

Permalink
Removed nest_asyncio (#227)
Browse files Browse the repository at this point in the history
* Removed nest_asyncio

* Updated README for sync notes
  • Loading branch information
whitead authored Jan 23, 2024
1 parent 25d75de commit 89d3eab
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 32 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ print(answer.formatted_answer)

The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question` , and `context` (the summaries of passages found for answer).

### Async

paper-qa is written to be used asynchronously. The synchronous API is just a wrapper around the async. Here are the methods and their async equivalents:

| Sync | Async |
| --- | --- |
| `Docs.add` | `Docs.aadd` |
| `Docs.add_file` | `Docs.aadd_file` |
| `Docs.add_url` | `Docs.add_url` |
| `Docs.get_evidence` | `Docs.aget_evidence` |
| `Docs.query` | `Docs.aquery` |

The synchronous version just call the async version in a loop. Most modern python environments support async natively (including Jupyter notebooks!). So you can do this in a Jupyter Notebook:

```py
from paperqa import Docs

my_docs = ...# get a list of paths

docs = Docs()
for d in my_docs:
await docs.aadd(d)

answer = await docs.aquery("What manufacturing challenges are unique to bispecific antibodies?")
```

### Adding Documents

`add` will add from paths. You can also use `add_file` (expects a file object) or `add_url` to work with other sources.
Expand Down
97 changes: 78 additions & 19 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import nest_asyncio # isort:skip
import asyncio
import os
import re
import tempfile
Expand Down Expand Up @@ -35,6 +33,7 @@
)
from .utils import (
gather_with_concurrency,
get_loop,
guess_is_4xx,
maybe_is_html,
maybe_is_pdf,
Expand All @@ -44,9 +43,6 @@
strip_citations,
)

# Apply the patch to allow nested loops
nest_asyncio.apply()


class Docs(BaseModel):
"""A collection of documents to be used for answering questions."""
Expand Down Expand Up @@ -251,6 +247,25 @@ def add_file(
docname: str | None = None,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
loop = get_loop()
return loop.run_until_complete(
self.aadd_file(
file,
citation=citation,
docname=docname,
dockey=dockey,
chunk_chars=chunk_chars,
)
)

async def aadd_file(
self,
file: BinaryIO,
citation: str | None = None,
docname: str | None = None,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
"""Add a document to the collection."""
# just put in temp file and use existing method
Expand All @@ -263,7 +278,7 @@ def add_file(
with tempfile.NamedTemporaryFile(suffix=suffix) as f:
f.write(file.read())
f.seek(0)
return self.add(
return await self.aadd(
Path(f.name),
citation=citation,
docname=docname,
Expand All @@ -278,14 +293,33 @@ def add_url(
docname: str | None = None,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
loop = get_loop()
return loop.run_until_complete(
self.aadd_url(
url,
citation=citation,
docname=docname,
dockey=dockey,
chunk_chars=chunk_chars,
)
)

async def aadd_url(
self,
url: str,
citation: str | None = None,
docname: str | None = None,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
"""Add a document to the collection."""
import urllib.request

with urllib.request.urlopen(url) as f:
# need to wrap to enable seek
file = BytesIO(f.read())
return self.add_file(
return await self.aadd_file(
file,
citation=citation,
docname=docname,
Expand All @@ -301,6 +335,27 @@ def add(
disable_check: bool = False,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
loop = get_loop()
return loop.run_until_complete(
self.aadd(
path,
citation=citation,
docname=docname,
disable_check=disable_check,
dockey=dockey,
chunk_chars=chunk_chars,
)
)

async def aadd(
self,
path: Path,
citation: str | None = None,
docname: str | None = None,
disable_check: bool = False,
dockey: DocKey | None = None,
chunk_chars: int = 3000,
) -> str | None:
"""Add a document to the collection."""
if dockey is None:
Expand All @@ -317,9 +372,7 @@ def add(
texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100)
if len(texts) == 0:
raise ValueError(f"Could not read document {path}. Is it empty?")
chain_result = asyncio.run(
cite_chain(dict(text=texts[0].text), None),
)
chain_result = await cite_chain(dict(text=texts[0].text), None)
citation = chain_result.text
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
Expand Down Expand Up @@ -353,14 +406,22 @@ def add(
raise ValueError(
f"This does not look like a text document: {path}. Path disable_check to ignore this error."
)
if self.add_texts(texts, doc):
if await self.aadd_texts(texts, doc):
return docname
return None

def add_texts(
self,
texts: list[Text],
doc: Doc,
) -> bool:
loop = get_loop()
return loop.run_until_complete(self.aadd_texts(texts, doc))

async def aadd_texts(
self,
texts: list[Text],
doc: Doc,
) -> bool:
"""Add chunked texts to the collection. This is useful if you have already chunked the texts yourself.
Expand All @@ -376,16 +437,14 @@ def add_texts(
t.name = t.name.replace(doc.docname, new_docname)
doc.docname = new_docname
if texts[0].embedding is None:
text_embeddings = asyncio.run(
self.texts_index.embedding_model.embed_documents(
self._embedding_client, [t.text for t in texts]
)
text_embeddings = await self.texts_index.embedding_model.embed_documents(
self._embedding_client, [t.text for t in texts]
)
for i, t in enumerate(texts):
t.embedding = text_embeddings[i]
if doc.embedding is None:
doc.embedding = asyncio.run(
self.docs_index.embedding_model.embed_documents(
doc.embedding = (
await self.docs_index.embedding_model.embed_documents(
self._embedding_client, [doc.citation]
)
)[0]
Expand Down Expand Up @@ -493,7 +552,7 @@ def get_evidence(
detailed_citations: bool = False,
disable_vector_search: bool = False,
) -> Answer:
return asyncio.run(
return get_loop().run_until_complete(
self.aget_evidence(
answer,
k=k,
Expand Down Expand Up @@ -641,7 +700,7 @@ def query(
key_filter: bool | None = None,
get_callbacks: CallbackFactory = lambda x: None,
) -> Answer:
return asyncio.run(
return get_loop().run_until_complete(
self.aquery(
query,
k=k,
Expand Down
21 changes: 16 additions & 5 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .prompts import default_system_prompt
from .types import Doc, Embeddable, LLMResult, Text
from .utils import batch_iter, flatten, gather_with_concurrency
from .utils import batch_iter, flatten, gather_with_concurrency, is_coroutine_callable


def guess_model_type(model_name: str) -> str:
Expand Down Expand Up @@ -167,7 +167,7 @@ def make_chain(
chat_prompt = [system_message_prompt, human_message_prompt]

async def execute(
data: dict, callbacks: list[Callable[[str], None]] | None = None
data: dict, callbacks: list[Callable] | None = None
) -> LLMResult:
start_clock = asyncio.get_running_loop().time()
result = LLMResult(
Expand All @@ -184,6 +184,10 @@ async def execute(
if callbacks is None:
output = await self.achat(client, messages)
else:
sync_callbacks = [
f for f in callbacks if not is_coroutine_callable(f)
]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]
completion = self.achat_iter(client, messages) # type: ignore
text_result = []
async for chunk in completion: # type: ignore
Expand All @@ -193,7 +197,8 @@ async def execute(
asyncio.get_running_loop().time() - start_clock
)
text_result.append(chunk)
[f(chunk) for f in callbacks]
[await f(chunk) for f in async_callbacks]
[f(chunk) for f in sync_callbacks]
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
result.text = output
Expand All @@ -210,7 +215,7 @@ async def execute(
completion_prompt = system_prompt + "\n\n" + prompt

async def execute(
data: dict, callbacks: list[Callable[[str], None]] | None = None
data: dict, callbacks: list[Callable] | None = None
) -> LLMResult:
start_clock = asyncio.get_running_loop().time()
result = LLMResult(
Expand All @@ -223,6 +228,11 @@ async def execute(
if callbacks is None:
output = await self.acomplete(client, formatted_prompt)
else:
sync_callbacks = [
f for f in callbacks if not is_coroutine_callable(f)
]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]

completion = self.acomplete_iter( # type: ignore
client,
formatted_prompt,
Expand All @@ -235,7 +245,8 @@ async def execute(
asyncio.get_running_loop().time() - start_clock
)
text_result.append(chunk)
[f(chunk) for f in callbacks]
[await f(chunk) for f in async_callbacks]
[f(chunk) for f in sync_callbacks]
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
result.text = output
Expand Down
18 changes: 18 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import inspect
import math
import re
import string
Expand Down Expand Up @@ -152,3 +153,20 @@ def flatten(iteratble: list) -> list:
:return: A flattened list
"""
return [item for sublist in iteratble for item in sublist]


def get_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop


def is_coroutine_callable(obj):
if inspect.isfunction(obj):
return inspect.iscoroutinefunction(obj)
elif callable(obj):
return inspect.iscoroutinefunction(obj.__call__)
return False
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.0-pre.1"
__version__ = "4.0.0-pre.2"
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"pydantic>=2",
"openai>=1",
"numpy",
"nest-asyncio",
"PyCryptodome",
"html2text",
"tiktoken>=0.4.0",
Expand Down
18 changes: 12 additions & 6 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ def accum(x):
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0

# check with mixed callbacks
async def ac(x):
pass

completion = await call(dict(animal="duck"), callbacks=[accum, ac])


def test_docs():
docs = Docs(llm="babbage-002")
Expand Down Expand Up @@ -679,7 +685,7 @@ async def test_langchain_vector_store(self):
)
assert docs._embedding_client is not None # from docs_index default

docs.add_url(
await docs.aadd_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
Expand All @@ -688,18 +694,18 @@ async def test_langchain_vector_store(self):

# now try with JIT
docs = Docs(texts_index=index, jit_texts_index=True)
docs.add_url(
await docs.aadd_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
# should get cleared and rebuilt here
ev = docs.get_evidence(
ev = await docs.aget_evidence(
answer=Answer(question="What is Frederick Bates's greatest accomplishment?")
)
assert len(ev.context) > 0
# now with dockkey filter
docs.get_evidence(
await docs.aget_evidence(
answer=Answer(
question="What is Frederick Bates's greatest accomplishment?",
dockey_filter=["test"],
Expand All @@ -716,7 +722,7 @@ async def test_langchain_vector_store(self):
class Test(IsolatedAsyncioTestCase):
async def test_aquery(self):
docs = Docs()
docs.add_url(
await docs.aadd_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
Expand All @@ -727,7 +733,7 @@ async def test_aquery(self):
class TestDocMatch(IsolatedAsyncioTestCase):
async def test_adoc_match(self):
docs = Docs()
docs.add_url(
await docs.aadd_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
Expand Down

0 comments on commit 89d3eab

Please sign in to comment.