Skip to content

Commit

Permalink
Making hide_old_env_states accessible to standard agents (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Nov 6, 2024
1 parent 3790617 commit f22decf
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 8 deletions.
9 changes: 8 additions & 1 deletion ldp/agent/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ class ReActAgent(BaseModel, Agent[SimpleAgentState]):
),
)

hide_old_env_states: bool = Field(
default=False,
description="See SimpleAgentState.hide_old_env_states.",
)

@classmethod
def make_act_agent(cls, **kwargs) -> Self:
single_prompt = kwargs.pop("single_prompt", False)
Expand Down Expand Up @@ -139,7 +144,9 @@ def __init__(self, **kwargs):
)

async def init_state(self, tools: list[Tool]) -> SimpleAgentState:
return SimpleAgentState(tools=tools)
return SimpleAgentState(
tools=tools, hide_old_env_states=self.hide_old_env_states
)

@staticmethod
def after_retry_failure_log(retry_state: RetryCallState):
Expand Down
25 changes: 21 additions & 4 deletions ldp/agent/simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ class SimpleAgentState(BaseModel):
messages: list[ToolRequestMessage | ToolResponseMessage | Message] = Field(
default_factory=list
)
hide_old_env_states: bool = Field(
default=False,
description="Whether to hide old EnvStateMessages.",
)

def get_next_state(
self,
obs: list[Message] | None = None,
tools: list[Tool] | None = None,
hide_old_env_states: bool = False,
hide_old_env_states: bool | None = None,
**kwargs,
) -> Self:
"""
Expand All @@ -41,14 +45,19 @@ def get_next_state(
obs: Optional observation messages to use in creating the next state.
tools: Optional list of tools available to the agent. If unspecified, these
should be pulled from the prior_state.
hide_old_env_states: Whether to hide old environment states in the messages.
This is useful for reducing context window usage.
hide_old_env_states: Optional override of self.hide_old_env_states.
kwargs: Additional keyword arguments to pass to this class's constructor.
Returns:
The next agent state (which is not an in-place change to self).
"""
old_messages = self.messages

hide_old_env_states = (
hide_old_env_states
if hide_old_env_states is not None
else self.hide_old_env_states
)
if hide_old_env_states:
old_messages = [
HiddenEnvStateMessage() if isinstance(m, EnvStateMessage) else m
Expand All @@ -58,6 +67,7 @@ def get_next_state(
return type(self)(
tools=tools if tools is not None else self.tools,
messages=old_messages + (obs or []),
hide_old_env_states=hide_old_env_states,
**kwargs,
)

Expand Down Expand Up @@ -89,13 +99,20 @@ class SimpleAgent(BaseModel, Agent[SimpleAgentState]):
),
)

hide_old_env_states: bool = Field(
default=False,
description="See SimpleAgentState.hide_old_env_states.",
)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._config_op = ConfigOp[dict](config=self.llm_model)
self._llm_call_op = LLMCallOp()

async def init_state(self, tools: list[Tool]) -> SimpleAgentState:
return SimpleAgentState(tools=tools)
return SimpleAgentState(
tools=tools, hide_old_env_states=self.hide_old_env_states
)

@compute_graph()
async def get_asv(
Expand Down
8 changes: 7 additions & 1 deletion ldp/agent/tree_of_thoughts_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ class TreeofThoughtsAgent(BaseModel, Agent[SimpleAgentState]):
default=lambda x, y: f"Proposal prompt for input: {x}, current path: {y}",
description="Function to format proposal prompt template.",
)
hide_old_env_states: bool = Field(
default=False,
description="See SimpleAgentState.hide_old_env_states.",
)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._prepend_op = FxnOp(prepend_sys)
self._llm_call_op = LLMCallOp()

async def init_state(self, tools: list[Tool]) -> SimpleAgentState:
return SimpleAgentState(tools=tools)
return SimpleAgentState(
tools=tools, hide_old_env_states=self.hide_old_env_states
)

