Skip to content

Commit

Permalink
Ability to zero-shot gen_answer (#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Nov 2, 2024
1 parent 77be14f commit bcaf43d
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ repos:
- aiohttp
- coredis
- fhaviary[llm]>=0.8.2 # Match pyproject.toml
- ldp>=0.9 # Match pyproject.toml
- ldp>=0.12 # Match pyproject.toml
- html2text
- httpx
- limits
Expand Down
3 changes: 1 addition & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,7 @@ async def aquery( # noqa: PLR0912
)

contexts = session.contexts

if not contexts:
if answer_config.get_evidence_if_no_contexts and not contexts:
session = await self.aget_evidence(
session,
callbacks=callbacks,
Expand Down
7 changes: 7 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class AnswerSettings(BaseModel):
default=False,
description="Whether to cite background information provided by model.",
)
get_evidence_if_no_contexts: bool = Field(
default=True,
description=(
"Opt-out flag for allowing answer generation to lazily gather evidence if"
" called before evidence was gathered."
),
)

@model_validator(mode="after")
def _deprecated_field(self) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ datasets = [
"datasets",
]
ldp = [
"ldp>=0.9", # For alg namespace grouping
"ldp>=0.12", # For StoreTrajectoriesCallback
]
local = [
"sentence-transformers",
Expand Down
31 changes: 29 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest
from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset
from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.callbacks import MeanMetricsCallback, StoreTrajectoriesCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig
from pytest_subtests import SubTests

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents import get_directory_index
Expand All @@ -16,6 +17,7 @@
LitQAv2TaskDataset,
LitQAv2TaskSplit,
)
from paperqa.agents.tools import GenerateAnswer


@pytest.fixture(name="base_query_request")
Expand Down Expand Up @@ -106,7 +108,9 @@ async def test_can_validate_stub_dataset_sources(
)

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
async def test_evaluation(
self, subtests: SubTests, base_query_request: QueryRequest
) -> None:
await get_directory_index(settings=base_query_request.settings) # Build
docs = Docs()
# Why are we constructing a TaskConfig here using a serialized QueryRequest and
Expand Down Expand Up @@ -150,6 +154,29 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None:
isinstance(metrics_callback.eval_means["reward"], float) > 0
), "Expected some wins"

with subtests.test(msg="zero-shot"):
# Confirm we can just directly call gen_answer
base_query_request.settings.agent.tool_names = {
GenerateAnswer.gen_answer.__name__
}
base_query_request.settings.answer.get_evidence_if_no_contexts = False
dataset = LitQAv2TaskDataset(base_query=base_query_request)
dataset.data = dataset.data[:2] # Save the world: just use two questions
storage_callback = StoreTrajectoriesCallback()
evaluator = Evaluator(
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2),
agent=SimpleAgent(),
dataset=dataset,
callbacks=[storage_callback],
)
await evaluator.evaluate()
for traj in storage_callback.eval_trajectories:
for step in traj.steps:
assert all(
tc.function.name == GenerateAnswer.gen_answer.__name__
for tc in step.action.value.tool_calls
)

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_tool_failure(self, base_query_request: QueryRequest) -> None:
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bcaf43d

Please sign in to comment.