Skip to content

Commit

Permalink
Feature/sotopia demo UI (#261)
Browse files Browse the repository at this point in the history
* initial

* initial ui

* merge main

* add new ui

* switch to fastAPI

* websocket check

* fix render episode error

* add  page; make a simplified page and still WIP

* [autofix.ci] apply automated fixes

* fix simplified streaming version

* semi-done character page + avatar assets

* Fixed character card styling

* [autofix.ci] apply automated fixes

* unified rendering and chat display

* updated chat character icons

* add some tags

* add typing

* temp fix

* add characters avatar to simulation

* fix episode full avatar

* go to modal config

* clean up code

* add modal streamlit app

* clean codebase except websocket

* remove repeated local css

* clean websocket

* fix get name error

* fix errors

* pre render scenario

* add custom eval

* change streamlit to dynamic path

* new uv

* revert to previous install commands

* a fix for modal

* add customized dimension

* [autofix.ci] apply automated fixes

* sort scenarios in simulation

* for demo video

* update deploy instruction

* update intro page

* update intro page

* [autofix.ci] apply automated fixes

* update intro page

* add customized dimensions

* update api link and modal environment

* move folder

* fix relative import

* update modal image build

* use uv to build environment

* change folder name

* change test

* fix modal serve

* environment change

* refactor

* fix ui

---------

Co-authored-by: Zhe Su <[email protected]>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: astrophie <[email protected]>
  • Loading branch information
4 people authored Dec 31, 2024
1 parent 1f4fb0a commit cb6b2d1
Show file tree
Hide file tree
Showing 33 changed files with 2,921 additions and 805 deletions.
2 changes: 1 addition & 1 deletion docs/pages/contribution/contribution.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Please refer to [Dev Containers](https://containers.dev/supporting#editors) to s

You can also set up the development environment without Dev Containers. There are three things you will need to set up manually:

- Python and uv: Please start from an environment supporting Python 3.10+ and install uv using `pip install uv; uv sync --all-extra`.
- Python and uv: Please start from an environment supporting Python 3.10+ and install uv using `pip install uv; uv sync --all-extras`. (Note that this will install all the extra dependencies)
- Redis: Please refer to introduction page for the set up of Redis.
- Local LLM (optional): If you don't have access to model endpoints (e.g. OpenAI, Anthropic or others), you can use a local model. You can use Ollama, Llama.cpp, vLLM or many others which support OpenAI compatible endpoints.

Expand Down
69 changes: 34 additions & 35 deletions examples/experimental/websocket/websocket_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sotopia.database import EnvironmentProfile, AgentProfile

import asyncio
import websockets
import aiohttp
import sys
from pathlib import Path

Expand All @@ -28,40 +28,39 @@ async def connect(self) -> None:
url_with_token = f"{self.url}?token=test_token_{self.client_id}"

try:
async with websockets.connect(url_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.url}")

# Send initial message
# Note: You'll need to implement the logic to get agent_ids and env_id
# This is just an example structure
agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]]
env_id = EnvironmentProfile.find().all()[0].pk
start_message = {
"type": "START_SIM",
"data": {
"env_id": env_id, # Replace with actual env_id
"agent_ids": agent_ids, # Replace with actual agent_ids
},
}
await websocket.send(json.dumps(start_message))
print(f"Client {self.client_id}: Sent START_SIM message")

# Receive and process messages
while True:
try:
message = await websocket.recv()
print(
f"\nClient {self.client_id} received message:",
json.dumps(json.loads(message), indent=2),
)
assert isinstance(message, str)
await self.save_message(message)
except websockets.ConnectionClosed:
print(f"Client {self.client_id}: Connection closed")
break
except Exception as e:
print(f"Client {self.client_id} error:", str(e))
break
async with aiohttp.ClientSession() as session:
async with session.ws_connect(url_with_token) as ws:
print(f"Client {self.client_id}: Connected to {self.url}")

# Send initial message
# Note: You'll need to implement the logic to get agent_ids and env_id
# This is just an example structure
agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]]
env_id = EnvironmentProfile.find().all()[0].pk
start_message = {
"type": "START_SIM",
"data": {
"env_id": env_id, # Replace with actual env_id
"agent_ids": agent_ids, # Replace with actual agent_ids
},
}
await ws.send_json(start_message)
print(f"Client {self.client_id}: Sent START_SIM message")

# Receive and process messages
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
print(
f"\nClient {self.client_id} received message:",
json.dumps(json.loads(msg.data), indent=2),
)
await self.save_message(msg.data)
elif msg.type == aiohttp.WSMsgType.CLOSED:
print(f"Client {self.client_id}: Connection closed")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
print(f"Client {self.client_id}: Connection error")
break

