Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved to MultipleChoiceQuestion/MultipleChoiceEvaluation from aviary #768

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ repos:
- aiohttp>=3.10.6 # Match pyproject.toml
- PyMuPDF>=1.24.12
- anyio
- fhaviary[llm]>=0.10.2 # Match pyproject.toml
- ldp>=0.14.5 # Match pyproject.toml
- fhaviary[llm]>=0.14 # Match pyproject.toml
- ldp>=0.17 # Match pyproject.toml
- html2text
- fh-llm-client
- httpx
Expand Down
9 changes: 8 additions & 1 deletion paperqa/_ldp_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"UIndexMemoryModel",
"_Memories",
"discounted_returns",
"evaluate_consensus",
"set_training_mode",
]

Expand All @@ -29,7 +30,12 @@
SimpleAgent,
SimpleAgentState,
)
from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager
from ldp.alg import (
Callback,
ComputeTrajectoryMetricsMixin,
RolloutManager,
evaluate_consensus,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.utils import discounted_returns
Expand All @@ -48,3 +54,4 @@ class Callback: # type: ignore[no-redef]

RolloutManager = None # type: ignore[assignment,misc]
discounted_returns = None # type: ignore[assignment]
evaluate_consensus = None # type: ignore[assignment]
4 changes: 3 additions & 1 deletion paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.utils import MultipleChoiceQuestion
from llmclient import EmbeddingModel, LiteLLMModel

from paperqa.docs import Docs
Expand Down Expand Up @@ -127,10 +128,11 @@ def make_tools(self) -> list[Tool]:
)

