Skip to content

Commit

Permalink
update clients
Browse files Browse the repository at this point in the history
  • Loading branch information
fern-api[bot] committed Jul 22, 2024
1 parent 14e6339 commit 2c90ded
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 50 deletions.
138 changes: 92 additions & 46 deletions src/hume/empathic_voice/chat/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
from json.decoder import JSONDecodeError
from pathlib import Path

from hume.empathic_voice.chat.types.publish_event import PublishEvent
from hume.empathic_voice.types.pause_assistant_message import PauseAssistantMessage
from hume.empathic_voice.types.resume_assistant_message import ResumeAssistantMessage
from hume.empathic_voice.types.tool_error_message import ToolErrorMessage
from hume.empathic_voice.types.tool_response_message import ToolResponseMessage

from ..chat.types.subscribe_event import SubscribeEvent
from ..types.assistant_input import AssistantInput
from ..types.session_settings import SessionSettings
from ..types.audio_input import AudioInput
from ..types.user_input import UserInput
from ...core.pydantic_utilities import pydantic_v1
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ...core.client_wrapper import AsyncClientWrapper
from ...core.api_error import ApiError


Expand Down Expand Up @@ -43,7 +49,7 @@ def __init__(
self,
*,
websocket: websockets.WebSocketClientProtocol,
params: AsyncChatConnectOptions,
params: typing.Optional[AsyncChatConnectOptions] = None,
):
super().__init__()
self.websocket = websocket
Expand All @@ -56,31 +62,31 @@ async def __aiter__(self):
async for message in self.websocket:
yield message

async def _send(self, data: typing.Any) -> SubscribeEvent:
async def _send(self, data: typing.Any) -> None:
if isinstance(data, dict):
data = json.dumps(data)
await self.websocket.send(data)
# Mimicing the request-reply pattern and waiting for the
# response as soon as we send it
return await self.recv()

async def recv(self) -> SubscribeEvent:
data = await self.websocket.recv()
return pydantic_v1.parse_obj_as(SubscribeEvent, json.loads(data)) # type: ignore

async def send_audio_input(self, message: AudioInput) -> SubscribeEvent:
async def _send_model(self, data: PublishEvent) -> None:
await self._send(data.dict())

async def send_audio_input(self, message: AudioInput) -> None:
"""
Parameters
----------
message : AudioInput
Returns
-------
SubscribeEvent
None
"""
return await self._send(message.dict())
await self._send_model(message)

async def send_session_settings(self, message: SessionSettings) -> SubscribeEvent:
async def send_session_settings(self, message: SessionSettings) -> None:
"""
Update the EVI session settings.
Expand All @@ -90,7 +96,7 @@ async def send_session_settings(self, message: SessionSettings) -> SubscribeEven
Returns
-------
SubscribeEvent
None
"""

# Update sample rate and channels
Expand All @@ -100,31 +106,31 @@ async def send_session_settings(self, message: SessionSettings) -> SubscribeEven
if message.audio.sample_rate is not None:
self._sample_rate = message.audio.sample_rate

return await self._send(message.dict())
await self._send_model(message)

async def send_text_input(self, message: UserInput) -> SubscribeEvent:
async def send_user_input(self, message: UserInput) -> None:
"""
Parameters
----------
message : UserInput
Returns
-------
SubscribeEvent
None
"""
return await self._send(message.dict())
await self._send_model(message)

async def send_assistant_input(self, message: AssistantInput) -> SubscribeEvent:
async def send_assistant_input(self, message: AssistantInput) -> None:
"""
Parameters
----------
message : AssistantInput
Returns
-------
SubscribeEvent
None
"""
return await self._send(message.dict())
await self._send_model(message)