@compute_graph()
async def get_asv( # type: ignore[override]
Expand Down
22 changes: 20 additions & 2 deletions ldp/graph/common_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,26 @@ async def forward(
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = LLMModel.TOOL_CHOICE_REQUIRED,
) -> Message:
"""Calls the LLM.
Args:
config: Configuration passed to LLMModel.
msgs: Input messages to prompt model with.
tools: A list of Tools that the model may call, if supported.
tool_choice: Configures how the model should choose a tool.
Can be a Tool or a string; see here for string options:
https://platform.openai.com/docs/guides/function-calling#configuring-function-calling-behavior-using-the-tool_choice-parameter
NOTE: if `tools` is None or empty, this parameter is ignored.
Returns:
Output message from the model.
"""
model = LLMModel(config=config)

if not tools:
# if no tools are provided, tool_choice must be 'none'
tool_choice = "none"

result = await model.call(messages=msgs, tools=tools, tool_choice=tool_choice)
if result.messages is None:
raise ValueError("No messages returned")
Expand Down Expand Up @@ -284,8 +302,8 @@ async def compute_logprob(
if raw_log_p is None or self.num_samples_partition_estimate == 0:
return None

# TODO: Try using n completions from a single API call. Need to modify LLMModel.call to do this, since
# it currently only checks completion.choices[0]. Would reduce cost for long prompts.
# TODO: possibly move to MultipleCompletionLLMModel here, though we need to check that the estimates
# are consistent - not sure we'd be sampling from the same distribution as N independent samples.
# TODO: think about whether sampling params besides temperature need to be accounted for, like top_p
results = await asyncio.gather(*[
model.call(temperature=1, **model_kwargs)
Expand Down
106 changes: 106 additions & 0 deletions tests/cassettes/TestSimpleAgent.test_hide_old_env_states.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
interactions:
- request:
body:
'{"messages": [{"role": "user", "content": ""}], "model": "gpt-4o-2024-08-06",
"temperature": 0.1}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "97"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.52.2
x-stainless-arch:
- x64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.52.2
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "0"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.11.0rc1
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA4xSwWrcMBC9+yumOq+Ld+PsdvdSQhNIQgkEAjmUYhRpbKuVNUIat92E/fcge3e9
oSn0osN7857eG+YlAxBGiw0I1UpWnbf5xT1+/XJ59/ywmi/Kurs4C4+3Uf9qG3OlpZglBT39QMUH
1UdFnbfIhtxIq4CSMbnOV2fFel6ul4uB6EijTbLGc15SvigWZV58yovlXtiSURjFBr5lAAAvw5si
Oo1/xAaK2QHpMEbZoNgchwBEIJsQIWM0kaVjMZtIRY7RDamv0Vr6ANf0G5R0cAOjALbUA5OW28+n
woB1H2XK7Xpr9/jumMRS4wM9xT1/xGvjTGyrgDKSS79GJi8GdpcBfB8a929KCB+o81wx/USXDFej
m5g2PHHrPcfE0k7wfDl7x6vSyNLYeLIvoaRqUU/Kabmy14ZOiOyk8d9Z3vMeWxvX/I/9RCiFnlFX
PqA26m3faSxgOr9/jR03PAQWcRsZu6o2rsHggxkvoPZVea7q81KjRJHtslcAAAD//wMAhkW/KQoD
AAA=
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8de6e1adf9fbfb40-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 06 Nov 2024 17:42:42 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=IV1Bcm.OiOhx5Fueebsi7W6vfkykRRtCudT2Nx2Si.8-1730914962-1.0.1.1-gOxkIRQ2IVoWcd1Zz84mSVVwBHV5.OKb837EPZuheau5Qz5mzTUKUZbirerOuWHYGvTe3mf9h3oXoV6SISWa2Q;
path=/; expires=Wed, 06-Nov-24 18:12:42 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=Cu15wqjKcy0EqDJ_H8EAhxyN9YVdxEDkGLA6vmynwxo-1730914962519-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "350"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "10000"
x-ratelimit-limit-tokens:
- "30000000"
x-ratelimit-remaining-requests:
- "9999"
x-ratelimit-remaining-tokens:
- "29999983"
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_654554dd6c32e59d7f3e4e0ff5dafa77
status:
code: 200
message: OK
version: 1
28 changes: 28 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import networkx as nx
import pytest
from aviary.core import DummyEnv, Message, Tool, ToolCall, ToolRequestMessage
from aviary.message import EnvStateMessage
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel, Field

Expand All @@ -24,6 +25,7 @@
make_simple_agent_server,
)
from ldp.agent.interactive_agent import InteractiveAgent
from ldp.agent.simple_agent import HiddenEnvStateMessage
from ldp.alg import to_network
from ldp.graph import LLMCallOp, Memory, OpResult, eval_mode
from ldp.graph.gradient_estimators import llm_straight_through_estimator as llm_ste
Expand Down Expand Up @@ -135,6 +137,7 @@ async def test_dummyenv(self, dummy_env: DummyEnv, model_name: str) -> None:
# Check serialization after get_asv runs to ensure private
# Ops aren't included
assert agent.model_dump() == {
"hide_old_env_states": False,
"llm_model": {"model": model_name, "temperature": 0.1},
"sys_prompt": None,
}
Expand Down Expand Up @@ -201,6 +204,31 @@ async def test_agent_grad(self, dummy_env: DummyEnv, model_name: str) -> None:
output_dir / f"TestSimpleAgent.test_agent_grad.{model_name}.png",
)

@pytest.mark.asyncio
@pytest.mark.vcr
async def test_hide_old_env_states(self) -> None:
agent = SimpleAgent(hide_old_env_states=True)
agent_state_0 = await agent.init_state(tools=[])

_, agent_state_1, _ = await agent.get_asv(
agent_state_0, [EnvStateMessage(content="")]
)
_, agent_state_2, _ = await agent.get_asv(
agent_state_1, [EnvStateMessage(content="")]
)

# EnvStateMessage, model response
assert len(agent_state_1.messages) == 2
# as above + EnvStateMessage, model response
assert len(agent_state_2.messages) == 4

assert isinstance(agent_state_2.messages[0], HiddenEnvStateMessage)
assert agent_state_1.messages[1].content == agent_state_2.messages[1].content

# Check that the second EnvStateMessage didn't get hidden
assert isinstance(agent_state_2.messages[2], EnvStateMessage)
assert not isinstance(agent_state_2.messages[2], HiddenEnvStateMessage)


class TestMemoryAgent:
# # On 5/14/2024, claude 3 opus would not follow its past memories
Expand Down

0 comments on commit f22decf

Please sign in to comment.