Skip to content

Commit

Permalink
Add callback support in settings and tools (#590)
Browse files Browse the repository at this point in the history
Co-authored-by: James Braza <[email protected]>
  • Loading branch information
nadolskit and jamesbraza authored Oct 16, 2024
1 parent a21e199 commit 6357cae
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 5 deletions.
51 changes: 51 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for tools, implemented in a functional manner."""

import asyncio
import inspect
import logging
import re
Expand Down Expand Up @@ -190,19 +191,33 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
if not state.docs.docs:
raise EmptyDocsError("Not gathering evidence due to having no papers.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks:
await asyncio.gather(
*(
c(state)
for c in self.settings.agent.callbacks[
f"{self.TOOL_FN_NAME}_initialized"
]
)
)

logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
original_question = state.answer.question
try:
# Swap out the question with the more specific question
# TODO: remove this swap, as it prevents us from supporting parallel calls
state.answer.question = question
l0 = len(state.answer.contexts)

# TODO: refactor answer out of this...
state.answer = await state.docs.aget_evidence(
query=state.answer,
settings=self.settings,
embedding_model=self.embedding_model,
summary_llm_model=self.summary_llm_model,
callbacks=self.settings.agent.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_evidence"
),
)
l1 = len(state.answer.contexts)
finally:
Expand All @@ -218,6 +233,17 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
if sorted_contexts
else ""
)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks:
await asyncio.gather(
*(
callback(state)
for callback in self.settings.agent.callbacks[
f"{self.TOOL_FN_NAME}_completed"
]
)
)

return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status


Expand Down Expand Up @@ -248,6 +274,17 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
state: Current state.
"""
logger.info(f"Generating answer for '{question}'.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks:
await asyncio.gather(
*(
callback(state)
for callback in self.settings.agent.callbacks[
f"{self.TOOL_FN_NAME}_initialized"
]
)
)

# TODO: Should we allow the agent to change the question?
# self.answer.question = query
state.answer = await state.docs.aquery(
Expand All @@ -256,6 +293,9 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
llm_model=self.llm_model,
summary_llm_model=self.summary_llm_model,
embedding_model=self.embedding_model,
callbacks=self.settings.agent.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_query"
),
)

if state.answer.could_not_answer:
Expand All @@ -267,6 +307,17 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
answer = state.answer.answer
status = state.status
logger.info(status)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks:
await asyncio.gather(
*(
callback(state)
for callback in self.settings.agent.callbacks[
f"{self.TOOL_FN_NAME}_completed"
]
)
)

return f"{answer} | {status}"

# NOTE: can match failure to answer or an actual answer
Expand Down
35 changes: 34 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import pathlib
import warnings
from collections.abc import Callable, Mapping
from enum import StrEnum
from pydoc import locate
from typing import Any, ClassVar, Self, assert_never, cast
from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast

import anyio
from aviary.tools import ToolSelector
Expand Down Expand Up @@ -53,6 +54,9 @@
from paperqa.utils import hexdigest, pqa_directory
from paperqa.version import __version__

if TYPE_CHECKING:
from .agents.env import EnvironmentState


class AnswerSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -441,6 +445,35 @@ class AgentSettings(BaseModel):
)
index: IndexSettings = Field(default_factory=IndexSettings)

callbacks: Mapping[str, list[Callable[["EnvironmentState"], Any]]] = Field(
default_factory=dict,
description="""
A mapping that associates callback names with lists of corresponding callable functions.
Each callback list contains functions that will be called with an instance of `EnvironmentState`,
representing the current state context.
Accepted callback names:
- 'gen_answer_initialized': Triggered when `GenerateAnswer.gen_answer`
is initialized.
- 'gen_answer_aget_query': LLM callbacks to execute in the prompt runner
as part of `GenerateAnswer.gen_answer`.
- 'gen_answer_completed': Triggered after `GenerateAnswer.gen_answer`
successfully generates an answer.
- 'gather_evidence_initialized': Triggered when `GatherEvidence.gather_evidence`
is initialized.
- 'gather_evidence_aget_evidence: LLM callbacks to execute in the prompt runner
as part of `GatherEvidence.gather_evidence`.
- 'gather_evidence_completed': Triggered after `GatherEvidence.gather_evidence`
completes evidence gathering.
""",
exclude=True,
)

@field_validator("tool_names")
@classmethod
def validate_tool_names(cls, v: set[str] | None) -> set[str] | None:
Expand Down
34 changes: 30 additions & 4 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from uuid import uuid4

import ldp.agent
Expand Down Expand Up @@ -421,10 +421,11 @@ async def test_gather_evidence_rejects_empty_docs(
), "Agent should have hit its max timesteps"


@pytest.mark.parametrize("callback_type", [None, "async"])
@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])
@pytest.mark.asyncio
async def test_agent_sharing_state(
agent_test_settings: Settings, subtests: SubTests
agent_test_settings: Settings, subtests: SubTests, callback_type: str
) -> None:
agent_test_settings.agent.search_count = 3 # Keep low for speed
agent_test_settings.answer.evidence_k = 2
Expand All @@ -433,6 +434,22 @@ async def test_agent_sharing_state(
summary_llm_model = agent_test_settings.get_summary_llm()
embedding_model = agent_test_settings.get_embedding_model()

callbacks = {}
if callback_type == "async":
gen_answer_initialized_callback = AsyncMock()
gen_answer_completed_callback = AsyncMock()
gather_evidence_initialized_callback = AsyncMock()
gather_evidence_completed_callback = AsyncMock()

callbacks = {
"gen_answer_initialized": [gen_answer_initialized_callback],
"gen_answer_completed": [gen_answer_completed_callback],
"gather_evidence_initialized": [gather_evidence_initialized_callback],
"gather_evidence_completed": [gather_evidence_completed_callback],
}

agent_test_settings.agent.callbacks = callbacks # type: ignore[assignment]

answer = Answer(question="What is is a self-explanatory model?")
query = QueryRequest(query=answer.question, settings=agent_test_settings)
env_state = EnvironmentState(docs=Docs(), answer=answer)
Expand All @@ -455,8 +472,7 @@ async def test_agent_sharing_state(
assert env_state.docs.docs, "Search did not add any papers"
mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index"
assert all(
(isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable]
for d in env_state.docs.docs.values()
isinstance(d, Doc) for d in env_state.docs.docs.values()
), "Document type or DOI propagation failure"

with subtests.test(msg=GatherEvidence.__name__):
Expand All @@ -468,6 +484,11 @@ async def test_agent_sharing_state(
embedding_model=embedding_model,
)
await gather_evidence_tool.gather_evidence(answer.question, state=env_state)

if callback_type == "async":
gather_evidence_initialized_callback.assert_awaited_once_with(env_state)
gather_evidence_completed_callback.assert_awaited_once_with(env_state)

assert answer.contexts, "Evidence did not return any results"

with subtests.test(msg=f"{GenerateAnswer.__name__} working"):
Expand All @@ -478,6 +499,11 @@ async def test_agent_sharing_state(
embedding_model=embedding_model,
)
result = await generate_answer_tool.gen_answer(answer.question, state=env_state)

if callback_type == "async":
gen_answer_initialized_callback.assert_awaited_once_with(env_state)
gen_answer_completed_callback.assert_awaited_once_with(env_state)

assert re.search(
pattern=EnvironmentState.STATUS_SEARCH_REGEX_PATTERN, string=result
)
Expand Down

0 comments on commit 6357cae

Please sign in to comment.