Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback support in settings and tools #590

Merged
merged 13 commits into from
Oct 16, 2024
20 changes: 20 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
if sorted_contexts
else ""
)

# named this way to support lifecycle callbacks in the future
# gather_evidence_init, generate_evidence_progress, etc.
if "gather_evidence_completed" in self.settings.callbacks:
nadolskit marked this conversation as resolved.
Show resolved Hide resolved
callback = self.settings.callbacks["gather_evidence_completed"]
if inspect.iscoroutinefunction(callback):
await callback(state)
nadolskit marked this conversation as resolved.
Show resolved Hide resolved
else:
callback(state)

return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status


Expand Down Expand Up @@ -267,6 +277,16 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
answer = state.answer.answer
status = state.status
logger.info(status)

# named this way to support lifecycle callbacks in the future
# generate_answer_init, generate_answer_progress, etc.
if "generate_answer_completed" in self.settings.callbacks:
callback = self.settings.callbacks["generate_answer_completed"]
if inspect.iscoroutinefunction(callback):
await callback(state)
else:
callback(state)

return f"{answer} | {status}"

# NOTE: can match failure to answer or an actual answer
Expand Down
23 changes: 22 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import pathlib
import warnings
from collections.abc import Callable, Mapping
from enum import StrEnum
from pydoc import locate
from typing import Any, ClassVar, Self, assert_never, cast
from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast

import anyio
from aviary.tools import ToolSelector
Expand Down Expand Up @@ -572,6 +573,26 @@ class Settings(BaseSettings):
frozen=True,
)

# imported here to avoid circular ref
if TYPE_CHECKING:
from .agents.env import EnvironmentState
nadolskit marked this conversation as resolved.
Show resolved Hide resolved
callbacks: Mapping[str, Callable[["EnvironmentState"], Any]] = Field(
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved
nadolskit marked this conversation as resolved.
Show resolved Hide resolved
default_factory=dict,
description="""
A mapping that associates callback names with their corresponding callable functions.
Each callback will be called with an instance of `EnvironmentState`, representing
the current state context.

The callback functions can be synchronous or asynchronous, and are used to trigger specific
actions after key operations in the agent lifecycle.

Accepted callback names:
- 'generate_answer_completed': Triggered after `GenerateAnswer.gen_answer` generates an answer.
- 'gather_evidence_completed': Triggered after `GatherEvidence.gather_evidence` gathers evidence.

""",
)

@model_validator(mode="after")
def _deprecated_field(self) -> Self:
for deprecated_field_name, new_name in (
Expand Down
38 changes: 33 additions & 5 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import patch
from unittest.mock import AsyncMock, Mock, patch
from uuid import uuid4

import ldp.agent
Expand Down Expand Up @@ -421,10 +421,10 @@ async def test_gather_evidence_rejects_empty_docs(
), "Agent should have hit its max timesteps"


@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])
@pytest.mark.parametrize("callback_type", ["none", "sync", "async"])
@pytest.mark.asyncio
async def test_agent_sharing_state(
agent_test_settings: Settings, subtests: SubTests
agent_test_settings: Settings, subtests: SubTests, callback_type: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be callback_type: str | None

) -> None:
agent_test_settings.agent.search_count = 3 # Keep low for speed
agent_test_settings.answer.evidence_k = 2
Expand All @@ -433,6 +433,24 @@ async def test_agent_sharing_state(
summary_llm_model = agent_test_settings.get_summary_llm()
embedding_model = agent_test_settings.get_embedding_model()

callbacks = {}
if callback_type == "sync":
generate_answer_callback = Mock()
gather_evidence_callback = Mock()
callbacks = {
"generate_answer_completed": generate_answer_callback,
"gather_evidence_completed": gather_evidence_callback,
}
elif callback_type == "async":
agenerate_answer_callback = AsyncMock()
agather_evidence_callback = AsyncMock()
callbacks = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it really matters, but it's more future proof to use callbacks.update here instead of redefining

"generate_answer_completed": agenerate_answer_callback,
"gather_evidence_completed": agather_evidence_callback,
}

agent_test_settings.callbacks = callbacks

answer = Answer(question="What is is a self-explanatory model?")
query = QueryRequest(query=answer.question, settings=agent_test_settings)
env_state = EnvironmentState(docs=Docs(), answer=answer)
Expand All @@ -455,8 +473,7 @@ async def test_agent_sharing_state(
assert env_state.docs.docs, "Search did not add any papers"
mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index"
assert all(
(isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable]
for d in env_state.docs.docs.values()
isinstance(d, Doc) for d in env_state.docs.docs.values()
), "Document type or DOI propagation failure"

with subtests.test(msg=GatherEvidence.__name__):
Expand All @@ -470,6 +487,11 @@ async def test_agent_sharing_state(
await gather_evidence_tool.gather_evidence(answer.question, state=env_state)
assert answer.contexts, "Evidence did not return any results"

if callback_type == "sync":
gather_evidence_callback.assert_called_with(env_state)
elif callback_type == "async":
agather_evidence_callback.assert_awaited_once_with(env_state)

with subtests.test(msg=f"{GenerateAnswer.__name__} working"):
generate_answer_tool = GenerateAnswer(
settings=agent_test_settings,
Expand All @@ -478,6 +500,12 @@ async def test_agent_sharing_state(
embedding_model=embedding_model,
)
result = await generate_answer_tool.gen_answer(answer.question, state=env_state)

if callback_type == "sync":
generate_answer_callback.assert_called_with(env_state)
elif callback_type == "async":
agenerate_answer_callback.assert_awaited_once_with(env_state)

assert re.search(
pattern=EnvironmentState.STATUS_SEARCH_REGEX_PATTERN, string=result
)
Expand Down