async def send_file(self, filepath: Path) -> None:
"""Send a file over the voice socket.
Expand All @@ -140,29 +146,77 @@ async def send_file(self, filepath: Path) -> None:
audio_bytes = segment.raw_data
await self._send(audio_bytes)

async def send_tool_response(self, message: ToolResponseMessage) -> None:
"""
Parameters
----------
message : ToolResponseMessage
Returns
-------
None
"""
await self._send_model(message)

async def send_tool_error(self, message: ToolErrorMessage) -> None:
"""
Parameters
----------
message : ToolErrorMessage
Returns
-------
None
"""
await self._send_model(message)

async def send_pause_assistant(self, message: PauseAssistantMessage) -> None:
"""
Parameters
----------
message : PauseAssistantMessage
Returns
-------
None
"""
await self._send_model(message)

async def send_resume_assistant(self, message: ResumeAssistantMessage) -> None:
"""
Parameters
----------
message : ResumeAssistantMessage
Returns
-------
None
"""
await self._send_model(message)

class AsyncChatClientWithWebsocket:
def __init__(self, *, client_wrapper: typing.Union[AsyncClientWrapper, SyncClientWrapper]):
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self.client_wrapper = client_wrapper

@asynccontextmanager
async def connect(
self, options: AsyncChatConnectOptions
self, options: typing.Optional[AsyncChatConnectOptions] = None
) -> typing.AsyncIterator[AsyncChatWSSConnection]:
query_params = httpx.QueryParams()

api_key = options.api_key or self.client_wrapper.api_key

if options.config_id is not None:
query_params = query_params.add("config_id", options.config_id)
if options.config_version is not None:
query_params = query_params.add("config_version", options.config_version)

if options.client_secret is not None and api_key is not None:
query_params = query_params.add(
"accessToken",
await self._fetch_access_token(options.client_secret, api_key),
)
api_key = options.api_key if options is not None and options.api_key else self.client_wrapper.api_key

if options is not None:
if options.config_id is not None:
query_params = query_params.add("config_id", options.config_id)
if options.config_version is not None:
query_params = query_params.add("config_version", options.config_version)

if options.client_secret is not None and api_key is not None:
query_params = query_params.add(
"accessToken",
await self._fetch_access_token(options.client_secret, api_key),
)
elif api_key is not None:
query_params = query_params.add("apiKey", api_key)

Expand All @@ -181,20 +235,12 @@ async def connect(
async def _fetch_access_token(self, client_secret: str, api_key: str) -> str:
auth = f"{api_key}:{client_secret}"
encoded_auth = base64.b64encode(auth.encode()).decode()
if isinstance(self.client_wrapper.httpx_client, httpx.AsyncClient):
_response = await self.client_wrapper.httpx_client.request(
method="POST",
url="https://api.hume.ai/oauth2-cc/token",
headers={"Authorization": f"Basic {encoded_auth}"},
data={"grant_type": "client_credentials"},
)
else:
_response = await self.client_wrapper.httpx_client.request( # type: ignore
method="POST",
url="https://api.hume.ai/oauth2-cc/token",
headers={"Authorization": f"Basic {encoded_auth}"},
data={"grant_type": "client_credentials"},
)
_response = await self.client_wrapper.httpx_client.request(
method="POST",
base_url="https://api.hume.ai/oauth2-cc/token",
headers={"Authorization": f"Basic {encoded_auth}"},
data={"grant_type": "client_credentials"},
)

if 200 <= _response.status_code < 300:
return _response.json()["access_token"]
Expand Down
6 changes: 5 additions & 1 deletion src/hume/empathic_voice/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
class EmpathicVoiceClientWithWebsocket(EmpathicVoiceClient):
def __init__(self, *, client_wrapper):
super().__init__(client_wrapper=client_wrapper)
self.chat = AsyncChatClientWithWebsocket(client_wrapper=client_wrapper)

@property
def chat(self):
raise NotImplementedError("The websocket at `.chat` is only available on the `AsyncHumeClient`, not this synchronous client (`HumeClient`).")


class AsyncEmpathicVoiceClientWithWebsocket(AsyncEmpathicVoiceClient):
def __init__(self, *, client_wrapper):
Expand Down
5 changes: 4 additions & 1 deletion src/hume/expression_measurement/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
class ExpressionMeasurementClientWithWebsocket(ExpressionMeasurementClient):
def __init__(self, *, client_wrapper):
super().__init__(client_wrapper=client_wrapper)
self.stream = AsyncStreamClientWithWebsocket(client_wrapper=client_wrapper)

@property
def stream(self):
raise NotImplementedError("The websocket at `.stream` is only available on the `AsyncHumeClient`, not this synchronous client (`HumeClient`).")

class AsyncExpressionMeasurementClientWithWebsocket(AsyncExpressionMeasurementClient):
def __init__(self, *, client_wrapper):
Expand Down
4 changes: 2 additions & 2 deletions src/hume/expression_measurement/stream/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..stream.types.stream_data_models import StreamDataModels
from ..stream.types.subscribe_event import SubscribeEvent
from ...core.pydantic_utilities import pydantic_v1
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ...core.client_wrapper import AsyncClientWrapper


class AsyncStreamConnectOptions(pydantic_v1.BaseModel):
Expand Down Expand Up @@ -170,7 +170,7 @@ async def send_file(


class AsyncStreamClientWithWebsocket:
def __init__(self, *, client_wrapper: typing.Union[AsyncClientWrapper, SyncClientWrapper]):
def __init__(self, *, client_wrapper: AsyncClientWrapper):
self.client_wrapper = client_wrapper

@asynccontextmanager
Expand Down

0 comments on commit 2c90ded

Please sign in to comment.