Skip to content

Commit

Permalink
feat: enable evaluation of EpisodeLog
Browse files Browse the repository at this point in the history
  • Loading branch information
JXZhou authored and JXZhou committed Jan 2, 2025
1 parent 396db1a commit 448293f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
3 changes: 2 additions & 1 deletion examples/experimental/sotopia_original_replica/origin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ input_channels = ["Jane:moderator", "Jack:moderator"]
agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"}
scenario = "Two friends are sitting in a cafe and catching up with each other's lives."
max_turns = 2
push_to_db = true
push_to_db = false
will_eval = true

[[nodes]]
node_name = "Jack"
Expand Down
25 changes: 25 additions & 0 deletions sotopia/experimental/agents/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from .logs import EpisodeLog


class BaseEvaluator(ABC):
def __init__(self):
pass

@abstractmethod
def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
pass


class DummyEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()

def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
return 0.0, "No evaluation implemented"
33 changes: 26 additions & 7 deletions sotopia/experimental/agents/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .datamodels import AgentAction, Observation
from .logs import EpisodeLog, AgentProfile
from .evaluators import DummyEvaluator
from sotopia.messages import ActionType


Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(
],
max_turns: int = 20,
push_to_db: bool = False,
will_eval: bool = False,
evaluator: str = "DummyEvaluator",
):
super().__init__(
input_channel_types=[
Expand All @@ -55,7 +58,7 @@ def __init__(
(output_channel, Observation) for output_channel in output_channels
],
redis_url=redis_url,
node_name=node_name
node_name=node_name,
)
self.observation_queue: asyncio.Queue[AgentAction] = asyncio.Queue()
self.task_scheduler: asyncio.Task[None] | None = None
Expand All @@ -77,7 +80,9 @@ def __init__(
self.message_history: list[tuple[str, str, str]] = [
("Environment", "Environment", self.scenario)
]
self.push_to_db = push_to_db
self.push_to_db: bool = push_to_db
self.will_eval: bool = will_eval
self.evaluator: str = evaluator

if self.action_order == "round-robin":
pass
Expand Down Expand Up @@ -178,15 +183,29 @@ async def booting(self) -> None:
self.current_agent_index += 1

async def wrap_up_and_stop(self) -> None:
if self.push_to_db:
await self.save()
epilog = await self.save()
if self.will_eval:
epilog = await self.eval(epilog)
await asyncio.sleep(0.5)
print("stopping all agents")
print("result of this episode:\n", epilog)
await self.r.publish(
"shutdown:moderator",
"shutdown",
)

async def eval(self, epilog: EpisodeLog) -> EpisodeLog:
"""
evaluate the episode
"""
if self.evaluator == "DummyEvaluator":
evaluator = DummyEvaluator()
reward, reward_prompt = evaluator.evaluate(epilog)
epilog.rewards = [reward]
epilog.rewards_prompt = reward_prompt
if self.push_to_db:
epilog.save()
return epilog

async def save(self) -> EpisodeLog:
"""
save the EpisodeLog to redis
Expand All @@ -200,8 +219,8 @@ async def save(self) -> EpisodeLog:
rewards=None,
rewards_prompt=None,
)
epilog.save()
print(epilog.model_dump_json(indent=2))
if self.push_to_db:
epilog.save()
return epilog

async def aact(self, agent_action: AgentAction) -> Observations | None:
Expand Down
10 changes: 0 additions & 10 deletions test.py

This file was deleted.

0 comments on commit 448293f

Please sign in to comment.