Skip to content

Commit

Permalink
Added main entry point (#109)
Browse files Browse the repository at this point in the history
Basic entry point for running agents with an environment for tasks not derived from the task library.
  • Loading branch information
whitead authored Oct 22, 2024
1 parent 6c93bd5 commit 17d0408
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 302 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ repos:
- id: mypy
additional_dependencies:
- fastapi>=0.109 # Match pyproject.toml
- fhaviary>=0.6 # Match pyproject.toml
- fhaviary>=0.8 # Match pyproject.toml
- httpx
- litellm>=1.49.3 # Match pyproject.toml
- numpy>=1.20 # Match pyproject.toml
Expand Down
6 changes: 5 additions & 1 deletion ldp/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def named_ops(self) -> Iterable[tuple[str, Op]]:
"""Analogous to torch.nn.Module.named_parameters()."""
return _find_ops(self)

@classmethod
def from_name(cls, name: str, **kwargs) -> Agent:
return _AGENT_REGISTRY[name](**kwargs)


class AgentConfig(BaseModel):
"""Configuration for specifying the type of agent i.e. the subclass of Agent above."""
Expand All @@ -96,7 +100,7 @@ class AgentConfig(BaseModel):
)

def construct_agent(self) -> Agent:
return _AGENT_REGISTRY[self.agent_type](**self.agent_kwargs)
return Agent.from_name(self.agent_type, **self.agent_kwargs)

def __hash__(self) -> int:
return hash(self.agent_type + json.dumps(self.agent_kwargs, sort_keys=True))
Expand Down
52 changes: 52 additions & 0 deletions ldp/alg/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,55 @@ async def after_eval_loop(self) -> None:
await super().after_eval_loop() # Call the parent to compute means
if self.eval_means:
self._log_filtered_metrics(self.eval_means, step_type="Eval")


class TerminalPrintingCallback(Callback):
"""Callback that prints action, observation, and timing information to the terminal."""

def __init__(self):
self.start_time = None
# try now, rather than start running and die
try:
from rich.pretty import pprint # noqa: F401
except ImportError as e:
raise ImportError(
f"rich is required for {type(self).__name__}. Please install it with `pip install rich`."
) from e

async def before_transition(
self,
traj_id: str,
agent: Agent,
env: Environment,
agent_state: Any,
obs: list[Message],
) -> None:
"""Start the timer before each transition."""
self.start_time = time.time()

async def after_agent_get_asv(
self,
traj_id: str,
action: OpResult[ToolRequestMessage],
next_agent_state: Any,
value: float,
) -> None:
from rich.pretty import pprint

print("\nAction:")
pprint(action.value, expand_all=True)

async def after_env_step(
self, traj_id: str, obs: list[Message], reward: float, done: bool, trunc: bool
) -> None:
from rich.pretty import pprint

# Compute elapsed time
if self.start_time is not None:
elapsed_time = time.time() - self.start_time
self.start_time = None # Reset timer
else:
elapsed_time = 0.0
print("\nObservation:")
pprint(obs, expand_all=True)
print(f"Elapsed time: {elapsed_time:.2f} seconds")
1 change: 1 addition & 0 deletions ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def from_jsonl(cls, filename: str | os.PathLike) -> Self:
return traj

def compute_discounted_returns(self, discount: float = 1.0) -> list[float]:
"""Compute the discounted returns for each step in the trajectory."""
return discounted_returns(
rewards=[step.reward for step in self.steps],
terminated=[step.truncated for step in self.steps],
Expand Down
70 changes: 70 additions & 0 deletions ldp/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
import asyncio
import pickle
from contextlib import suppress
from os import PathLike
from pathlib import Path

from aviary.env import Environment

from ldp.agent import Agent
from ldp.alg.callbacks import TerminalPrintingCallback
from ldp.alg.rollout import RolloutManager


def get_or_make_agent(agent: Agent | str | PathLike) -> Agent:
if isinstance(agent, Agent):
return agent

if isinstance(agent, str):
with suppress(KeyError):
return Agent.from_name(agent)

path = Path(agent)
if not path.exists():
raise ValueError(f"Could not resolve agent: {agent}")

with path.open("rb") as f:
return pickle.load(f) # noqa: S301


def get_or_make_environment(environment: Environment | str, task: str) -> Environment:
if isinstance(environment, Environment):
return environment

if isinstance(environment, str):
with suppress(KeyError):
return Environment.from_name(environment, task=task)

raise ValueError(
f"Could not resolve environment: {environment}. Available environments: {Environment.available()}"
)


async def main(
task: str,
environment: Environment | str,
agent: Agent | str | PathLike = "SimpleAgent",
):
agent = get_or_make_agent(agent)

callback = TerminalPrintingCallback()
rollout_manager = RolloutManager(agent=agent, callbacks=[callback])

_ = await rollout_manager.sample_trajectories(
environment_factory=lambda: get_or_make_environment(environment, task)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("task", help="Task to prompt environment with.")
parser.add_argument(
"--env", required=True, help="Environment to sample trajectories from."
)
parser.add_argument(
"--agent", default="SimpleAgent", help="Agent to sample trajectories with."
)
args = parser.parse_args()

asyncio.run(main(args.task, args.env, args.agent))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
dependencies = [
"aiofiles",
"dm-tree",
"fhaviary>=0.6", # For MalformedMessageError
"fhaviary>=0.8", # For from_task
"httpx",
"litellm>=1.40.15", # For LITELLM_LOG addition
"networkx[default]~=3.4", # Pin for pydot fix
Expand Down
Loading

0 comments on commit 17d0408

Please sign in to comment.