Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback support in settings and tools #590

Merged
merged 13 commits into from
Oct 16, 2024
28 changes: 28 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,28 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
if not state.docs.docs:
raise EmptyDocsError("Not gathering evidence due to having no papers.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_initialized"]
for callback in callback_list:
await callback(state)
nadolskit marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
original_question = state.answer.question
try:
# Swap out the question with the more specific question
# TODO: remove this swap, as it prevents us from supporting parallel calls
state.answer.question = question
l0 = len(state.answer.contexts)

# TODO: refactor answer out of this...
state.answer = await state.docs.aget_evidence(
query=state.answer,
settings=self.settings,
embedding_model=self.embedding_model,
summary_llm_model=self.summary_llm_model,
callbacks=self.settings.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_evidence"
),
)
l1 = len(state.answer.contexts)
finally:
Expand All @@ -218,6 +227,12 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
if sorted_contexts
else ""
)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_completed"]
for callback in callback_list:
await callback(state)
nadolskit marked this conversation as resolved.
Show resolved Hide resolved

return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status


Expand Down Expand Up @@ -248,6 +263,12 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
state: Current state.
"""
logger.info(f"Generating answer for '{question}'.")

if f"{self.TOOL_FN_NAME}_initialized" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_initialized"]
for callback in callback_list:
await callback(state)

# TODO: Should we allow the agent to change the question?
# self.answer.question = query
state.answer = await state.docs.aquery(
Expand All @@ -256,6 +277,7 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
llm_model=self.llm_model,
summary_llm_model=self.summary_llm_model,
embedding_model=self.embedding_model,
callbacks=self.settings.callbacks.get(f"{self.TOOL_FN_NAME}_aget_query"),
)

if state.answer.could_not_answer:
Expand All @@ -267,6 +289,12 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
answer = state.answer.answer
status = state.status
logger.info(status)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.callbacks:
callback_list = self.settings.callbacks[f"{self.TOOL_FN_NAME}_completed"]
for callback in callback_list:
await callback(state)

return f"{answer} | {status}"

# NOTE: can match failure to answer or an actual answer
Expand Down
35 changes: 34 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import pathlib
import warnings
from collections.abc import Callable, Mapping
from enum import StrEnum
from pydoc import locate
from typing import Any, ClassVar, Self, assert_never, cast
from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast

import anyio
from aviary.tools import ToolSelector
Expand Down Expand Up @@ -53,6 +54,9 @@
from paperqa.utils import hexdigest, pqa_directory
from paperqa.version import __version__

if TYPE_CHECKING:
from .agents.env import EnvironmentState


class AnswerSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -580,6 +584,35 @@ class Settings(BaseSettings):
frozen=True,
)

callbacks: Mapping[str, list[Callable[["EnvironmentState"], Any]]] = Field(
nadolskit marked this conversation as resolved.
Show resolved Hide resolved
default_factory=dict,
description="""
A mapping that associates callback names with lists of corresponding callable functions.
Each callback list contains functions that will be called with an instance of `EnvironmentState`,
representing the current state context.

Accepted callback names:
- 'gen_answer_initialized': Triggered when `GenerateAnswer.gen_answer`
is initialized.

- 'gen_answer_aget_query': LLM callbacks to execute in the prompt runner
as part of `GenerateAnswer.gen_answer`.

- 'gen_answer_completed': Triggered after `GenerateAnswer.gen_answer`
successfully generates an answer.

- 'gather_evidence_initialized': Triggered when `GatherEvidence.gather_evidence`
is initialized.

- 'gather_evidence_aget_evidence: LLM callbacks to execute in the prompt runner
as part of `GatherEvidence.gather_evidence`.

- 'gather_evidence_completed': Triggered after `GatherEvidence.gather_evidence`
completes evidence gathering.
""",
exclude=True,
)

@model_validator(mode="after")
def _deprecated_field(self) -> Self:
for deprecated_field_name, new_name in (
Expand Down
34 changes: 30 additions & 4 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from uuid import uuid4

import ldp.agent
Expand Down Expand Up @@ -421,10 +421,11 @@ async def test_gather_evidence_rejects_empty_docs(
), "Agent should have hit its max timesteps"


@pytest.mark.parametrize("callback_type", [None, "async"])
@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])
@pytest.mark.asyncio
async def test_agent_sharing_state(
agent_test_settings: Settings, subtests: SubTests
agent_test_settings: Settings, subtests: SubTests, callback_type: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be callback_type: str | None

) -> None:
agent_test_settings.agent.search_count = 3 # Keep low for speed
agent_test_settings.answer.evidence_k = 2
Expand All @@ -433,6 +434,22 @@ async def test_agent_sharing_state(
summary_llm_model = agent_test_settings.get_summary_llm()
embedding_model = agent_test_settings.get_embedding_model()

callbacks = {}
if callback_type == "async":
gen_answer_initialized_callback = AsyncMock()
gen_answer_completed_callback = AsyncMock()
gather_evidence_initialized_callback = AsyncMock()
gather_evidence_completed_callback = AsyncMock()

callbacks = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it really matters, but it's more future proof to use callbacks.update here instead of redefining

"gen_answer_initialized": [gen_answer_initialized_callback],
"gen_answer_completed": [gen_answer_completed_callback],
"gather_evidence_initialized": [gather_evidence_initialized_callback],
"gather_evidence_completed": [gather_evidence_completed_callback],
}

agent_test_settings.callbacks = callbacks # type: ignore[assignment]

answer = Answer(question="What is is a self-explanatory model?")
query = QueryRequest(query=answer.question, settings=agent_test_settings)
env_state = EnvironmentState(docs=Docs(), answer=answer)
Expand All @@ -455,8 +472,7 @@ async def test_agent_sharing_state(
assert env_state.docs.docs, "Search did not add any papers"
mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index"
assert all(
(isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable]
for d in env_state.docs.docs.values()
isinstance(d, Doc) for d in env_state.docs.docs.values()
), "Document type or DOI propagation failure"

with subtests.test(msg=GatherEvidence.__name__):
Expand All @@ -468,6 +484,11 @@ async def test_agent_sharing_state(
embedding_model=embedding_model,
)
await gather_evidence_tool.gather_evidence(answer.question, state=env_state)

if callback_type == "async":
gather_evidence_initialized_callback.assert_awaited_once_with(env_state)
gather_evidence_completed_callback.assert_awaited_once_with(env_state)

assert answer.contexts, "Evidence did not return any results"

with subtests.test(msg=f"{GenerateAnswer.__name__} working"):
Expand All @@ -478,6 +499,11 @@ async def test_agent_sharing_state(
embedding_model=embedding_model,
)
result = await generate_answer_tool.gen_answer(answer.question, state=env_state)

if callback_type == "async":
gen_answer_initialized_callback.assert_awaited_once_with(env_state)
gen_answer_completed_callback.assert_awaited_once_with(env_state)

assert re.search(
pattern=EnvironmentState.STATUS_SEARCH_REGEX_PATTERN, string=result
)
Expand Down
Loading