except Exception as e:
print(f"Client {self.client_id} connection error:", str(e))
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ examples = ["transformers", "datasets", "scipy", "torch", "pandas"]
api = [
"fastapi[standard]",
"uvicorn",
"streamlit",
"websockets",
"modal"
]
test = ["pytest", "pytest-cov", "pytest-asyncio"]

Expand Down
2 changes: 1 addition & 1 deletion scripts/modal/modal_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

import redis
from sotopia.ui.fastapi_server import SotopiaFastAPI
from sotopia.api.fastapi_server import SotopiaFastAPI

# Create persistent volume for Redis data
redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True)
Expand Down
8 changes: 8 additions & 0 deletions sotopia/ui/README.md → sotopia/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
> [!CAUTION]
> Work in progress: the API endpoints are being implemented. And will be released in the future major version.
## Deploy to Modal
First you need to have a Modal account and logged in with `modal setup`

To deploy the FastAPI server to Modal, run the following command:
```bash
cd sotopia/ui/fastapi_server
modal deploy modal_api_server.py
```
## FastAPI Server

To run the FastAPI server, you can use the following command:
Expand Down
File renamed without changes.
61 changes: 7 additions & 54 deletions sotopia/ui/fastapi_server.py → sotopia/api/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
NonStreamingSimulationStatus,
CustomEvaluationDimensionList,
CustomEvaluationDimension,
BaseEnvironmentProfile,
BaseAgentProfile,
BaseRelationshipProfile,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
Expand All @@ -37,7 +40,7 @@
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, model_validator, field_validator, Field

