Skip to content

Commit

Permalink
fix: Allow evaluation benchmarks to pass image urls in run_controller…
Browse files Browse the repository at this point in the history
…() instead of simply passing strings (#4100)

Co-authored-by: Xingyao Wang <[email protected]>
  • Loading branch information
adityasoni9998 and xingyaoww authored Oct 7, 2024
1 parent 9c07370 commit 0809d26
Show file tree
Hide file tree
Showing 19 changed files with 47 additions and 42 deletions.
3 changes: 2 additions & 1 deletion evaluation/EDA/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction

game = None

Expand Down Expand Up @@ -122,7 +123,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
Expand Down
2 changes: 1 addition & 1 deletion evaluation/agent_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
)
Expand Down
4 changes: 2 additions & 2 deletions evaluation/aider_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

Expand Down Expand Up @@ -211,7 +211,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
)
Expand Down
4 changes: 2 additions & 2 deletions evaluation/biocoder/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

Expand Down Expand Up @@ -285,7 +285,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
Expand Down
2 changes: 1 addition & 1 deletion evaluation/bird/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def execute_sql(db_path, sql):
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
Expand Down
3 changes: 2 additions & 1 deletion evaluation/browsing_delegation/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction

# Only CodeActAgent can delegate to BrowsingAgent
SUPPORTED_AGENT_CLS = {'CodeActAgent'}
Expand Down Expand Up @@ -76,7 +77,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
)
)
Expand Down
2 changes: 1 addition & 1 deletion evaluation/gaia/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
Expand Down
3 changes: 2 additions & 1 deletion evaluation/gorilla/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction

AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
Expand Down Expand Up @@ -83,7 +84,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
Expand Down
2 changes: 1 addition & 1 deletion evaluation/gpqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
Expand Down
4 changes: 2 additions & 2 deletions evaluation/humanevalfix/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

Expand Down Expand Up @@ -237,7 +237,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
Expand Down
2 changes: 1 addition & 1 deletion evaluation/logic_reasoning/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
Expand Down
5 changes: 3 additions & 2 deletions evaluation/miniwob/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ def process_instance(

runtime = create_runtime(config, sid=env_id)
task_str = initialize_runtime(runtime)

state: State | None = asyncio.run(
run_controller(
config=config,
task_str=task_str, # take output from initialize_runtime
initial_user_action=MessageAction(
content=task_str
), # take output from initialize_runtime
runtime=runtime,
)
)
Expand Down
3 changes: 2 additions & 1 deletion evaluation/mint/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import (
CmdRunAction,
MessageAction,
)
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime
Expand Down Expand Up @@ -180,7 +181,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=fake_user_response_fn,
)
Expand Down
4 changes: 2 additions & 2 deletions evaluation/ml_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

Expand Down Expand Up @@ -242,7 +242,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
Expand Down
5 changes: 3 additions & 2 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.runtime import Runtime
Expand Down Expand Up @@ -365,6 +365,7 @@ def process_instance(
logger.info(f'Starting evaluation for instance {instance.instance_id}.')

runtime = create_runtime(config, sid=instance.instance_id)

try:
initialize_runtime(runtime, instance)

Expand All @@ -374,7 +375,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
Expand Down
4 changes: 2 additions & 2 deletions evaluation/toolqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.runtime.runtime import Runtime

Expand Down Expand Up @@ -109,7 +109,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=instruction,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
Expand Down
2 changes: 1 addition & 1 deletion evaluation/webarena/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def process_instance(
state: State | None = asyncio.run(
run_controller(
config=config,
task_str=task_str,
initial_user_action=MessageAction(content=task_str),
runtime=runtime,
)
)
Expand Down
18 changes: 10 additions & 8 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create_runtime(

async def run_controller(
config: AppConfig,
task_str: str,
initial_user_action: Action,
sid: str | None = None,
runtime: Runtime | None = None,
agent: Agent | None = None,
Expand All @@ -96,7 +96,7 @@ async def run_controller(
Args:
config: The app config.
task_str: The task to run. It can be a string.
initial_user_action: An Action object containing initial user input
runtime: (optional) A runtime for the agent to run on.
agent: (optional) A agent to run.
exit_on_message: quit if agent asks for a message from user (optional)
Expand Down Expand Up @@ -146,11 +146,13 @@ async def run_controller(
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())

assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
# Logging
logger.info(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with task: "{task_str}"'
f'{agent.llm.config.model}, with actions: {initial_user_action}'
)

# start event is a MessageAction with the task, either resumed or new
Expand All @@ -166,8 +168,8 @@ async def run_controller(
EventSource.USER,
)
elif initial_state is None:
# init with the provided task
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)

async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
Expand Down Expand Up @@ -224,7 +226,7 @@ def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
task_str = read_task_from_stdin()
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')

initial_user_action: MessageAction = MessageAction(content=task_str)
# Load the app config
# this will load config from config.toml in the current directory
# as well as from the environment variables
Expand Down Expand Up @@ -253,7 +255,7 @@ def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
asyncio.run(
run_controller(
config=config,
task_str=task_str,
initial_user_action=initial_user_action,
sid=sid,
)
)
17 changes: 7 additions & 10 deletions tests/integration/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from openhands.core.config import load_app_config
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events.action import (
AgentFinishAction,
AgentRejectAction,
)
from openhands.events.action import AgentFinishAction, AgentRejectAction, MessageAction
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.runtime import get_runtime_cls
Expand Down Expand Up @@ -90,7 +87,7 @@ def test_write_simple_script(current_test_name: str) -> None:
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."

final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)

Expand Down Expand Up @@ -136,7 +133,7 @@ def test_edits(current_test_name: str):
# Execute the task
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)

Expand All @@ -160,7 +157,7 @@ def test_ipython(current_test_name: str):
# Execute the task
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)

Expand All @@ -185,7 +182,7 @@ def test_simple_task_rejection(current_test_name: str):
# the workspace is not a git repo
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
Expand All @@ -200,7 +197,7 @@ def test_ipython_module(current_test_name: str):
# Execute the task
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)

Expand All @@ -226,7 +223,7 @@ def test_browse_internet(current_test_name: str):
# Execute the task
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(
run_controller(CONFIG, task, exit_on_message=True)
run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
)
validate_final_state(final_state, current_test_name)

Expand Down

0 comments on commit 0809d26

Please sign in to comment.