def make_initial_state(self) -> EnvironmentState:
query: str | MultipleChoiceQuestion = self._query.query
return EnvironmentState(
docs=self._docs,
session=PQASession(
question=self._query.query,
question=query if isinstance(query, str) else query.question_prompt,
config_md5=self._query.settings.md5,
id=self._query.id,
),
Expand Down
9 changes: 8 additions & 1 deletion paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, ClassVar, Protocol
from uuid import UUID, uuid4

from aviary.utils import MultipleChoiceQuestion
from llmclient import LiteLLMModel, LLMModel
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -55,7 +56,13 @@ class MismatchedModelsError(Exception):
class QueryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

query: str = ""
query: str | MultipleChoiceQuestion = Field(
default="",
description=(
"The query to be answered. Set to a multiple choice question when grading"
" (e.g. for training)."
),
)
id: UUID = Field(
default_factory=uuid4,
description="Identifier which will be propagated to the Answer object.",
Expand Down
157 changes: 121 additions & 36 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,49 @@
import logging
import re
from abc import ABC
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, assert_never

from aviary.core import (
TASK_DATASET_REGISTRY,
Environment,
Frame,
Messages,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY
from aviary.utils import (
DEFAULT_EVAL_MODEL_NAME,
MultipleChoiceEvaluation,
MultipleChoiceQuestion,
)
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel

from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
from paperqa._ldp_shims import (
Callback,
ComputeTrajectoryMetricsMixin,
evaluate_consensus,
)
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.types import DocDetails, PQASession

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .search import SearchIndex, maybe_get_manifest
from .tools import Complete
from .tools import Complete, EnvironmentState

if TYPE_CHECKING:
from ldp.data_structures import Trajectory
from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition

logger = logging.getLogger(__name__)

Expand All @@ -58,26 +67,22 @@ def __init__(
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None
) = None,
sources: str | list[str] | None = None,
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
evaluation_callback: (
Callable[[MultipleChoiceEvaluation], Awaitable] | None
) = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
# Enables checking an Index has the right DOI(s)
self.sources: list[str] | None = (
[sources] if isinstance(sources, str) else sources
)
self._evaluation_callback = evaluation_callback
self._rewards = rewards
self.answer = ""
self.ideal = ""

async def validate_sources(
self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None
Expand Down Expand Up @@ -120,7 +125,7 @@ async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
if not done or not isinstance(self._query.query, MultipleChoiceQuestion):
return messages, reward, done, truncated
# If the ensuring evaluation fails (e.g. due to OpenAI being down), we can:
# - Suppress the exception and declare the evaluation as incorrect, which can
Expand All @@ -130,23 +135,13 @@ async def step(
# incorrectly reward what otherwise was a good trajectory.
# - Don't suppress the exception, which leads to the trajectory failing, and
# removes it from the learnable pool. This is the only safe default behavior.
evaluation = await self._evaluation_from_answer(self.state.session.answer)
evaluation, self.state.session.graded_answer = await self._query.query.grade(
self.state.session.answer
)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
self.answer = evaluation.answer or ""
self.ideal = evaluation.ideal or ""
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
return Frame(
state=self.state,
info={
"query": self._query,
"answer": self.answer,
"ideal": self.ideal,
},
)

def __deepcopy__(self, memo) -> Self:
copy_state = deepcopy(self.state, memo)
# We don't know the side effects of deep copying a litellm.Router,
Expand All @@ -162,7 +157,6 @@ def __deepcopy__(self, memo) -> Self:
copy_self = type(self)(
query=deepcopy(self._query, memo), # deepcopy for _docs_name
docs=copy_state.docs,
evaluation_from_answer=self._evaluation_from_answer,
sources=self.sources,
rewards=self._rewards,
evaluation_callback=self._evaluation_callback,
Expand All @@ -182,6 +176,95 @@ def __deepcopy__(self, memo) -> Self:
)


async def evaluate_consensus_sampling(
data: Iterable[GradablePaperQAEnvironment | Frame],
num_samples: int = 1,
seed: int | None = None,
) -> tuple[dict[str, list[tuple[str, int]]], float]:
def get_question(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
return query
if isinstance(query, MultipleChoiceQuestion):
return query.question_prompt
return query["question"]

def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str:
sess: PQASession | dict[str, Any] = (
x.state.session
if isinstance(x.state, EnvironmentState)
else x.state["session"]
)
return (
sess.graded_answer
if isinstance(sess, PQASession)
else sess["graded_answer"]
) or ""

def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
raise ValueError( # noqa: TRY004
"We require a {MultipleChoiceQuestion.__name__} variant to extract"
" ideal answer, not a string."
)
if isinstance(query, MultipleChoiceQuestion):
return query.ideal_answer
return query["ideal_answer"]

try:
return await evaluate_consensus(
data=data,
grouping_fn=get_question,
extract_answer_fn=extract_answer,
ideal_answer_fn=extract_ideal,
num_samples=num_samples,
seed=seed,
)
except TypeError:
raise ImportError(
"Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
) from None


class StoreForConsensusSamplingCallback(Callback):
def __init__(self):
super().__init__()
self.stored: list[GradablePaperQAEnvironment | Frame] = []

async def after_transition(
self,
traj_id: str, # noqa: ARG002
agent: "Agent", # noqa: ARG002
env: Environment,
transition: "Transition",
) -> None:
if not isinstance(env, GradablePaperQAEnvironment):
raise NotImplementedError(
f"So far only handled {GradablePaperQAEnvironment} in this callback,"
f" not {type(env)}."
)
if not transition.done: # Only store once
return
self.stored.append(env)

async def evaluate_consensus_sampling(
self, num_samples: int = 1, seed: int | None = None
) -> tuple[dict[str, list[tuple[str, int]]], float]:
return await evaluate_consensus_sampling(
data=self.stored, num_samples=num_samples, seed=seed
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
Expand Down Expand Up @@ -218,24 +301,26 @@ def __init__(

def _make_gradable_environment(
self,
ideal: str,
ideal_answer: str,
distractors: str | list[str],
question: str,
sources: str | list[str] | None = None,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
mc_question = MultipleChoiceQuestion(
question=question,
eval_model=self._eval_model,
options=(
distractors
if isinstance(distractors, list)
else MultipleChoiceQuestion.split_options(distractors)
),
ideal_answer=ideal_answer,
**(self._question_kwargs or {}),
)
query = self._base_query.model_copy()
query.query = qa_prompt
query.query = mc_question
return GradablePaperQAEnvironment(
query=query,
docs=self._base_docs.model_copy(),
evaluation_from_answer=evaluation_from_answer,
sources=sources,
rewards=self._rewards,
**self._env_kwargs,
Expand Down Expand Up @@ -338,7 +423,7 @@ def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
) from exc
sources.append(doi)
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
ideal_answer=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
sources=sources,
Expand Down
Loading
Loading