Skip to content

Commit

Permalink
parallelize the callback calling
Browse files Browse the repository at this point in the history
  • Loading branch information
nadolskit committed Oct 16, 2024
1 parent 68320d8 commit eaab9e7
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for tools, implemented in a functional manner."""

import asyncio
import inspect
import logging
import re
Expand Down Expand Up @@ -191,7 +192,12 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
raise EmptyDocsError("Not gathering evidence due to having no papers.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.callbacks:
await asyncio.gather(*(c(state) for c in self.settings.callbacks[f"{self.TOOL_FN_NAME}_initialized"]))
await asyncio.gather(
*(
c(state)
for c in self.settings.callbacks[f"{self.TOOL_FN_NAME}_initialized"]
)
)

logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
original_question = state.answer.question
Expand Down Expand Up @@ -227,9 +233,14 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_completed"]
for callback in callback_list:
await callback(state)
await asyncio.gather(
*(
callback(state)
for callback in self.settings.callbacks[
f"{self.TOOL_FN_NAME}_completed"
]
)
)

return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status

Expand Down Expand Up @@ -263,9 +274,14 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
logger.info(f"Generating answer for '{question}'.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_initialized"]
for callback in callback_list:
await callback(state)
await asyncio.gather(
*(
callback(state)
for callback in self.settings.callbacks[
f"{self.TOOL_FN_NAME}_initialized"
]
)
)

# TODO: Should we allow the agent to change the question?
# self.answer.question = query
Expand All @@ -289,9 +305,14 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
logger.info(status)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_completed"]
for callback in callback_list:
await callback(state)
await asyncio.gather(
*(
callback(state)
for callback in self.settings.callbacks[
f"{self.TOOL_FN_NAME}_completed"
]
)
)

return f"{answer} | {status}"

Expand Down

0 comments on commit eaab9e7

Please sign in to comment.