Skip to content

Commit

Permalink
Expose env class to run_agent functions (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
mskarlin authored Oct 18, 2024
1 parent e5908d4 commit 0a1167a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def run_agent(
async def run_fake_agent(
query: QueryRequest,
docs: Docs,
env_class: type[PaperQAEnvironment] = PaperQAEnvironment,
on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None,
on_env_step_callback: (
Callable[[list[Message], float, bool, bool], Awaitable] | None
Expand All @@ -151,7 +152,7 @@ async def run_fake_agent(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
" applicable with the fake agent, ignoring it."
)
env = PaperQAEnvironment(query, docs, **env_kwargs)
env = env_class(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)
Expand Down Expand Up @@ -188,6 +189,7 @@ async def run_aviary_agent(
query: QueryRequest,
docs: Docs,
agent: ToolSelector,
env_class: type[PaperQAEnvironment] = PaperQAEnvironment,
on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None,
on_agent_action_callback: (
Callable[[ToolRequestMessage, BaseModel], Awaitable] | None
Expand All @@ -197,7 +199,7 @@ async def run_aviary_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = PaperQAEnvironment(query, docs, **env_kwargs)
env = env_class(query, docs, **env_kwargs)
done = False

try:
Expand Down Expand Up @@ -273,14 +275,15 @@ async def run_ldp_agent(
query: QueryRequest,
docs: Docs,
agent: "Agent[SimpleAgentState]",
env_class: type[PaperQAEnvironment] = PaperQAEnvironment,
on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None,
on_agent_action_callback: "Callable[[OpResult[ToolRequestMessage], SimpleAgentState, float], Awaitable] | None" = None, # noqa: E501
on_env_step_callback: (
Callable[[list[Message], float, bool, bool], Awaitable] | None
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = PaperQAEnvironment(query, docs, **env_kwargs)
env = env_class(query, docs, **env_kwargs)
# NOTE: don't worry about ldp import checks, because we know Settings.make_ldp_agent
# has already taken place, which checks that ldp is installed

Expand Down

0 comments on commit 0a1167a

Please sign in to comment.