Skip to content

Commit

Permalink
Websocket client for streaming API (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorybchris authored Oct 6, 2022
1 parent 656dbaa commit 0712d71
Show file tree
Hide file tree
Showing 29 changed files with 560 additions and 38 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
if [ -d /poetryenvs ]; then rm -rf ~/poetryenvs; fi
poetry config virtualenvs.path ~/poetryenvs
poetry install
poetry install -E stream
- name: Run flake8
shell: bash
Expand Down Expand Up @@ -99,7 +99,7 @@ jobs:
if [ -d /poetryenvs ]; then rm -rf ~/poetryenvs; fi
poetry config virtualenvs.path ~/poetryenvs
poetry install
poetry install -E stream
- name: Run pytest
shell: bash
Expand Down
33 changes: 31 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@ Python versions 3.8 and 3.9 are supported

## Installation

```python
pip install hume
Basic installation:

```bash
$ pip install hume
```

Websocket and streaming features can be enabled with:

```bash
$ pip install hume[stream]
```

## Basic Usage
Expand Down Expand Up @@ -45,6 +53,27 @@ job = client.get_job(job_id)
print(job)
```

### Stream predictions over a websocket

> Note: `pip install hume[stream]` is required to use websocket features
```python
import asyncio

from hume import HumeStreamClient, StreamSocket
from hume.config import FaceConfig

async def main():
client = HumeStreamClient("<your-api-key>")
configs = [FaceConfig(identify_faces=True)]
async with client.connect(configs) as socket:
socket: StreamSocket
result = await socket.send_file("<your-image-filepath>")
print(result)

asyncio.run(main())
```

## Other Resources

- [Hume AI Homepage](https://hume.ai)
Expand Down
1 change: 1 addition & 0 deletions docs/stream/hume-stream-client.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: hume._stream.hume_stream_client.HumeStreamClient
1 change: 1 addition & 0 deletions docs/stream/stream-socket.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: hume._stream.stream_socket.StreamSocket
3 changes: 3 additions & 0 deletions hume/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.metadata

from hume._batch import BatchJob, BatchJobResult, BatchJobStatus, HumeBatchClient
from hume._stream import HumeStreamClient, StreamSocket
from hume._common.hume_client_error import HumeClientError
from hume._common.model_type import ModelType

Expand All @@ -14,5 +15,7 @@
"BatchJobStatus",
"HumeBatchClient",
"HumeClientError",
"HumeStreamClient",
"ModelType",
"StreamSocket",
]
23 changes: 20 additions & 3 deletions hume/_batch/hume_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@


class HumeBatchClient(ClientBase):
"""Batch API client."""
"""Batch API client.
Example:
```python
from hume import HumeBatchClient
client = HumeBatchClient("<your-api-key>")
job = client.submit_face(["<your-image-url>"])
print(job)
print("Running...")
result = job.await_complete()
result.download_predictions("predictions.json")
print("Predictions downloaded!")
```
"""

_DEFAULT_API_TIMEOUT = 10

Expand All @@ -40,7 +57,7 @@ def get_job_result(self, job_id: str) -> BatchJobResult:
Returns:
BatchJobResult: Batch job result.
"""
endpoint = (f"{self._api_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs/{job_id}"
endpoint = (f"{self._api_http_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs/{job_id}"
f"?apikey={self._api_key}")
response = requests.get(endpoint, timeout=self._DEFAULT_API_TIMEOUT)
body = response.json()
Expand Down Expand Up @@ -178,7 +195,7 @@ def start_job(self, request_body: Any) -> BatchJob:
Returns:
BatchJob: A `BatchJob` that wraps the batch computation.
"""
endpoint = (f"{self._api_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs"
endpoint = (f"{self._api_http_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs"
f"?apikey={self._api_key}")
response = requests.post(endpoint, json=request_body, timeout=self._DEFAULT_API_TIMEOUT)

Expand Down
1 change: 1 addition & 0 deletions hume/_common/api_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class ApiType(Enum):
"""API type."""

BATCH = "batch"
STREAM = "stream"
9 changes: 6 additions & 3 deletions hume/_common/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
class ClientBase(ABC):
"""Base class for Hume API clients."""

_API_BASE_URL = "https://api.hume.ai"
_HTTP_BASE_URL = "https://api.hume.ai"
_WS_BASE_URI = "wss://api.hume.ai"

def __init__(
self,
api_key: str,
_api_version: str = "v0",
_api_base_url: Optional[str] = None,
_api_http_base_url: Optional[str] = None,
_api_ws_base_uri: Optional[str] = None,
):
"""Construct a new Hume API client.
Expand All @@ -21,4 +23,5 @@ def __init__(
"""
self._api_key = api_key
self._api_version = _api_version
self._api_base_url = self._API_BASE_URL if _api_base_url is None else _api_base_url
self._api_http_base_url = self._HTTP_BASE_URL if _api_http_base_url is None else _api_http_base_url
self._api_ws_base_uri = self._WS_BASE_URI if _api_ws_base_uri is None else _api_ws_base_uri
9 changes: 5 additions & 4 deletions hume/_common/config/face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class FaceConfig(JobConfigBase["FaceConfig"]):

def __init__(
self,
*,
fps_pred: Optional[float] = None,
prob_threshold: Optional[float] = None,
identify_faces: Optional[bool] = None,
Expand Down Expand Up @@ -67,8 +68,8 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "FaceConfig":
FaceConfig: Deserialized `FaceConfig` object.
"""
return cls(
fps_pred=request_dict["fps_pred"],
prob_threshold=request_dict["prob_threshold"],
identify_faces=request_dict["identify_faces"],
min_face_size=request_dict["min_face_size"],
fps_pred=request_dict.get("fps_pred"),
prob_threshold=request_dict.get("prob_threshold"),
identify_faces=request_dict.get("identify_faces"),
min_face_size=request_dict.get("min_face_size"),
)
4 changes: 2 additions & 2 deletions hume/_common/config/language_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "LanguageConfig":
LanguageConfig: Deserialized `LanguageConfig` object.
"""
return cls(
sliding_window=request_dict["sliding_window"],
identify_speakers=request_dict["identify_speakers"],
sliding_window=request_dict.get("sliding_window"),
identify_speakers=request_dict.get("identify_speakers"),
)
3 changes: 2 additions & 1 deletion hume/_common/config/prosody_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ProsodyConfig(JobConfigBase["ProsodyConfig"]):

def __init__(
self,
*,
identify_speakers: Optional[bool] = None,
):
"""Construct a `ProsodyConfig`.
Expand Down Expand Up @@ -50,4 +51,4 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "ProsodyConfig":
Returns:
ProsodyConfig: Deserialized `ProsodyConfig` object.
"""
return cls(identify_speakers=request_dict["identify_speakers"])
return cls(identify_speakers=request_dict.get("identify_speakers"))
8 changes: 8 additions & 0 deletions hume/_stream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Module init."""
from hume._stream.hume_stream_client import HumeStreamClient
from hume._stream.stream_socket import StreamSocket

__all__ = [
"HumeStreamClient",
"StreamSocket",
]
91 changes: 91 additions & 0 deletions hume/_stream/hume_stream_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Streaming API client."""
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, List

from hume._common.api_type import ApiType
from hume._common.client_base import ClientBase
from hume._common.config import JobConfigBase
from hume._common.config.config_utils import config_from_model_type
from hume._common.hume_client_error import HumeClientError
from hume._common.model_type import ModelType
from hume._stream.stream_socket import StreamSocket

try:
import websockets
HAS_WEBSOCKETS = True
except ModuleNotFoundError:
HAS_WEBSOCKETS = False

logger = logging.getLogger(__name__)


class HumeStreamClient(ClientBase):
"""Streaming API client.
Example:
```python
import asyncio
from hume import HumeStreamClient, StreamSocket
from hume.config import FaceConfig
async def main():
client = HumeStreamClient("<your-api-key>")
configs = [FaceConfig(identify_faces=True)]
async with client.connect(configs) as socket:
socket: StreamSocket
result = await socket.send_file("<your-image-filepath>")
print(result)
asyncio.run(main())
```
"""

_DEFAULT_API_TIMEOUT = 10

def __init__(self, *args: Any, **kwargs: Any):
"""Construct a HumeStreamClient.
Args:
api_key (str): Hume API key.
"""
if not HAS_WEBSOCKETS:
raise HumeClientError("websockets package required to use HumeStreamClient")

super().__init__(*args, **kwargs)

@asynccontextmanager
async def connect(self, configs: List[JobConfigBase]) -> AsyncIterator:
"""Connect to the streaming API.
Args:
configs (List[JobConfigBase]): List of job configs.
"""
uri = (f"{self._api_ws_base_uri}/{self._api_version}/{ApiType.STREAM.value}/multi"
f"?apikey={self._api_key}")

try:
# pylint: disable=no-member
async with websockets.connect(uri) as protocol: # type: ignore[attr-defined]
yield StreamSocket(protocol, configs)
except websockets.exceptions.InvalidStatusCode as exc:
message = "Client initialized with invalid API key"
raise HumeClientError(message) from exc

@asynccontextmanager
async def _connect_to_models(self, configs_dict: Any) -> AsyncIterator:
"""Connect to the streaming API with a single models configuration dict.
Args:
configs_dict (Any): Models configurations dict. This should be a dict from model name
to model configuration dict. An empty dict uses the default configuration.
"""
configs = []
for model_name, config_dict in configs_dict.items():
model_type = ModelType.from_str(model_name)
config = config_from_model_type(model_type).deserialize(config_dict)
configs.append(config)

async with self.connect(configs) as websocket:
yield websocket
Loading

0 comments on commit 0712d71

Please sign in to comment.