Skip to content

Commit

Permalink
Renamed to PQASession type (#653)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead authored Oct 30, 2024
1 parent 99a9e07 commit dcb4fd4
Show file tree
Hide file tree
Showing 15 changed files with 169 additions and 145 deletions.
5 changes: 3 additions & 2 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from paperqa.agents import ask # noqa: E402
from paperqa.agents.main import agent_query # noqa: E402
from paperqa.agents.models import QueryRequest # noqa: E402
from paperqa.docs import Answer, Docs, print_callback # noqa: E402
from paperqa.docs import Docs, PQASession, print_callback # noqa: E402
from paperqa.llms import ( # noqa: E402
EmbeddingModel,
HybridEmbeddingModel,
Expand All @@ -23,7 +23,7 @@
embedding_model_factory,
)
from paperqa.settings import Settings, get_settings # noqa: E402
from paperqa.types import Context, Doc, DocDetails, Text # noqa: E402
from paperqa.types import Answer, Context, Doc, DocDetails, Text # noqa: E402
from paperqa.version import __version__ # noqa: E402

__all__ = [
Expand All @@ -39,6 +39,7 @@
"LiteLLMEmbeddingModel",
"LiteLLMModel",
"NumpyVectorStore",
"PQASession",
"QueryRequest",
"SentenceTransformerEmbeddingModel",
"Settings",
Expand Down
8 changes: 4 additions & 4 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.types import PQASession
from paperqa.utils import get_year

from .models import QueryRequest
Expand Down Expand Up @@ -128,7 +128,7 @@ def make_tools(self) -> list[Tool]:
def make_initial_state(self) -> EnvironmentState:
return EnvironmentState(
docs=self._docs,
answer=Answer(
answer=PQASession(
question=self._query.query,
config_md5=self._query.settings.md5,
id=self._query.id,
Expand All @@ -145,7 +145,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
[
Message(
content=self._query.settings.agent.agent_prompt.format(
question=self.state.answer.question,
question=self.state.session.question,
status=self.state.status,
gen_answer_tool_name=GenerateAnswer.TOOL_FN_NAME,
),
Expand All @@ -160,7 +160,7 @@ def export_frame(self) -> Frame:
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
self.state.answer.add_tokens(action) # Add usage for action if present
self.state.session.add_tokens(action) # Add usage for action if present

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
Expand Down
4 changes: 2 additions & 2 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def table_formatter(
table.add_column("Answer", style="magenta")
for obj, _ in objects:
table.add_row(
cast(AnswerResponse, obj).answer.question[:max_chars_per_column],
cast(AnswerResponse, obj).answer.answer[:max_chars_per_column],
cast(AnswerResponse, obj).session.question[:max_chars_per_column],
cast(AnswerResponse, obj).session.answer[:max_chars_per_column],
)
return table
if isinstance(example_object, Docs):
Expand Down
36 changes: 18 additions & 18 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Callback: # type: ignore[no-redef]

from paperqa.docs import Docs
from paperqa.settings import AgentSettings
from paperqa.types import Answer
from paperqa.types import PQASession

from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
Expand Down Expand Up @@ -72,13 +72,13 @@ async def agent_query(
response = await run_agent(docs, query, agent_type, **runner_kwargs)
agent_logger.debug(f"agent_response: {response}")

agent_logger.info(f"[bold blue]Answer: {response.answer.answer}[/bold blue]")
agent_logger.info(f"[bold blue]Answer: {response.session.answer}[/bold blue]")

await answers_index.add_document(
{
"file_location": str(response.answer.id),
"body": response.answer.answer,
"question": response.answer.question,
"file_location": str(response.session.id),
"body": response.session.answer,
"question": response.session.question,
},
document=response,
)
Expand Down Expand Up @@ -120,26 +120,26 @@ async def run_agent(
# Build the index once here, and then all tools won't need to rebuild it
await get_directory_index(settings=query.settings)
if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE:
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
session, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type):
answer, agent_status = await run_aviary_agent(
session, agent_status = await run_aviary_agent(
query, docs, tool_selector_or_none, **runner_kwargs
)
elif ldp_agent_or_none := await query.settings.make_ldp_agent(agent_type):
answer, agent_status = await run_ldp_agent(
session, agent_status = await run_ldp_agent(
query, docs, ldp_agent_or_none, **runner_kwargs
)
else:
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")

if answer.could_not_answer and agent_status != AgentStatus.TRUNCATED:
if session.could_not_answer and agent_status != AgentStatus.TRUNCATED:
agent_status = AgentStatus.UNSURE
# stop after, so overall isn't reported as long-running step.
logger.info(
f"Finished agent {agent_type!r} run with question {query.query!r} and status"
f" {agent_status}."
)
return AnswerResponse(answer=answer, status=agent_status)
return AnswerResponse(session=session, status=agent_status)


async def run_fake_agent(
Expand All @@ -154,7 +154,7 @@ async def run_fake_agent(
Callable[[list[Message], float, bool, bool], Awaitable] | None
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
) -> tuple[PQASession, AgentStatus]:
if query.settings.agent.max_timesteps is not None:
logger.warning(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
Expand All @@ -165,7 +165,7 @@ async def run_fake_agent(
if on_env_reset_callback:
await on_env_reset_callback(env.state)

question = env.state.answer.question
question = env.state.session.question
search_tool = next(filter(lambda x: x.info.name == PaperSearch.TOOL_FN_NAME, tools))
gather_evidence_tool = next(
filter(lambda x: x.info.name == GatherEvidence.TOOL_FN_NAME, tools)
Expand All @@ -191,7 +191,7 @@ async def step(tool: Tool, **call_kwargs) -> None:
await step(search_tool, query=search, min_year=None, max_year=None)
await step(gather_evidence_tool, question=question)
await step(generate_answer_tool, question=question)
return env.state.answer, AgentStatus.SUCCESS
return env.state.session, AgentStatus.SUCCESS


async def run_aviary_agent(
Expand All @@ -207,7 +207,7 @@ async def run_aviary_agent(
Callable[[list[Message], float, bool, bool], Awaitable] | None
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
) -> tuple[PQASession, AgentStatus]:
env = env_class(query, docs, **env_kwargs)
done = False

Expand Down Expand Up @@ -247,7 +247,7 @@ async def run_aviary_agent(
await generate_answer_tool._tool_fn(
question=query.query, state=env.state
)
return env.state.answer, AgentStatus.TRUNCATED
return env.state.session, AgentStatus.TRUNCATED
agent_state.messages += obs
for attempt in Retrying(
stop=stop_after_attempt(5),
Expand Down Expand Up @@ -278,7 +278,7 @@ async def run_aviary_agent(
except Exception:
logger.exception(f"Agent {agent} failed.")
status = AgentStatus.FAIL
return env.state.answer, status
return env.state.session, status


class LDPRolloutCallback(Callback):
Expand Down Expand Up @@ -323,7 +323,7 @@ async def run_ldp_agent(
) = None,
ldp_callback_type: type[LDPRolloutCallback] = LDPRolloutCallback,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
) -> tuple[PQASession, AgentStatus]:
env = env_class(query, docs, **env_kwargs)
# NOTE: don't worry about ldp import checks, because we know Settings.make_ldp_agent
# has already taken place, which checks that ldp is installed
Expand Down Expand Up @@ -357,7 +357,7 @@ async def run_ldp_agent(
except Exception:
logger.exception(f"Agent {agent} failed.")
status = AgentStatus.FAIL
return env.state.answer, status
return env.state.session, status


async def index_search(
Expand Down
14 changes: 8 additions & 6 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from paperqa.llms import LiteLLMModel, LLMModel
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.types import PQASession
from paperqa.version import __version__

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,7 +85,9 @@ def set_docs_name(self, docs_name: str) -> None:


class AnswerResponse(BaseModel):
answer: Answer
model_config = ConfigDict(populate_by_name=True)

session: PQASession = Field(alias="answer")
bibtex: dict[str, str] | None = None
status: AgentStatus
timing_info: dict[str, dict[str, float]] | None = None
Expand All @@ -94,10 +96,10 @@ class AnswerResponse(BaseModel):
# about the answer, such as the number of sources used, etc.
stats: dict[str, str] | None = None

@field_validator("answer")
@field_validator("session")
def strip_answer(
cls, v: Answer, info: ValidationInfo # noqa: ARG002, N805
) -> Answer:
cls, v: PQASession, info: ValidationInfo # noqa: ARG002, N805
) -> PQASession:
# This modifies in place, this is fine
# because when a response is being constructed,
# we should be done with the Answer object
Expand All @@ -114,7 +116,7 @@ async def get_summary(self, llm_model: LLMModel | str = "gpt-4o") -> str:
)
result = await model.run_prompt(
prompt="{question}\n\n{answer}",
data={"question": self.answer.question, "answer": self.answer.answer},
data={"question": self.session.question, "answer": self.session.answer},
system_prompt=sys_prompt,
)
return result.text.strip()
Expand Down
4 changes: 2 additions & 2 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
read_litqa_v2_from_hub,
)
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
from paperqa.types import Answer
from paperqa.types import PQASession

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
Expand All @@ -69,7 +69,7 @@ def __init__(
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None
Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None
) = None,
sources: str | list[str] | None = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
Expand Down
38 changes: 19 additions & 19 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import Answer, DocDetails
from paperqa.types import DocDetails, PQASession

from .search import get_directory_index

Expand All @@ -35,7 +35,7 @@ class EnvironmentState(BaseModel):
model_config = ConfigDict(extra="forbid")

docs: Docs
answer: Answer
session: PQASession = Field(..., alias="answer")

# SEE: https://regex101.com/r/RmuVdC/1
STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = (
Expand All @@ -51,18 +51,18 @@ def status(self) -> str:
relevant_paper_count=len(
{
c.text.doc.dockey
for c in self.answer.contexts
for c in self.session.contexts
if c.score > self.RELEVANT_SCORE_CUTOFF
}
),
evidence_count=len(
[
c
for c in self.answer.contexts
for c in self.session.contexts
if c.score > self.RELEVANT_SCORE_CUTOFF
]
),
cost=self.answer.cost,
cost=self.session.cost,
)


Expand Down Expand Up @@ -202,31 +202,31 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
)

logger.info(f"{self.TOOL_FN_NAME} starting for question {question!r}.")
original_question = state.answer.question
original_question = state.session.question
try:
# Swap out the question with the more specific question
# TODO: remove this swap, as it prevents us from supporting parallel calls
state.answer.question = question
l0 = len(state.answer.contexts)
state.session.question = question
l0 = len(state.session.contexts)

# TODO: refactor answer out of this...
state.answer = await state.docs.aget_evidence(
query=state.answer,
state.session = await state.docs.aget_evidence(
query=state.session,
settings=self.settings,
embedding_model=self.embedding_model,
summary_llm_model=self.summary_llm_model,
callbacks=self.settings.agent.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_evidence"
),
)
l1 = len(state.answer.contexts)
l1 = len(state.session.contexts)
finally:
state.answer.question = original_question
state.session.question = original_question

status = state.status
logger.info(status)
sorted_contexts = sorted(
state.answer.contexts, key=lambda x: x.score, reverse=True
state.session.contexts, key=lambda x: x.score, reverse=True
)
best_evidence = (
f" Best evidence:\n\n{sorted_contexts[0].context}"
Expand Down Expand Up @@ -287,8 +287,8 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:

# TODO: Should we allow the agent to change the question?
# self.answer.question = query
state.answer = await state.docs.aquery(
query=state.answer,
state.session = await state.docs.aquery(
query=state.session,
settings=self.settings,
llm_model=self.llm_model,
summary_llm_model=self.summary_llm_model,
Expand All @@ -298,13 +298,13 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
),
)

if state.answer.could_not_answer:
if state.session.could_not_answer:
if self.settings.agent.wipe_context_on_answer_failure:
state.answer.contexts = []
state.answer.context = ""
state.session.contexts = []
state.session.context = ""
answer = self.FAILED_TO_ANSWER
else:
answer = state.answer.answer
answer = state.session.answer
status = state.status
logger.info(status)

Expand Down
Loading

0 comments on commit dcb4fd4

Please sign in to comment.