From dcb4fd471f110a58120e1e8ebbbe8006fbe40904 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 30 Oct 2024 15:27:56 -0700 Subject: [PATCH] Renamed to PQASession type (#653) --- paperqa/__init__.py | 5 ++- paperqa/agents/env.py | 8 ++-- paperqa/agents/helpers.py | 4 +- paperqa/agents/main.py | 36 ++++++++-------- paperqa/agents/models.py | 14 ++++--- paperqa/agents/task.py | 4 +- paperqa/agents/tools.py | 38 ++++++++--------- paperqa/docs.py | 86 +++++++++++++++++++-------------------- paperqa/litqa.py | 8 ++-- paperqa/settings.py | 4 +- paperqa/types.py | 41 +++++++++++++------ tests/conftest.py | 6 +-- tests/test_agents.py | 46 ++++++++++----------- tests/test_cli.py | 4 +- tests/test_paperqa.py | 10 ++++- 15 files changed, 169 insertions(+), 145 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 42242498..008b1825 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -9,7 +9,7 @@ from paperqa.agents import ask # noqa: E402 from paperqa.agents.main import agent_query # noqa: E402 from paperqa.agents.models import QueryRequest # noqa: E402 -from paperqa.docs import Answer, Docs, print_callback # noqa: E402 +from paperqa.docs import Docs, PQASession, print_callback # noqa: E402 from paperqa.llms import ( # noqa: E402 EmbeddingModel, HybridEmbeddingModel, @@ -23,7 +23,7 @@ embedding_model_factory, ) from paperqa.settings import Settings, get_settings # noqa: E402 -from paperqa.types import Context, Doc, DocDetails, Text # noqa: E402 +from paperqa.types import Answer, Context, Doc, DocDetails, Text # noqa: E402 from paperqa.version import __version__ # noqa: E402 __all__ = [ @@ -39,6 +39,7 @@ "LiteLLMEmbeddingModel", "LiteLLMModel", "NumpyVectorStore", + "PQASession", "QueryRequest", "SentenceTransformerEmbeddingModel", "Settings", diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 7791bd29..f9f44a66 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -14,7 +14,7 @@ from paperqa.docs import Docs from paperqa.llms import EmbeddingModel, LiteLLMModel from paperqa.settings import Settings -from paperqa.types import Answer +from paperqa.types import PQASession from paperqa.utils import get_year from .models import QueryRequest @@ -128,7 +128,7 @@ def make_tools(self) -> list[Tool]: def make_initial_state(self) -> EnvironmentState: return EnvironmentState( docs=self._docs, - answer=Answer( + answer=PQASession( question=self._query.query, config_md5=self._query.settings.md5, id=self._query.id, @@ -145,7 +145,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]: [ Message( content=self._query.settings.agent.agent_prompt.format( - question=self.state.answer.question, + question=self.state.session.question, status=self.state.status, gen_answer_tool_name=GenerateAnswer.TOOL_FN_NAME, ), @@ -160,7 +160,7 @@ def export_frame(self) -> Frame: async def step( self, action: ToolRequestMessage ) -> tuple[list[Message], float, bool, bool]: - self.state.answer.add_tokens(action) # Add usage for action if present + self.state.session.add_tokens(action) # Add usage for action if present # If the action has empty tool_calls, the agent can later take that into account msgs = cast( diff --git a/paperqa/agents/helpers.py b/paperqa/agents/helpers.py index 85de16b7..591eb644 100644 --- a/paperqa/agents/helpers.py +++ b/paperqa/agents/helpers.py @@ -81,8 +81,8 @@ def table_formatter( table.add_column("Answer", style="magenta") for obj, _ in objects: table.add_row( - cast(AnswerResponse, obj).answer.question[:max_chars_per_column], - cast(AnswerResponse, obj).answer.answer[:max_chars_per_column], + cast(AnswerResponse, obj).session.question[:max_chars_per_column], + cast(AnswerResponse, obj).session.answer[:max_chars_per_column], ) return table if isinstance(example_object, Docs): diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 18ef6763..033bde5b 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -32,7 +32,7 @@ class Callback: # type: ignore[no-redef] from paperqa.docs import Docs from paperqa.settings import AgentSettings -from paperqa.types import Answer +from paperqa.types import PQASession from .env import PaperQAEnvironment from .helpers import litellm_get_search_query, table_formatter @@ -72,13 +72,13 @@ async def agent_query( response = await run_agent(docs, query, agent_type, **runner_kwargs) agent_logger.debug(f"agent_response: {response}") - agent_logger.info(f"[bold blue]Answer: {response.answer.answer}[/bold blue]") + agent_logger.info(f"[bold blue]Answer: {response.session.answer}[/bold blue]") await answers_index.add_document( { - "file_location": str(response.answer.id), - "body": response.answer.answer, - "question": response.answer.question, + "file_location": str(response.session.id), + "body": response.session.answer, + "question": response.session.question, }, document=response, ) @@ -120,26 +120,26 @@ async def run_agent( # Build the index once here, and then all tools won't need to rebuild it await get_directory_index(settings=query.settings) if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE: - answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs) + session, agent_status = await run_fake_agent(query, docs, **runner_kwargs) elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type): - answer, agent_status = await run_aviary_agent( + session, agent_status = await run_aviary_agent( query, docs, tool_selector_or_none, **runner_kwargs ) elif ldp_agent_or_none := await query.settings.make_ldp_agent(agent_type): - answer, agent_status = await run_ldp_agent( + session, agent_status = await run_ldp_agent( query, docs, ldp_agent_or_none, **runner_kwargs ) else: raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.") - if answer.could_not_answer and agent_status != AgentStatus.TRUNCATED: + if session.could_not_answer and agent_status != AgentStatus.TRUNCATED: agent_status = AgentStatus.UNSURE # stop after, so overall isn't reported as long-running step. logger.info( f"Finished agent {agent_type!r} run with question {query.query!r} and status" f" {agent_status}." ) - return AnswerResponse(answer=answer, status=agent_status) + return AnswerResponse(session=session, status=agent_status) async def run_fake_agent( @@ -154,7 +154,7 @@ async def run_fake_agent( Callable[[list[Message], float, bool, bool], Awaitable] | None ) = None, **env_kwargs, -) -> tuple[Answer, AgentStatus]: +) -> tuple[PQASession, AgentStatus]: if query.settings.agent.max_timesteps is not None: logger.warning( f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" @@ -165,7 +165,7 @@ async def run_fake_agent( if on_env_reset_callback: await on_env_reset_callback(env.state) - question = env.state.answer.question + question = env.state.session.question search_tool = next(filter(lambda x: x.info.name == PaperSearch.TOOL_FN_NAME, tools)) gather_evidence_tool = next( filter(lambda x: x.info.name == GatherEvidence.TOOL_FN_NAME, tools) @@ -191,7 +191,7 @@ async def step(tool: Tool, **call_kwargs) -> None: await step(search_tool, query=search, min_year=None, max_year=None) await step(gather_evidence_tool, question=question) await step(generate_answer_tool, question=question) - return env.state.answer, AgentStatus.SUCCESS + return env.state.session, AgentStatus.SUCCESS async def run_aviary_agent( @@ -207,7 +207,7 @@ async def run_aviary_agent( Callable[[list[Message], float, bool, bool], Awaitable] | None ) = None, **env_kwargs, -) -> tuple[Answer, AgentStatus]: +) -> tuple[PQASession, AgentStatus]: env = env_class(query, docs, **env_kwargs) done = False @@ -247,7 +247,7 @@ async def run_aviary_agent( await generate_answer_tool._tool_fn( question=query.query, state=env.state ) - return env.state.answer, AgentStatus.TRUNCATED + return env.state.session, AgentStatus.TRUNCATED agent_state.messages += obs for attempt in Retrying( stop=stop_after_attempt(5), @@ -278,7 +278,7 @@ async def run_aviary_agent( except Exception: logger.exception(f"Agent {agent} failed.") status = AgentStatus.FAIL - return env.state.answer, status + return env.state.session, status class LDPRolloutCallback(Callback): @@ -323,7 +323,7 @@ async def run_ldp_agent( ) = None, ldp_callback_type: type[LDPRolloutCallback] = LDPRolloutCallback, **env_kwargs, -) -> tuple[Answer, AgentStatus]: +) -> tuple[PQASession, AgentStatus]: env = env_class(query, docs, **env_kwargs) # NOTE: don't worry about ldp import checks, because we know Settings.make_ldp_agent # has already taken place, which checks that ldp is installed @@ -357,7 +357,7 @@ async def run_ldp_agent( except Exception: logger.exception(f"Agent {agent} failed.") status = AgentStatus.FAIL - return env.state.answer, status + return env.state.session, status async def index_search( diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py index ec6619c8..4bad05d4 100644 --- a/paperqa/agents/models.py +++ b/paperqa/agents/models.py @@ -20,7 +20,7 @@ from paperqa.llms import LiteLLMModel, LLMModel from paperqa.settings import Settings -from paperqa.types import Answer +from paperqa.types import PQASession from paperqa.version import __version__ logger = logging.getLogger(__name__) @@ -85,7 +85,9 @@ def set_docs_name(self, docs_name: str) -> None: class AnswerResponse(BaseModel): - answer: Answer + model_config = ConfigDict(populate_by_name=True) + + session: PQASession = Field(alias="answer") bibtex: dict[str, str] | None = None status: AgentStatus timing_info: dict[str, dict[str, float]] | None = None @@ -94,10 +96,10 @@ class AnswerResponse(BaseModel): # about the answer, such as the number of sources used, etc. stats: dict[str, str] | None = None - @field_validator("answer") + @field_validator("session") def strip_answer( - cls, v: Answer, info: ValidationInfo # noqa: ARG002, N805 - ) -> Answer: + cls, v: PQASession, info: ValidationInfo # noqa: ARG002, N805 + ) -> PQASession: # This modifies in place, this is fine # because when a response is being constructed, # we should be done with the Answer object @@ -114,7 +116,7 @@ async def get_summary(self, llm_model: LLMModel | str = "gpt-4o") -> str: ) result = await model.run_prompt( prompt="{question}\n\n{answer}", - data={"question": self.answer.question, "answer": self.answer.answer}, + data={"question": self.session.question, "answer": self.session.answer}, system_prompt=sys_prompt, ) return result.text.strip() diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index d3530d9b..e02f9f5d 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -46,7 +46,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] read_litqa_v2_from_hub, ) from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel -from paperqa.types import Answer +from paperqa.types import PQASession from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest @@ -69,7 +69,7 @@ def __init__( summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS, evaluation_from_answer: ( - Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None + Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None ) = None, sources: str | list[str] | None = None, rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION, diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 25ba7eaa..7c85f43e 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -12,7 +12,7 @@ from paperqa.docs import Docs from paperqa.llms import EmbeddingModel, LiteLLMModel from paperqa.settings import Settings -from paperqa.types import Answer, DocDetails +from paperqa.types import DocDetails, PQASession from .search import get_directory_index @@ -35,7 +35,7 @@ class EnvironmentState(BaseModel): model_config = ConfigDict(extra="forbid") docs: Docs - answer: Answer + session: PQASession = Field(..., alias="answer") # SEE: https://regex101.com/r/RmuVdC/1 STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = ( @@ -51,18 +51,18 @@ def status(self) -> str: relevant_paper_count=len( { c.text.doc.dockey - for c in self.answer.contexts + for c in self.session.contexts if c.score > self.RELEVANT_SCORE_CUTOFF } ), evidence_count=len( [ c - for c in self.answer.contexts + for c in self.session.contexts if c.score > self.RELEVANT_SCORE_CUTOFF ] ), - cost=self.answer.cost, + cost=self.session.cost, ) @@ -202,16 +202,16 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: ) logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.") - original_question = state.answer.question + original_question = state.session.question try: # Swap out the question with the more specific question # TODO: remove this swap, as it prevents us from supporting parallel calls - state.answer.question = question - l0 = len(state.answer.contexts) + state.session.question = question + l0 = len(state.session.contexts) # TODO: refactor answer out of this... - state.answer = await state.docs.aget_evidence( - query=state.answer, + state.session = await state.docs.aget_evidence( + query=state.session, settings=self.settings, embedding_model=self.embedding_model, summary_llm_model=self.summary_llm_model, @@ -219,14 +219,14 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: f"{self.TOOL_FN_NAME}_aget_evidence" ), ) - l1 = len(state.answer.contexts) + l1 = len(state.session.contexts) finally: - state.answer.question = original_question + state.session.question = original_question status = state.status logger.info(status) sorted_contexts = sorted( - state.answer.contexts, key=lambda x: x.score, reverse=True + state.session.contexts, key=lambda x: x.score, reverse=True ) best_evidence = ( f" Best evidence:\n\n{sorted_contexts[0].context}" @@ -287,8 +287,8 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: # TODO: Should we allow the agent to change the question? # self.answer.question = query - state.answer = await state.docs.aquery( - query=state.answer, + state.session = await state.docs.aquery( + query=state.session, settings=self.settings, llm_model=self.llm_model, summary_llm_model=self.summary_llm_model, @@ -298,13 +298,13 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: ), ) - if state.answer.could_not_answer: + if state.session.could_not_answer: if self.settings.agent.wipe_context_on_answer_failure: - state.answer.contexts = [] - state.answer.context = "" + state.session.contexts = [] + state.session.context = "" answer = self.FAILED_TO_ANSWER else: - answer = state.answer.answer + answer = state.session.answer status = state.status logger.info(status) diff --git a/paperqa/docs.py b/paperqa/docs.py index 7bbe6bc8..ce95ef1b 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -34,13 +34,13 @@ from paperqa.readers import read_doc from paperqa.settings import MaybeSettings, get_settings from paperqa.types import ( - Answer, Doc, DocDetails, DocKey, LLMResult, + PQASession, Text, - set_llm_answer_ids, + set_llm_session_ids, ) from paperqa.utils import ( gather_with_concurrency, @@ -512,13 +512,13 @@ async def retrieve_texts( def get_evidence( self, - query: Answer | str, + query: PQASession | str, exclude_text_filter: set[str] | None = None, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, embedding_model: EmbeddingModel | None = None, summary_llm_model: LLMModel | None = None, - ) -> Answer: + ) -> PQASession: return get_loop().run_until_complete( self.aget_evidence( query=query, @@ -532,26 +532,26 @@ def get_evidence( async def aget_evidence( self, - query: Answer | str, + query: PQASession | str, exclude_text_filter: set[str] | None = None, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, embedding_model: EmbeddingModel | None = None, summary_llm_model: LLMModel | None = None, - ) -> Answer: + ) -> PQASession: evidence_settings = get_settings(settings) answer_config = evidence_settings.answer prompt_config = evidence_settings.prompts - answer = ( - Answer(question=query, config_md5=evidence_settings.md5) + session = ( + PQASession(question=query, config_md5=evidence_settings.md5) if isinstance(query, str) else query ) if not self.docs and len(self.texts_index) == 0: - return answer + return session if embedding_model is None: embedding_model = evidence_settings.get_embedding_model() @@ -560,7 +560,7 @@ async def aget_evidence( summary_llm_model = evidence_settings.get_summary_llm() exclude_text_filter = exclude_text_filter or set() - exclude_text_filter |= {c.text.name for c in answer.contexts} + exclude_text_filter |= {c.text.name for c in session.contexts} _k = answer_config.evidence_k if exclude_text_filter: @@ -570,7 +570,7 @@ async def aget_evidence( if answer_config.evidence_retrieval: matches = await self.retrieve_texts( - answer.question, _k, evidence_settings, embedding_model + session.question, _k, evidence_settings, embedding_model ) else: matches = self.texts @@ -599,13 +599,13 @@ async def aget_evidence( system_prompt=prompt_config.system, ) - with set_llm_answer_ids(answer.id): + with set_llm_session_ids(session.id): results = await gather_with_concurrency( answer_config.max_concurrent_requests, [ map_fxn_summary( text=m, - question=answer.question, + question=session.question, prompt_runner=prompt_runner, extra_prompt_data={ "summary_length": answer_config.evidence_summary_length, @@ -619,20 +619,20 @@ async def aget_evidence( ) for _, llm_result in results: - answer.add_tokens(llm_result) + session.add_tokens(llm_result) - answer.contexts += [r for r, _ in results if r is not None] - return answer + session.contexts += [r for r, _ in results if r is not None] + return session def query( self, - query: Answer | str, + query: PQASession | str, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, summary_llm_model: LLMModel | None = None, embedding_model: EmbeddingModel | None = None, - ) -> Answer: + ) -> PQASession: return get_loop().run_until_complete( self.aquery( query, @@ -646,13 +646,13 @@ def query( async def aquery( # noqa: PLR0912 self, - query: Answer | str, + query: PQASession | str, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, summary_llm_model: LLMModel | None = None, embedding_model: EmbeddingModel | None = None, - ) -> Answer: + ) -> PQASession: query_settings = get_settings(settings) answer_config = query_settings.answer @@ -665,34 +665,34 @@ async def aquery( # noqa: PLR0912 if embedding_model is None: embedding_model = query_settings.get_embedding_model() - answer = ( - Answer(question=query, config_md5=query_settings.md5) + session = ( + PQASession(question=query, config_md5=query_settings.md5) if isinstance(query, str) else query ) - contexts = answer.contexts + contexts = session.contexts if not contexts: - answer = await self.aget_evidence( - answer, + session = await self.aget_evidence( + session, callbacks=callbacks, settings=settings, embedding_model=embedding_model, summary_llm_model=summary_llm_model, ) - contexts = answer.contexts + contexts = session.contexts pre_str = None if prompt_config.pre is not None: - with set_llm_answer_ids(answer.id): + with set_llm_session_ids(session.id): pre = await llm_model.run_prompt( prompt=prompt_config.pre, - data={"question": answer.question}, + data={"question": session.question}, callbacks=callbacks, name="pre", system_prompt=prompt_config.system, ) - answer.add_tokens(pre) + session.add_tokens(pre) pre_str = pre.text # sort by first score, then name @@ -737,13 +737,13 @@ async def aquery( # noqa: PLR0912 "I cannot answer this question due to insufficient information." ) else: - with set_llm_answer_ids(answer.id): + with set_llm_session_ids(session.id): answer_result = await llm_model.run_prompt( prompt=prompt_config.qa, data={ "context": context_str, "answer_length": answer_config.answer_length, - "question": answer.question, + "question": session.question, "example_citation": prompt_config.EXAMPLE_CITATION, }, callbacks=callbacks, @@ -751,7 +751,7 @@ async def aquery( # noqa: PLR0912 system_prompt=prompt_config.system, ) answer_text = answer_result.text - answer.add_tokens(answer_result) + session.add_tokens(answer_result) # it still happens if prompt_config.EXAMPLE_CITATION in answer_text: answer_text = answer_text.replace(prompt_config.EXAMPLE_CITATION, "") @@ -772,30 +772,30 @@ async def aquery( # noqa: PLR0912 answer_text, ) - formatted_answer = f"Question: {answer.question}\n\n{answer_text}\n" + formatted_answer = f"Question: {session.question}\n\n{answer_text}\n" if bib: formatted_answer += f"\nReferences\n\n{bib_str}\n" if prompt_config.post is not None: - with set_llm_answer_ids(answer.id): + with set_llm_session_ids(session.id): post = await llm_model.run_prompt( prompt=prompt_config.post, - data=answer.model_dump(), + data=session.model_dump(), callbacks=callbacks, name="post", system_prompt=prompt_config.system, ) answer_text = post.text - answer.add_tokens(post) - formatted_answer = f"Question: {answer.question}\n\n{post}\n" + session.add_tokens(post) + formatted_answer = f"Question: {session.question}\n\n{post}\n" if bib: formatted_answer += f"\nReferences\n\n{bib_str}\n" # now at end we modify, so we could have retried earlier - answer.answer = answer_text - answer.formatted_answer = formatted_answer - answer.references = bib_str - answer.contexts = contexts - answer.context = context_str + session.answer = answer_text + session.formatted_answer = formatted_answer + session.references = bib_str + session.contexts = contexts + session.context = context_str - return answer + return session diff --git a/paperqa/litqa.py b/paperqa/litqa.py index 631b9d04..6712f473 100644 --- a/paperqa/litqa.py +++ b/paperqa/litqa.py @@ -17,7 +17,7 @@ from paperqa.llms import LiteLLMModel, LLMModel from paperqa.prompts import EVAL_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE from paperqa.settings import make_default_litellm_model_list_settings -from paperqa.types import Answer +from paperqa.types import PQASession if TYPE_CHECKING: import pandas as pd @@ -127,7 +127,7 @@ def from_question( use_unsure: bool = True, eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, seed: int | None = None, - ) -> tuple[str, Callable[[Answer | str], Awaitable[LitQAEvaluation]]]: + ) -> tuple[str, Callable[[PQASession | str], Awaitable[LitQAEvaluation]]]: """ Create a LitQA question and an answer-to-evaluation function. @@ -158,8 +158,8 @@ def from_question( config=make_default_litellm_model_list_settings(eval_model), ) - async def llm_from_answer(answer: Answer | str) -> LitQAEvaluation: - if isinstance(answer, Answer): + async def llm_from_answer(answer: PQASession | str) -> LitQAEvaluation: + if isinstance(answer, PQASession): answer = answer.answer eval_chunk = await eval_model.achat( messages=[ diff --git a/paperqa/settings.py b/paperqa/settings.py index 35cb41b0..f921b6ef 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -302,9 +302,9 @@ def check_select(cls, v: str) -> str: def check_post(cls, v: str | None) -> str | None: if v is not None: # kind of a hack to get list of attributes in answer - from paperqa.types import Answer + from paperqa.types import PQASession - attrs = set(Answer.model_fields.keys()) + attrs = set(PQASession.model_fields.keys()) if not get_formatted_variables(v).issubset(attrs): raise ValueError(f"Post prompt must have input variables: {attrs}") return v diff --git a/paperqa/types.py b/paperqa/types.py index f0107071..8d0e799c 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -4,6 +4,7 @@ import logging import os import re +import warnings from collections.abc import Collection from contextlib import contextmanager from datetime import datetime @@ -40,38 +41,41 @@ logger = logging.getLogger(__name__) # A context var that will be unique to threads/processes -cvar_answer_id = contextvars.ContextVar[UUID | None]("answer_id", default=None) +cvar_session_id = contextvars.ContextVar[UUID | None]("session_id", default=None) @contextmanager -def set_llm_answer_ids(answer_id: UUID): - token = cvar_answer_id.set(answer_id) +def set_llm_session_ids(session_id: UUID): + token = cvar_session_id.set(session_id) try: yield finally: - cvar_answer_id.reset(token) + cvar_session_id.reset(token) class LLMResult(BaseModel): """A class to hold the result of a LLM completion. - To associate a group of LLMResults, you can use the `set_llm_answer_ids` context manager: + To associate a group of LLMResults, you can use the `set_llm_session_ids` context manager: ```python - my_answer_id = uuid4() - with set_llm_answer_ids(my_answer_id): + my_session_id = uuid4() + with set_llm_session_ids(my_session_id): # code that generates LLMResults pass ``` - and all the LLMResults generated within the context will have the same `answer_id`. + and all the LLMResults generated within the context will have the same `session_id`. This can be combined with LLMModels `llm_result_callback` to store all LLMResults. """ + model_config = ConfigDict(populate_by_name=True) + id: UUID = Field(default_factory=uuid4) - answer_id: UUID | None = Field( - default_factory=cvar_answer_id.get, + session_id: UUID | None = Field( + default_factory=cvar_session_id.get, description="A persistent ID to associate a group of LLMResults", + alias="answer_id", ) name: str | None = None prompt: str | list[dict] | None = Field( @@ -149,10 +153,10 @@ def __str__(self) -> str: return self.context -class Answer(BaseModel): - """A class to hold the answer to a question.""" +class PQASession(BaseModel): + """A class to hold session about researching/answering.""" - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="ignore", populate_by_name=True) id: UUID = Field(default_factory=uuid4) question: str @@ -248,6 +252,17 @@ def could_not_answer(self) -> bool: return "cannot answer" in self.answer.lower() +# for backwards compatibility +class Answer(PQASession): + def __init__(self, *args, **kwargs): + warnings.warn( + "The 'Answer' class is deprecated and will be removed in future versions. Use 'PQASession' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + class ChunkMetadata(BaseModel): """Metadata for chunking algorithm.""" diff --git a/tests/conftest.py b/tests/conftest.py index 21af2586..4c01df65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from paperqa.clients.crossref import CROSSREF_HEADER_KEY from paperqa.clients.semantic_scholar import SEMANTIC_SCHOLAR_HEADER_KEY from paperqa.settings import Settings -from paperqa.types import Answer +from paperqa.types import PQASession from paperqa.utils import setup_default_logs TESTS_DIR = Path(__file__).parent @@ -92,8 +92,8 @@ def agent_test_settings(agent_index_dir: Path, stub_data_dir: Path) -> Settings: @pytest.fixture -def agent_stub_answer() -> Answer: - return Answer(question="What is is a self-explanatory model?") +def agent_stub_session() -> PQASession: + return PQASession(question="What is is a self-explanatory model?") @pytest.fixture diff --git a/tests/test_agents.py b/tests/test_agents.py index ed693f88..8bde2ad5 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -44,7 +44,7 @@ ) from paperqa.docs import Docs from paperqa.settings import AgentSettings, IndexSettings, Settings -from paperqa.types import Answer, Context, Doc, Text +from paperqa.types import Context, Doc, PQASession, Text from paperqa.utils import extract_thought, get_year, md5sum @@ -251,20 +251,20 @@ async def test_agent_types( assert ( mock_open.call_count <= 1 ), "Expected one Index.open call, or possibly zero if multiprocessing tests" - assert response.answer.answer, "Answer not generated" - assert response.answer.answer != "I cannot answer", "Answer not generated" - assert response.answer.context, "No contexts were found" - assert response.answer.question == question + assert response.session.answer, "Answer not generated" + assert response.session.answer != "I cannot answer", "Answer not generated" + assert response.session.context, "No contexts were found" + assert response.session.question == question agent_llm = request.settings.agent.agent_llm # TODO: once LDP can track tokens, we can remove this check if agent_type not in {FAKE_AGENT_TYPE, SimpleAgent}: assert ( - response.answer.token_counts[agent_llm][0] > 1000 + response.session.token_counts[agent_llm][0] > 1000 ), "Expected many prompt tokens" assert ( - response.answer.token_counts[agent_llm][1] > 50 + response.session.token_counts[agent_llm][1] > 50 ), "Expected many completion tokens" - assert response.answer.cost > 0, "Expected nonzero cost" + assert response.session.cost > 0, "Expected nonzero cost" @pytest.mark.asyncio @@ -360,7 +360,7 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> ) # ensure that GenerateAnswerTool was called assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout" - assert "I cannot answer" in response.answer.answer + assert "I cannot answer" in response.session.answer @pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"]) @@ -389,7 +389,7 @@ async def test_propagate_options(agent_test_settings: Settings) -> None: ) response = await agent_query(query, agent_type=FAKE_AGENT_TYPE) assert response.status == AgentStatus.SUCCESS, "Agent did not succeed" - result = response.answer + result = response.session assert len(result.answer) > 200, "Answer did not return any results" assert "###" in result.answer, "Answer did not propagate system prompt" assert ( @@ -463,7 +463,7 @@ async def test_agent_sharing_state( agent_test_settings.agent.callbacks = callbacks # type: ignore[assignment] - answer = Answer(question="What is is a self-explanatory model?") + answer = PQASession(question="What is is a self-explanatory model?") query = QueryRequest(query=answer.question, settings=agent_test_settings) env_state = EnvironmentState(docs=Docs(), answer=answer) built_index = await get_directory_index(settings=agent_test_settings) @@ -695,7 +695,7 @@ def test_query_request_docs_name_serialized() -> None: def test_answers_are_striped() -> None: """Test that answers are striped.""" - answer = Answer( + session = PQASession( question="What is the meaning of life?", contexts=[ Context( @@ -715,12 +715,12 @@ def test_answers_are_striped() -> None: ) ], ) - response = AnswerResponse(answer=answer, bibtex={}, status="success") + response = AnswerResponse(session=session, bibtex={}, status="success") - assert response.answer.contexts[0].text.embedding is None - assert response.answer.contexts[0].text.text == "" # type: ignore[unreachable,unused-ignore] - assert response.answer.contexts[0].text.doc is not None - assert response.answer.contexts[0].text.doc.embedding is None + assert response.session.contexts[0].text.embedding is None + assert response.session.contexts[0].text.text == "" # type: ignore[unreachable,unused-ignore] + assert response.session.contexts[0].text.doc is not None + assert response.session.contexts[0].text.doc.embedding is None # make sure it serializes response.model_dump_json() @@ -778,12 +778,12 @@ async def test_deepcopy_env(agent_test_settings: Settings) -> None: ) _, _, done, _ = await env.step(gen_answer_action) assert done - assert not env.state.answer.could_not_answer - assert env.state.answer.used_contexts + assert not env.state.session.could_not_answer + assert env.state.session.used_contexts _, _, done, _ = await env_copy.step(gen_answer_action) assert done - assert not env_copy.state.answer.could_not_answer - assert env_copy.state.answer.used_contexts - assert sorted(env.state.answer.used_contexts) == sorted( - env_copy.state.answer.used_contexts + assert not env_copy.state.session.could_not_answer + assert env_copy.state.session.used_contexts + assert sorted(env.state.session.used_contexts) == sorted( + env_copy.state.session.used_contexts ) diff --git a/tests/test_cli.py b/tests/test_cli.py index f6f4f5b1..50f25259 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -56,10 +56,10 @@ def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path) -> None: response = ask( "How can you use XAI for chemical property prediction?", settings=settings ) - assert response.answer.formatted_answer + assert response.session.formatted_answer search_result = search_query( - " ".join(response.answer.formatted_answer.split()[:5]), + " ".join(response.session.formatted_answer.split()[:5]), "answers", settings, ) diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index fbaae73c..303abc2b 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -19,6 +19,7 @@ DocDetails, Docs, NumpyVectorStore, + PQASession, Settings, Text, print_callback, @@ -507,7 +508,7 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: def test_evidence(docs_fixture) -> None: debug_settings = Settings.from_name("debug") evidence = docs_fixture.get_evidence( - Answer(question="What does XAI stand for?"), + PQASession(question="What does XAI stand for?"), settings=debug_settings, ).contexts assert len(evidence) >= debug_settings.answer.evidence_k @@ -527,7 +528,7 @@ def test_json_evidence(docs_fixture) -> None: " question (integer out of 10)." ) evidence = docs_fixture.get_evidence( - Answer(question="Who wrote this article?"), + PQASession(question="Who wrote this article?"), settings=settings, ).contexts assert evidence[0].author_name @@ -1129,3 +1130,8 @@ def test_case_insensitive_matching(): assert strings_similarity("my test sentence", "My test sentence") == 1.0 assert strings_similarity("a b c d e", "a b c f") == 0.5 assert strings_similarity("A B c d e", "a b c f") == 0.5 + + +def test_answer_rename(): + answer = Answer(question="") + assert isinstance(answer, PQASession)