from sotopia.ui.websocket_utils import (
from sotopia.api.websocket_utils import (
WebSocketSotopiaSimulator,
WSMessageType,
ErrorType,
Expand Down Expand Up @@ -68,56 +71,6 @@
] = {} # TODO check whether this is the correct way to store the active simulations


class RelationshipWrapper(BaseModel):
pk: str = ""
agent_1_id: str = ""
agent_2_id: str = ""
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
backstory: str = ""
tag: str = ""


class AgentProfileWrapper(BaseModel):
"""
Wrapper for AgentProfile to avoid pydantic v2 issues
"""

pk: str = ""
first_name: str
last_name: str
age: int = 0
occupation: str = ""
gender: str = ""
gender_pronoun: str = ""
public_info: str = ""
big_five: str = ""
moral_values: list[str] = []
schwartz_personal_values: list[str] = []
personality_and_values: str = ""
decision_making_style: str = ""
secret: str = ""
model_id: str = ""
mbti: str = ""
tag: str = ""


class EnvironmentProfileWrapper(BaseModel):
"""
Wrapper for EnvironmentProfile to avoid pydantic v2 issues
"""

pk: str = ""
codename: str
source: str = ""
scenario: str = ""
agent_goals: list[str] = []
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
age_constraint: str | None = None
occupation_constraint: str | None = None
agent_constraint: list[list[str]] | None = None
tag: str = ""


class CustomEvaluationDimensionsWrapper(BaseModel):
pk: str = ""
name: str = Field(
Expand Down Expand Up @@ -484,23 +437,23 @@ def setup_routes(self) -> None:
)(get_evaluation_dimensions)

@self.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
async def create_scenario(scenario: BaseEnvironmentProfile) -> str:
scenario_profile = EnvironmentProfile(**scenario.model_dump())
scenario_profile.save()
pk = scenario_profile.pk
assert pk is not None
return pk

@self.post("/agents/", response_model=str)
async def create_agent(agent: AgentProfileWrapper) -> str:
async def create_agent(agent: BaseAgentProfile) -> str:
agent_profile = AgentProfile(**agent.model_dump())
agent_profile.save()
pk = agent_profile.pk
assert pk is not None
return pk

@self.post("/relationship/", response_model=str)
async def create_relationship(relationship: RelationshipWrapper) -> str:
async def create_relationship(relationship: BaseRelationshipProfile) -> str:
relationship_profile = RelationshipProfile(**relationship.model_dump())
relationship_profile.save()
pk = relationship_profile.pk
Expand Down
15 changes: 5 additions & 10 deletions sotopia/ui/websocket_utils.py → sotopia/api/websocket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,18 @@ def __init__(

async def arun(self) -> AsyncGenerator[dict[str, Any], None]:
# Use sotopia to run the simulation
generator = arun_one_episode(
generator = await arun_one_episode(
env=self.env,
agent_list=list(self.agents.values()),
push_to_db=False,
streaming=True,
)

# assert isinstance(
# generator, AsyncGenerator
# ), "generator should be async generator, but got {}".format(
# type(generator)
# )
assert isinstance(
generator, AsyncGenerator
), "generator should be async generator, but got {}".format(type(generator))

async for messages in await generator: # type: ignore
async for messages in generator:
reasoning, rewards = "", [0.0, 0.0]
if messages[-1][0][0] == "Evaluation":
reasoning = messages[-1][0][2].to_natural_language()
Expand All @@ -192,9 +190,6 @@ async def arun(self) -> AsyncGenerator[dict[str, Any], None]:
rewards=rewards,
rewards_prompt="",
)
# agent_profiles, parsed_messages = epilog.render_for_humans()
# if not eval_available:
# parsed_messages = parsed_messages[:-2]

yield {
"type": "messages",
Expand Down
14 changes: 13 additions & 1 deletion sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
from redis_om import JsonModel, Migrator
from .annotators import Annotator
from .env_agent_combo_storage import EnvAgentComboStorage
from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus
from .logs import (
AnnotationForEpisode,
BaseEpisodeLog,
NonStreamingSimulationStatus,
EpisodeLog,
)
from .persistent_profile import (
AgentProfile,
BaseAgentProfile,
EnvironmentProfile,
BaseEnvironmentProfile,
BaseRelationshipProfile,
RelationshipProfile,
RelationshipType,
)
Expand Down Expand Up @@ -42,12 +50,16 @@

__all__ = [
"AgentProfile",
"BaseAgentProfile",
"EnvironmentProfile",
"BaseEnvironmentProfile",
"EpisodeLog",
"BaseEpisodeLog",
"NonStreamingSimulationStatus",
"EnvAgentComboStorage",
"AnnotationForEpisode",
"Annotator",
"BaseRelationshipProfile",
"RelationshipProfile",
"RelationshipType",
"RedisCommunicationMixin",
Expand Down
8 changes: 6 additions & 2 deletions sotopia/database/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
else:
from typing_extensions import Self

from pydantic import model_validator
from pydantic import model_validator, BaseModel
from redis_om import JsonModel
from redis_om.model.model import Field
from typing import Literal
Expand All @@ -17,7 +17,7 @@ class NonStreamingSimulationStatus(JsonModel):
status: Literal["Started", "Error", "Completed"]


class EpisodeLog(JsonModel):
class BaseEpisodeLog(BaseModel):
# Note that we did not validate the following constraints:
# 1. The number of turns in messages and rewards should be the same or off by 1
# 2. The agents in the messages are the same as the agetns
Expand Down Expand Up @@ -77,6 +77,10 @@ def render_for_humans(self) -> tuple[list[AgentProfile], list[str]]:
return agent_profiles, messages_and_rewards


class EpisodeLog(BaseEpisodeLog, JsonModel):
pass


class AnnotationForEpisode(JsonModel):
episode: str = Field(index=True, description="the pk id of episode log")
annotator_id: str = Field(index=True, full_text_search=True)
Expand Down
20 changes: 16 additions & 4 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
else:
from typing_extensions import Self

from pydantic import model_validator
from pydantic import model_validator, BaseModel
from redis_om import JsonModel
from redis_om.model.model import Field

Expand All @@ -20,7 +20,7 @@ class RelationshipType(IntEnum):
family_member = 5


class AgentProfile(JsonModel):
class BaseAgentProfile(BaseModel):
first_name: str = Field(index=True)
last_name: str = Field(index=True)
age: int = Field(index=True, default_factory=lambda: 0)
Expand All @@ -43,7 +43,11 @@ class AgentProfile(JsonModel):
)


class EnvironmentProfile(JsonModel):
class AgentProfile(BaseAgentProfile, JsonModel):
pass


class BaseEnvironmentProfile(BaseModel):
codename: str = Field(
index=True,
default_factory=lambda: "",
Expand Down Expand Up @@ -86,7 +90,11 @@ class EnvironmentProfile(JsonModel):
)


class RelationshipProfile(JsonModel):
class EnvironmentProfile(BaseEnvironmentProfile, JsonModel):
pass


class BaseRelationshipProfile(BaseModel):
agent_1_id: str = Field(index=True)
agent_2_id: str = Field(index=True)
relationship: RelationshipType = Field(
Expand All @@ -101,6 +109,10 @@ class RelationshipProfile(JsonModel):
)


class RelationshipProfile(BaseRelationshipProfile, JsonModel):
pass


class EnvironmentList(JsonModel):
name: str = Field(index=True)
environments: list[str] = Field(default_factory=lambda: [])
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/test_fastapi.py → tests/api/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CustomEvaluationDimensionList,
)
from sotopia.messages import SimpleMessage
from sotopia.ui.fastapi_server import app
from sotopia.api.fastapi_server import app
import pytest
from typing import Generator, Callable

Expand Down
Loading

0 comments on commit cb6b2d1

Please sign in to comment.