From 6357caee73de336721aa91c0bae814ae322f363e Mon Sep 17 00:00:00 2001 From: Tyler Nadolski <122555266+nadolskit@users.noreply.github.com> Date: Wed, 16 Oct 2024 08:13:42 -0700 Subject: [PATCH] Add callback support in settings and tools (#590) Co-authored-by: James Braza --- paperqa/agents/tools.py | 51 +++++++++++++++++++++++++++++++++++++++++ paperqa/settings.py | 35 +++++++++++++++++++++++++++- tests/test_agents.py | 34 +++++++++++++++++++++++---- 3 files changed, 115 insertions(+), 5 deletions(-) diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 553fd643f..25ba7eaad 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -1,5 +1,6 @@ """Base classes for tools, implemented in a functional manner.""" +import asyncio import inspect import logging import re @@ -190,6 +191,16 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: if not state.docs.docs: raise EmptyDocsError("Not gathering evidence due to having no papers.") + if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks: + await asyncio.gather( + *( + c(state) + for c in self.settings.agent.callbacks[ + f"{self.TOOL_FN_NAME}_initialized" + ] + ) + ) + logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.") original_question = state.answer.question try: @@ -197,12 +208,16 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: # TODO: remove this swap, as it prevents us from supporting parallel calls state.answer.question = question l0 = len(state.answer.contexts) + # TODO: refactor answer out of this... state.answer = await state.docs.aget_evidence( query=state.answer, settings=self.settings, embedding_model=self.embedding_model, summary_llm_model=self.summary_llm_model, + callbacks=self.settings.agent.callbacks.get( + f"{self.TOOL_FN_NAME}_aget_evidence" + ), ) l1 = len(state.answer.contexts) finally: @@ -218,6 +233,17 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: if sorted_contexts else "" ) + + if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks: + await asyncio.gather( + *( + callback(state) + for callback in self.settings.agent.callbacks[ + f"{self.TOOL_FN_NAME}_completed" + ] + ) + ) + return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status @@ -248,6 +274,17 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: state: Current state. """ logger.info(f"Generating answer for '{question}'.") + + if f"{self.TOOL_FN_NAME}_initialized" in self.settings.agent.callbacks: + await asyncio.gather( + *( + callback(state) + for callback in self.settings.agent.callbacks[ + f"{self.TOOL_FN_NAME}_initialized" + ] + ) + ) + # TODO: Should we allow the agent to change the question? # self.answer.question = query state.answer = await state.docs.aquery( @@ -256,6 +293,9 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: llm_model=self.llm_model, summary_llm_model=self.summary_llm_model, embedding_model=self.embedding_model, + callbacks=self.settings.agent.callbacks.get( + f"{self.TOOL_FN_NAME}_aget_query" + ), ) if state.answer.could_not_answer: @@ -267,6 +307,17 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: answer = state.answer.answer status = state.status logger.info(status) + + if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks: + await asyncio.gather( + *( + callback(state) + for callback in self.settings.agent.callbacks[ + f"{self.TOOL_FN_NAME}_completed" + ] + ) + ) + return f"{answer} | {status}" # NOTE: can match failure to answer or an actual answer diff --git a/paperqa/settings.py b/paperqa/settings.py index 110ba55fb..a40920fcb 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -3,9 +3,10 @@ import os import pathlib import warnings +from collections.abc import Callable, Mapping from enum import StrEnum from pydoc import locate -from typing import Any, ClassVar, Self, assert_never, cast +from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast import anyio from aviary.tools import ToolSelector @@ -53,6 +54,9 @@ from paperqa.utils import hexdigest, pqa_directory from paperqa.version import __version__ +if TYPE_CHECKING: + from .agents.env import EnvironmentState + class AnswerSettings(BaseModel): model_config = ConfigDict(extra="forbid") @@ -441,6 +445,35 @@ class AgentSettings(BaseModel): ) index: IndexSettings = Field(default_factory=IndexSettings) + callbacks: Mapping[str, list[Callable[["EnvironmentState"], Any]]] = Field( + default_factory=dict, + description=""" + A mapping that associates callback names with lists of corresponding callable functions. + Each callback list contains functions that will be called with an instance of `EnvironmentState`, + representing the current state context. + + Accepted callback names: + - 'gen_answer_initialized': Triggered when `GenerateAnswer.gen_answer` + is initialized. + + - 'gen_answer_aget_query': LLM callbacks to execute in the prompt runner + as part of `GenerateAnswer.gen_answer`. + + - 'gen_answer_completed': Triggered after `GenerateAnswer.gen_answer` + successfully generates an answer. + + - 'gather_evidence_initialized': Triggered when `GatherEvidence.gather_evidence` + is initialized. + + - 'gather_evidence_aget_evidence: LLM callbacks to execute in the prompt runner + as part of `GatherEvidence.gather_evidence`. + + - 'gather_evidence_completed': Triggered after `GatherEvidence.gather_evidence` + completes evidence gathering. + """, + exclude=True, + ) + @field_validator("tool_names") @classmethod def validate_tool_names(cls, v: set[str] | None) -> set[str] | None: diff --git a/tests/test_agents.py b/tests/test_agents.py index 397212ba6..92eb6da65 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -10,7 +10,7 @@ import time from pathlib import Path from typing import Any, cast -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from uuid import uuid4 import ldp.agent @@ -421,10 +421,11 @@ async def test_gather_evidence_rejects_empty_docs( ), "Agent should have hit its max timesteps" +@pytest.mark.parametrize("callback_type", [None, "async"]) @pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"]) @pytest.mark.asyncio async def test_agent_sharing_state( - agent_test_settings: Settings, subtests: SubTests + agent_test_settings: Settings, subtests: SubTests, callback_type: str ) -> None: agent_test_settings.agent.search_count = 3 # Keep low for speed agent_test_settings.answer.evidence_k = 2 @@ -433,6 +434,22 @@ async def test_agent_sharing_state( summary_llm_model = agent_test_settings.get_summary_llm() embedding_model = agent_test_settings.get_embedding_model() + callbacks = {} + if callback_type == "async": + gen_answer_initialized_callback = AsyncMock() + gen_answer_completed_callback = AsyncMock() + gather_evidence_initialized_callback = AsyncMock() + gather_evidence_completed_callback = AsyncMock() + + callbacks = { + "gen_answer_initialized": [gen_answer_initialized_callback], + "gen_answer_completed": [gen_answer_completed_callback], + "gather_evidence_initialized": [gather_evidence_initialized_callback], + "gather_evidence_completed": [gather_evidence_completed_callback], + } + + agent_test_settings.agent.callbacks = callbacks # type: ignore[assignment] + answer = Answer(question="What is is a self-explanatory model?") query = QueryRequest(query=answer.question, settings=agent_test_settings) env_state = EnvironmentState(docs=Docs(), answer=answer) @@ -455,8 +472,7 @@ async def test_agent_sharing_state( assert env_state.docs.docs, "Search did not add any papers" mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index" assert all( - (isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable] - for d in env_state.docs.docs.values() + isinstance(d, Doc) for d in env_state.docs.docs.values() ), "Document type or DOI propagation failure" with subtests.test(msg=GatherEvidence.__name__): @@ -468,6 +484,11 @@ async def test_agent_sharing_state( embedding_model=embedding_model, ) await gather_evidence_tool.gather_evidence(answer.question, state=env_state) + + if callback_type == "async": + gather_evidence_initialized_callback.assert_awaited_once_with(env_state) + gather_evidence_completed_callback.assert_awaited_once_with(env_state) + assert answer.contexts, "Evidence did not return any results" with subtests.test(msg=f"{GenerateAnswer.__name__} working"): @@ -478,6 +499,11 @@ async def test_agent_sharing_state( embedding_model=embedding_model, ) result = await generate_answer_tool.gen_answer(answer.question, state=env_state) + + if callback_type == "async": + gen_answer_initialized_callback.assert_awaited_once_with(env_state) + gen_answer_completed_callback.assert_awaited_once_with(env_state) + assert re.search( pattern=EnvironmentState.STATUS_SEARCH_REGEX_PATTERN, string=result )