diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 3ab3c7164..95b39e87a 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -140,6 +140,7 @@ async def run_agent( async def run_fake_agent( query: QueryRequest, docs: Docs, + env_class: type[PaperQAEnvironment] = PaperQAEnvironment, on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None, on_env_step_callback: ( Callable[[list[Message], float, bool, bool], Awaitable] | None @@ -151,7 +152,7 @@ async def run_fake_agent( f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" " applicable with the fake agent, ignoring it." ) - env = PaperQAEnvironment(query, docs, **env_kwargs) + env = env_class(query, docs, **env_kwargs) _, tools = await env.reset() if on_env_reset_callback: await on_env_reset_callback(env.state) @@ -188,6 +189,7 @@ async def run_aviary_agent( query: QueryRequest, docs: Docs, agent: ToolSelector, + env_class: type[PaperQAEnvironment] = PaperQAEnvironment, on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None, on_agent_action_callback: ( Callable[[ToolRequestMessage, BaseModel], Awaitable] | None @@ -197,7 +199,7 @@ async def run_aviary_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = PaperQAEnvironment(query, docs, **env_kwargs) + env = env_class(query, docs, **env_kwargs) done = False try: @@ -273,6 +275,7 @@ async def run_ldp_agent( query: QueryRequest, docs: Docs, agent: "Agent[SimpleAgentState]", + env_class: type[PaperQAEnvironment] = PaperQAEnvironment, on_env_reset_callback: Callable[[EnvironmentState], Awaitable] | None = None, on_agent_action_callback: "Callable[[OpResult[ToolRequestMessage], SimpleAgentState, float], Awaitable] | None" = None, # noqa: E501 on_env_step_callback: ( @@ -280,7 +283,7 @@ async def run_ldp_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = PaperQAEnvironment(query, docs, **env_kwargs) + env = env_class(query, docs, **env_kwargs) # NOTE: don't worry about ldp import checks, because we know Settings.make_ldp_agent # has already taken place, which checks that ldp is installed