From 0712d71c71c32008c5a988ac0133e6c574919fac Mon Sep 17 00:00:00 2001 From: Chris Gregory <8800689+gregorybchris@users.noreply.github.com> Date: Thu, 6 Oct 2022 14:56:56 -0700 Subject: [PATCH] Websocket client for streaming API (#25) --- .github/workflows/ci.yaml | 4 +- docs/index.md | 33 ++++++- docs/stream/hume-stream-client.md | 1 + docs/stream/stream-socket.md | 1 + hume/__init__.py | 3 + hume/_batch/hume_batch_client.py | 23 ++++- hume/_common/api_type.py | 1 + hume/_common/client_base.py | 9 +- hume/_common/config/face_config.py | 9 +- hume/_common/config/language_config.py | 4 +- hume/_common/config/prosody_config.py | 3 +- hume/_stream/__init__.py | 8 ++ hume/_stream/hume_stream_client.py | 91 +++++++++++++++++ hume/_stream/stream_socket.py | 108 +++++++++++++++++++++ hume/config/__init__.py | 14 +++ mkdocs.yml | 3 + poetry.lock | 69 ++++++++++++- pyproject.toml | 16 +-- tests/batch/test_batch_client.py | 1 + tests/batch/test_batch_client_service.py | 12 +-- tests/batch/test_batch_job.py | 1 + tests/batch/test_batch_job_result.py | 1 + tests/batch/test_batch_job_status.py | 1 + tests/conftest.py | 18 ++++ tests/pytest.ini | 3 + tests/stream/__init__.py | 1 + tests/stream/test_stream_client.py | 45 +++++++++ tests/stream/test_stream_client_service.py | 47 +++++++++ tests/stream/test_stream_socket.py | 68 +++++++++++++ 29 files changed, 560 insertions(+), 38 deletions(-) create mode 100644 docs/stream/hume-stream-client.md create mode 100644 docs/stream/stream-socket.md create mode 100644 hume/_stream/__init__.py create mode 100644 hume/_stream/hume_stream_client.py create mode 100644 hume/_stream/stream_socket.py create mode 100644 hume/config/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/stream/__init__.py create mode 100644 tests/stream/test_stream_client.py create mode 100644 tests/stream/test_stream_client_service.py create mode 100644 tests/stream/test_stream_socket.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0f6672fa..6c652a47 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -34,7 +34,7 @@ jobs: if [ -d /poetryenvs ]; then rm -rf ~/poetryenvs; fi poetry config virtualenvs.path ~/poetryenvs - poetry install + poetry install -E stream - name: Run flake8 shell: bash @@ -99,7 +99,7 @@ jobs: if [ -d /poetryenvs ]; then rm -rf ~/poetryenvs; fi poetry config virtualenvs.path ~/poetryenvs - poetry install + poetry install -E stream - name: Run pytest shell: bash diff --git a/docs/index.md b/docs/index.md index dd7b00a2..465f77d9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,8 +6,16 @@ Python versions 3.8 and 3.9 are supported ## Installation -```python -pip install hume +Basic installation: + +```bash +$ pip install hume +``` + +Websocket and streaming features can be enabled with: + +```bash +$ pip install hume[stream] ``` ## Basic Usage @@ -45,6 +53,27 @@ job = client.get_job(job_id) print(job) ``` +### Stream predictions over a websocket + +> Note: `pip install hume[stream]` is required to use websocket features + +```python +import asyncio + +from hume import HumeStreamClient, StreamSocket +from hume.config import FaceConfig + +async def main(): + client = HumeStreamClient("") + configs = [FaceConfig(identify_faces=True)] + async with client.connect(configs) as socket: + socket: StreamSocket + result = await socket.send_file("") + print(result) + +asyncio.run(main()) +``` + ## Other Resources - [Hume AI Homepage](https://hume.ai) diff --git a/docs/stream/hume-stream-client.md b/docs/stream/hume-stream-client.md new file mode 100644 index 00000000..90bed981 --- /dev/null +++ b/docs/stream/hume-stream-client.md @@ -0,0 +1 @@ +::: hume._stream.hume_stream_client.HumeStreamClient diff --git a/docs/stream/stream-socket.md b/docs/stream/stream-socket.md new file mode 100644 index 00000000..30f45a4e --- /dev/null +++ b/docs/stream/stream-socket.md @@ -0,0 +1 @@ +::: hume._stream.stream_socket.StreamSocket diff --git a/hume/__init__.py b/hume/__init__.py index b650df40..019bc46d 100644 --- a/hume/__init__.py +++ b/hume/__init__.py @@ -2,6 +2,7 @@ import importlib.metadata from hume._batch import BatchJob, BatchJobResult, BatchJobStatus, HumeBatchClient +from hume._stream import HumeStreamClient, StreamSocket from hume._common.hume_client_error import HumeClientError from hume._common.model_type import ModelType @@ -14,5 +15,7 @@ "BatchJobStatus", "HumeBatchClient", "HumeClientError", + "HumeStreamClient", "ModelType", + "StreamSocket", ] diff --git a/hume/_batch/hume_batch_client.py b/hume/_batch/hume_batch_client.py index 0be805c7..4c269de1 100644 --- a/hume/_batch/hume_batch_client.py +++ b/hume/_batch/hume_batch_client.py @@ -16,7 +16,24 @@ class HumeBatchClient(ClientBase): - """Batch API client.""" + """Batch API client. + + Example: + ```python + from hume import HumeBatchClient + + client = HumeBatchClient("") + job = client.submit_face([""]) + + print(job) + print("Running...") + + result = job.await_complete() + result.download_predictions("predictions.json") + + print("Predictions downloaded!") + ``` + """ _DEFAULT_API_TIMEOUT = 10 @@ -40,7 +57,7 @@ def get_job_result(self, job_id: str) -> BatchJobResult: Returns: BatchJobResult: Batch job result. """ - endpoint = (f"{self._api_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs/{job_id}" + endpoint = (f"{self._api_http_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs/{job_id}" f"?apikey={self._api_key}") response = requests.get(endpoint, timeout=self._DEFAULT_API_TIMEOUT) body = response.json() @@ -178,7 +195,7 @@ def start_job(self, request_body: Any) -> BatchJob: Returns: BatchJob: A `BatchJob` that wraps the batch computation. """ - endpoint = (f"{self._api_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs" + endpoint = (f"{self._api_http_base_url}/{self._api_version}/{ApiType.BATCH.value}/jobs" f"?apikey={self._api_key}") response = requests.post(endpoint, json=request_body, timeout=self._DEFAULT_API_TIMEOUT) diff --git a/hume/_common/api_type.py b/hume/_common/api_type.py index cbba783f..057ec6d3 100644 --- a/hume/_common/api_type.py +++ b/hume/_common/api_type.py @@ -6,3 +6,4 @@ class ApiType(Enum): """API type.""" BATCH = "batch" + STREAM = "stream" diff --git a/hume/_common/client_base.py b/hume/_common/client_base.py index 0963b08e..e824ed29 100644 --- a/hume/_common/client_base.py +++ b/hume/_common/client_base.py @@ -6,13 +6,15 @@ class ClientBase(ABC): """Base class for Hume API clients.""" - _API_BASE_URL = "https://api.hume.ai" + _HTTP_BASE_URL = "https://api.hume.ai" + _WS_BASE_URI = "wss://api.hume.ai" def __init__( self, api_key: str, _api_version: str = "v0", - _api_base_url: Optional[str] = None, + _api_http_base_url: Optional[str] = None, + _api_ws_base_uri: Optional[str] = None, ): """Construct a new Hume API client. @@ -21,4 +23,5 @@ def __init__( """ self._api_key = api_key self._api_version = _api_version - self._api_base_url = self._API_BASE_URL if _api_base_url is None else _api_base_url + self._api_http_base_url = self._HTTP_BASE_URL if _api_http_base_url is None else _api_http_base_url + self._api_ws_base_uri = self._WS_BASE_URI if _api_ws_base_uri is None else _api_ws_base_uri diff --git a/hume/_common/config/face_config.py b/hume/_common/config/face_config.py index d5fcb1fe..f2a14034 100644 --- a/hume/_common/config/face_config.py +++ b/hume/_common/config/face_config.py @@ -10,6 +10,7 @@ class FaceConfig(JobConfigBase["FaceConfig"]): def __init__( self, + *, fps_pred: Optional[float] = None, prob_threshold: Optional[float] = None, identify_faces: Optional[bool] = None, @@ -67,8 +68,8 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "FaceConfig": FaceConfig: Deserialized `FaceConfig` object. """ return cls( - fps_pred=request_dict["fps_pred"], - prob_threshold=request_dict["prob_threshold"], - identify_faces=request_dict["identify_faces"], - min_face_size=request_dict["min_face_size"], + fps_pred=request_dict.get("fps_pred"), + prob_threshold=request_dict.get("prob_threshold"), + identify_faces=request_dict.get("identify_faces"), + min_face_size=request_dict.get("min_face_size"), ) diff --git a/hume/_common/config/language_config.py b/hume/_common/config/language_config.py index 241d2c12..bab972d5 100644 --- a/hume/_common/config/language_config.py +++ b/hume/_common/config/language_config.py @@ -57,6 +57,6 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "LanguageConfig": LanguageConfig: Deserialized `LanguageConfig` object. """ return cls( - sliding_window=request_dict["sliding_window"], - identify_speakers=request_dict["identify_speakers"], + sliding_window=request_dict.get("sliding_window"), + identify_speakers=request_dict.get("identify_speakers"), ) diff --git a/hume/_common/config/prosody_config.py b/hume/_common/config/prosody_config.py index b6e5dc7e..6f4e88ea 100644 --- a/hume/_common/config/prosody_config.py +++ b/hume/_common/config/prosody_config.py @@ -10,6 +10,7 @@ class ProsodyConfig(JobConfigBase["ProsodyConfig"]): def __init__( self, + *, identify_speakers: Optional[bool] = None, ): """Construct a `ProsodyConfig`. @@ -50,4 +51,4 @@ def deserialize(cls, request_dict: Dict[str, Any]) -> "ProsodyConfig": Returns: ProsodyConfig: Deserialized `ProsodyConfig` object. """ - return cls(identify_speakers=request_dict["identify_speakers"]) + return cls(identify_speakers=request_dict.get("identify_speakers")) diff --git a/hume/_stream/__init__.py b/hume/_stream/__init__.py new file mode 100644 index 00000000..97fbdc90 --- /dev/null +++ b/hume/_stream/__init__.py @@ -0,0 +1,8 @@ +"""Module init.""" +from hume._stream.hume_stream_client import HumeStreamClient +from hume._stream.stream_socket import StreamSocket + +__all__ = [ + "HumeStreamClient", + "StreamSocket", +] diff --git a/hume/_stream/hume_stream_client.py b/hume/_stream/hume_stream_client.py new file mode 100644 index 00000000..7416c4e4 --- /dev/null +++ b/hume/_stream/hume_stream_client.py @@ -0,0 +1,91 @@ +"""Streaming API client.""" +import logging +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, List + +from hume._common.api_type import ApiType +from hume._common.client_base import ClientBase +from hume._common.config import JobConfigBase +from hume._common.config.config_utils import config_from_model_type +from hume._common.hume_client_error import HumeClientError +from hume._common.model_type import ModelType +from hume._stream.stream_socket import StreamSocket + +try: + import websockets + HAS_WEBSOCKETS = True +except ModuleNotFoundError: + HAS_WEBSOCKETS = False + +logger = logging.getLogger(__name__) + + +class HumeStreamClient(ClientBase): + """Streaming API client. + + Example: + ```python + import asyncio + + from hume import HumeStreamClient, StreamSocket + from hume.config import FaceConfig + + async def main(): + client = HumeStreamClient("") + configs = [FaceConfig(identify_faces=True)] + async with client.connect(configs) as socket: + socket: StreamSocket + result = await socket.send_file("") + print(result) + + asyncio.run(main()) + ``` + """ + + _DEFAULT_API_TIMEOUT = 10 + + def __init__(self, *args: Any, **kwargs: Any): + """Construct a HumeStreamClient. + + Args: + api_key (str): Hume API key. + """ + if not HAS_WEBSOCKETS: + raise HumeClientError("websockets package required to use HumeStreamClient") + + super().__init__(*args, **kwargs) + + @asynccontextmanager + async def connect(self, configs: List[JobConfigBase]) -> AsyncIterator: + """Connect to the streaming API. + + Args: + configs (List[JobConfigBase]): List of job configs. + """ + uri = (f"{self._api_ws_base_uri}/{self._api_version}/{ApiType.STREAM.value}/multi" + f"?apikey={self._api_key}") + + try: + # pylint: disable=no-member + async with websockets.connect(uri) as protocol: # type: ignore[attr-defined] + yield StreamSocket(protocol, configs) + except websockets.exceptions.InvalidStatusCode as exc: + message = "Client initialized with invalid API key" + raise HumeClientError(message) from exc + + @asynccontextmanager + async def _connect_to_models(self, configs_dict: Any) -> AsyncIterator: + """Connect to the streaming API with a single models configuration dict. + + Args: + configs_dict (Any): Models configurations dict. This should be a dict from model name + to model configuration dict. An empty dict uses the default configuration. + """ + configs = [] + for model_name, config_dict in configs_dict.items(): + model_type = ModelType.from_str(model_name) + config = config_from_model_type(model_type).deserialize(config_dict) + configs.append(config) + + async with self.connect(configs) as websocket: + yield websocket diff --git a/hume/_stream/stream_socket.py b/hume/_stream/stream_socket.py new file mode 100644 index 00000000..8d544e13 --- /dev/null +++ b/hume/_stream/stream_socket.py @@ -0,0 +1,108 @@ +"""Streaming socket connection.""" +import base64 +import json +from pathlib import Path +from typing import Any, List, Union + +try: + from websockets.client import WebSocketClientProtocol + HAS_WEBSOCKETS = True +except ModuleNotFoundError: + HAS_WEBSOCKETS = False + +from hume._common.config import JobConfigBase +from hume._common.hume_client_error import HumeClientError + + +class StreamSocket: + """Streaming socket connection.""" + + def __init__( + self, + protocol: "WebSocketClientProtocol", + configs: List[JobConfigBase], + ): + """Construct a `StreamSocket`. + + Args: + protocol (WebSocketClientProtocol): Protocol instance from websockets library. + configs (List[JobConfigBase]): List of model configurations. + + Raises: + HumeClientError: If there is an error processing media over the socket connection. + """ + if not HAS_WEBSOCKETS: + raise HumeClientError("websockets package required to use HumeStreamClient") + + self._configs = configs + self._protocol = protocol + + self._serialized_configs = self._serialize_configs(configs) + + @classmethod + def _serialize_configs(cls, configs: List[JobConfigBase]) -> Any: + serialized = {} + for config in configs: + model_type = config.get_model_type() + model_name = model_type.value + serialized[model_name] = config.serialize() + return serialized + + @classmethod + def _file_to_bytes(cls, filepath: Path) -> bytes: + with filepath.open('rb') as f: + return base64.b64encode(f.read()) + + def _get_predictions(self, response: str) -> Any: + try: + json_response = json.loads(response) + except json.JSONDecodeError as exc: + raise HumeClientError("Unexpected error when fetching streaming API predictions") from exc + + return json_response + + async def send_bytes_str(self, bytes_str: str) -> Any: + """Send raw bytes string on the `StreamSocket`. + + Note: Must be ascii encoded bytes. + + Args: + bytes_str (str): Raw bytes of media to send on socket connection converted to a string. + + Returns: + Any: Predictions from the streaming API. + """ + payload = { + "data": bytes_str, + "models": self._serialized_configs, + } + json_payload = json.dumps(payload) + await self._protocol.send(json_payload) + response = await self._protocol.recv() + # Cast to str because websockets can send bytes, but we will always accept JSON strings + response_str = str(response) + return self._get_predictions(response_str) + + async def send_bytes(self, bytes_data: bytes) -> Any: + """Send raw bytes on the `StreamSocket`. + + Args: + bytes_data (bytes): Raw bytes of media to send on socket connection. + + Returns: + Any: Predictions from the streaming API. + """ + bytes_str = bytes_data.decode("ascii") + return await self.send_bytes_str(bytes_str) + + async def send_file(self, filepath: Union[str, Path]) -> Any: + """Send a file on the `StreamSocket`. + + Args: + filepath (Path): Path to media file to send on socket connection. + + Returns: + Any: Predictions from the streaming API. + """ + bytes_data = self._file_to_bytes(Path(filepath)) + return await self.send_bytes(bytes_data) diff --git a/hume/config/__init__.py b/hume/config/__init__.py new file mode 100644 index 00000000..dfe4404e --- /dev/null +++ b/hume/config/__init__.py @@ -0,0 +1,14 @@ +"""Module init.""" +from hume._common.config.burst_config import BurstConfig +from hume._common.config.face_config import FaceConfig +from hume._common.config.language_config import LanguageConfig +from hume._common.config.job_config_base import JobConfigBase +from hume._common.config.prosody_config import ProsodyConfig + +__all__ = [ + "BurstConfig", + "FaceConfig", + "LanguageConfig", + "JobConfigBase", + "ProsodyConfig", +] diff --git a/mkdocs.yml b/mkdocs.yml index f10c522c..de195548 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,9 @@ nav: - "BatchJob": batch/batch-job.md - "BatchJobStatus": batch/batch-job-status.md - "BatchJobResult": batch/batch-job-result.md + - "Streaming Reference": + - "HumeStreamClient": stream/hume-stream-client.md + - "StreamSocket": stream/stream-socket.md markdown_extensions: - pymdownx.highlight: diff --git a/poetry.lock b/poetry.lock index f3d0430b..9992f632 100644 --- a/poetry.lock +++ b/poetry.lock @@ -661,7 +661,7 @@ python-versions = ">=3.6,<4.0" [[package]] name = "types-requests" -version = "2.28.10" +version = "2.28.11" description = "Typing stubs for requests" category = "dev" optional = false @@ -729,6 +729,14 @@ python-versions = ">=3.6" [package.extras] watchmedo = ["PyYAML (>=3.10)"] +[[package]] +name = "websockets" +version = "10.3" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +category = "main" +optional = true +python-versions = ">=3.7" + [[package]] name = "wrapt" version = "1.14.1" @@ -757,10 +765,13 @@ python-versions = ">=3.7" docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"] testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[extras] +stream = ["websockets"] + [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.10" -content-hash = "7fb6a04a0c0588ff89bf7edfa2b073ea3d6356f90cb4eb1721d5935218f0582c" +content-hash = "10bc7241ff1b6a810646f69c37aa59e0a3b655957da9a4d953d93212cef377cf" [metadata.files] astroid = [ @@ -1160,8 +1171,8 @@ tomlkit = [ {file = "tomlkit-0.11.4.tar.gz", hash = "sha256:3235a9010fae54323e727c3ac06fb720752fe6635b3426e379daec60fbd44a83"}, ] types-requests = [ - {file = "types-requests-2.28.10.tar.gz", hash = "sha256:97d8f40aa1ffe1e58c3726c77d63c182daea9a72d9f1fa2cafdea756b2a19f2c"}, - {file = "types_requests-2.28.10-py3-none-any.whl", hash = "sha256:45b485725ed58752f2b23461252f1c1ad9205b884a1e35f786bb295525a3e16a"}, + {file = "types-requests-2.28.11.tar.gz", hash = "sha256:7ee827eb8ce611b02b5117cfec5da6455365b6a575f5e3ff19f655ba603e6b4e"}, + {file = "types_requests-2.28.11-py3-none-any.whl", hash = "sha256:af5f55e803cabcfb836dad752bd6d8a0fc8ef1cd84243061c0e27dee04ccf4fd"}, ] types-setuptools = [ {file = "types-setuptools-57.4.18.tar.gz", hash = "sha256:8ee03d823fe7fda0bd35faeae33d35cb5c25b497263e6a58b34c4cfd05f40bcf"}, @@ -1210,6 +1221,56 @@ watchdog = [ {file = "watchdog-2.1.9-py3-none-win_ia64.whl", hash = "sha256:ad576a565260d8f99d97f2e64b0f97a48228317095908568a9d5c786c829d428"}, {file = "watchdog-2.1.9.tar.gz", hash = "sha256:43ce20ebb36a51f21fa376f76d1d4692452b2527ccd601950d69ed36b9e21609"}, ] +websockets = [ + {file = "websockets-10.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:661f641b44ed315556a2fa630239adfd77bd1b11cb0b9d96ed8ad90b0b1e4978"}, + {file = "websockets-10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b529fdfa881b69fe563dbd98acce84f3e5a67df13de415e143ef053ff006d500"}, + {file = "websockets-10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f351c7d7d92f67c0609329ab2735eee0426a03022771b00102816a72715bb00b"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379e03422178436af4f3abe0aa8f401aa77ae2487843738542a75faf44a31f0c"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e904c0381c014b914136c492c8fa711ca4cced4e9b3d110e5e7d436d0fc289e8"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e7e6f2d6fd48422071cc8a6f8542016f350b79cc782752de531577d35e9bd677"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b9c77f0d1436ea4b4dc089ed8335fa141e6a251a92f75f675056dac4ab47a71e"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e6fa05a680e35d0fcc1470cb070b10e6fe247af54768f488ed93542e71339d6f"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2f94fa3ae454a63ea3a19f73b95deeebc9f02ba2d5617ca16f0bbdae375cda47"}, + {file = "websockets-10.3-cp310-cp310-win32.whl", hash = "sha256:6ed1d6f791eabfd9808afea1e068f5e59418e55721db8b7f3bfc39dc831c42ae"}, + {file = "websockets-10.3-cp310-cp310-win_amd64.whl", hash = "sha256:347974105bbd4ea068106ec65e8e8ebd86f28c19e529d115d89bd8cc5cda3079"}, + {file = "websockets-10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fab7c640815812ed5f10fbee7abbf58788d602046b7bb3af9b1ac753a6d5e916"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:994cdb1942a7a4c2e10098d9162948c9e7b235df755de91ca33f6e0481366fdb"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:aad5e300ab32036eb3fdc350ad30877210e2f51bceaca83fb7fef4d2b6c72b79"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e49ea4c1a9543d2bd8a747ff24411509c29e4bdcde05b5b0895e2120cb1a761d"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6ea6b300a6bdd782e49922d690e11c3669828fe36fc2471408c58b93b5535a98"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ef5ce841e102278c1c2e98f043db99d6755b1c58bde475516aef3a008ed7f28e"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1655a6fc7aecd333b079d00fb3c8132d18988e47f19740c69303bf02e9883c6"}, + {file = "websockets-10.3-cp37-cp37m-win32.whl", hash = "sha256:83e5ca0d5b743cde3d29fda74ccab37bdd0911f25bd4cdf09ff8b51b7b4f2fa1"}, + {file = "websockets-10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:da4377904a3379f0c1b75a965fff23b28315bcd516d27f99a803720dfebd94d4"}, + {file = "websockets-10.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a1e15b230c3613e8ea82c9fc6941b2093e8eb939dd794c02754d33980ba81e36"}, + {file = "websockets-10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:31564a67c3e4005f27815634343df688b25705cccb22bc1db621c781ddc64c69"}, + {file = "websockets-10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c8d1d14aa0f600b5be363077b621b1b4d1eb3fbf90af83f9281cda668e6ff7fd"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fbd7d77f8aba46d43245e86dd91a8970eac4fb74c473f8e30e9c07581f852b2"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:210aad7fdd381c52e58777560860c7e6110b6174488ef1d4b681c08b68bf7f8c"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6075fd24df23133c1b078e08a9b04a3bc40b31a8def4ee0b9f2c8865acce913e"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f6d96fdb0975044fdd7953b35d003b03f9e2bcf85f2d2cf86285ece53e9f991"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c7250848ce69559756ad0086a37b82c986cd33c2d344ab87fea596c5ac6d9442"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:28dd20b938a57c3124028680dc1600c197294da5db4292c76a0b48efb3ed7f76"}, + {file = "websockets-10.3-cp38-cp38-win32.whl", hash = "sha256:54c000abeaff6d8771a4e2cef40900919908ea7b6b6a30eae72752607c6db559"}, + {file = "websockets-10.3-cp38-cp38-win_amd64.whl", hash = "sha256:7ab36e17af592eec5747c68ef2722a74c1a4a70f3772bc661079baf4ae30e40d"}, + {file = "websockets-10.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a141de3d5a92188234afa61653ed0bbd2dde46ad47b15c3042ffb89548e77094"}, + {file = "websockets-10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:97bc9d41e69a7521a358f9b8e44871f6cdeb42af31815c17aed36372d4eec667"}, + {file = "websockets-10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d6353ba89cfc657a3f5beabb3b69be226adbb5c6c7a66398e17809b0ce3c4731"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec2b0ab7edc8cd4b0eb428b38ed89079bdc20c6bdb5f889d353011038caac2f9"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:85506b3328a9e083cc0a0fb3ba27e33c8db78341b3eb12eb72e8afd166c36680"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8af75085b4bc0b5c40c4a3c0e113fa95e84c60f4ed6786cbb675aeb1ee128247"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:07cdc0a5b2549bcfbadb585ad8471ebdc7bdf91e32e34ae3889001c1c106a6af"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:5b936bf552e4f6357f5727579072ff1e1324717902127ffe60c92d29b67b7be3"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e4e08305bfd76ba8edab08dcc6496f40674f44eb9d5e23153efa0a35750337e8"}, + {file = "websockets-10.3-cp39-cp39-win32.whl", hash = "sha256:bb621ec2dbbbe8df78a27dbd9dd7919f9b7d32a73fafcb4d9252fc4637343582"}, + {file = "websockets-10.3-cp39-cp39-win_amd64.whl", hash = "sha256:51695d3b199cd03098ae5b42833006a0f43dc5418d3102972addc593a783bc02"}, + {file = "websockets-10.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:907e8247480f287aa9bbc9391bd6de23c906d48af54c8c421df84655eef66af7"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b1359aba0ff810d5830d5ab8e2c4a02bebf98a60aa0124fb29aa78cfdb8031f"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93d5ea0b5da8d66d868b32c614d2b52d14304444e39e13a59566d4acb8d6e2e4"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7934e055fd5cd9dee60f11d16c8d79c4567315824bacb1246d0208a47eca9755"}, + {file = "websockets-10.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:3eda1cb7e9da1b22588cefff09f0951771d6ee9fa8dbe66f5ae04cc5f26b2b55"}, + {file = "websockets-10.3.tar.gz", hash = "sha256:fc06cc8073c8e87072138ba1e431300e2d408f054b27047d047b549455066ff4"}, +] wrapt = [ {file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, diff --git a/pyproject.toml b/pyproject.toml index 39fb8b0f..d8adcd57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,13 @@ license = "Proprietary" name = "hume" readme = "README.md" repository = "https://github.com/HumeAI/hume-python-sdk" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8,<3.10" requests = "^2.26.0" typing-extensions = "^4.3.0" +websockets = { version = "^10.3", optional = true } [tool.poetry.dev-dependencies] covcheck = "^0.4.0" @@ -55,17 +56,20 @@ mkdocs = "^1.3.1" mkdocs-material = "^8.4.1" mkdocstrings = { version = "^0.19.0", extras = ["python"] } +[tool.poetry.extras] +stream = ["websockets"] + [build-system] build-backend = "poetry.core.masonry.api" requires = ["poetry-core>=1.0.0"] [tool.covcheck.group.unit.coverage] -branch = 68.0 -line = 82.0 +branch = 65.0 +line = 74.0 [tool.covcheck.group.service.coverage] -branch = 77.0 -line = 92.0 +branch = 72.0 +line = 81.0 [tool.flake8] ignore = "" # Required to disable default ignores @@ -78,7 +82,7 @@ disallow_untyped_defs = true ignore_missing_imports = true [tool.pylint.basic] -good-names = ["id"] +good-names = ["id", "f"] max-args = 12 max-locals = 25 notes = ["FIXME"] diff --git a/tests/batch/test_batch_client.py b/tests/batch/test_batch_client.py index f0c523ab..9ddc7b66 100644 --- a/tests/batch/test_batch_client.py +++ b/tests/batch/test_batch_client.py @@ -15,6 +15,7 @@ def batch_client(monkeypatch: MonkeyPatch) -> HumeBatchClient: return client +@pytest.mark.batch class TestHumeBatchClient: def test_face(self, batch_client: HumeBatchClient): diff --git a/tests/batch/test_batch_client_service.py b/tests/batch/test_batch_client_service.py index 3b8b3b27..a7e7b5f7 100644 --- a/tests/batch/test_batch_client_service.py +++ b/tests/batch/test_batch_client_service.py @@ -22,17 +22,7 @@ def batch_client() -> HumeBatchClient: return HumeBatchClient(api_key) -@pytest.fixture(scope="module") -def eval_data() -> EvalData: - base_url = "https://storage.googleapis.com/hume-test-data" - return { - "image-obama-face": f"{base_url}/image/obama.png", - "burst-amusement-009": f"{base_url}/audio/burst-amusement-009.mp3", - "prosody-horror-1051": f"{base_url}/audio/prosody-horror-1051.mp3", - "text-happy-place": f"{base_url}/text/happy.txt", - } - - +@pytest.mark.batch @pytest.mark.service class TestHumeBatchClientService: diff --git a/tests/batch/test_batch_job.py b/tests/batch/test_batch_job.py index bf8f943e..d2cc94ca 100644 --- a/tests/batch/test_batch_job.py +++ b/tests/batch/test_batch_job.py @@ -20,6 +20,7 @@ def batch_client() -> Mock: return mock_client +@pytest.mark.batch class TestBatchJob: def test_job_id(self, batch_client: Mock): diff --git a/tests/batch/test_batch_job_result.py b/tests/batch/test_batch_job_result.py index 06f95232..aa9587eb 100644 --- a/tests/batch/test_batch_job_result.py +++ b/tests/batch/test_batch_job_result.py @@ -30,6 +30,7 @@ def failed_result() -> BatchJobResult: return BatchJobResult.from_response(response) +@pytest.mark.batch class TestBatchJobResult: def test_queued_status(self, queued_result: BatchJobResult): diff --git a/tests/batch/test_batch_job_status.py b/tests/batch/test_batch_job_status.py index 544732a7..6a93833e 100644 --- a/tests/batch/test_batch_job_status.py +++ b/tests/batch/test_batch_job_status.py @@ -3,6 +3,7 @@ from hume import BatchJobStatus +@pytest.mark.batch class TestBatchJobStatus: def test_update(self): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ca1b1e7c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,18 @@ +"""Conftest.""" +from typing import Dict + +import pytest + +EvalData = Dict[str, str] + + +@pytest.fixture(scope="module") +def eval_data() -> EvalData: + """Fixture for evaluation data.""" + base_url = "https://storage.googleapis.com/hume-test-data" + return { + "image-obama-face": f"{base_url}/image/obama.png", + "burst-amusement-009": f"{base_url}/audio/burst-amusement-009.mp3", + "prosody-horror-1051": f"{base_url}/audio/prosody-horror-1051.mp3", + "text-happy-place": f"{base_url}/text/happy.txt", + } diff --git a/tests/pytest.ini b/tests/pytest.ini index c1693edb..ff106a81 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -2,4 +2,7 @@ log_cli = 1 log_cli_level = INFO markers = + asyncio: Uses async features + batch: Uses the Hume batch API service: Authenticates with Hume APIs using HUME_PLATFORM_API_KEY environment variable. + stream: Uses the Hume streaming API diff --git a/tests/stream/__init__.py b/tests/stream/__init__.py new file mode 100644 index 00000000..26b8dab0 --- /dev/null +++ b/tests/stream/__init__.py @@ -0,0 +1 @@ +"""Module init.""" diff --git a/tests/stream/test_stream_client.py b/tests/stream/test_stream_client.py new file mode 100644 index 00000000..87c45def --- /dev/null +++ b/tests/stream/test_stream_client.py @@ -0,0 +1,45 @@ +from contextlib import asynccontextmanager +from unittest.mock import Mock + +import pytest +import websockets +from pytest import MonkeyPatch + +from hume import HumeStreamClient, StreamSocket +from hume._common.config import FaceConfig + + +def mock_connect(uri: str): + assert uri == "wss://api.hume.ai/v0/stream/multi?apikey=0000-0000-0000-0000" + + @asynccontextmanager + async def mock_connection() -> Mock: + yield Mock() + + return mock_connection() + + +@pytest.fixture(scope="function") +def stream_client() -> HumeStreamClient: + return HumeStreamClient("0000-0000-0000-0000") + + +@pytest.mark.asyncio +@pytest.mark.stream +class TestHumeStreamClient: + + async def test_connect(self, stream_client: HumeStreamClient, monkeypatch: MonkeyPatch): + monkeypatch.setattr(websockets, "connect", mock_connect) + configs = [FaceConfig(identify_faces=True)] + async with stream_client.connect(configs) as websocket: + assert isinstance(websocket, StreamSocket) + + async def test_connect_to_models(self, stream_client: HumeStreamClient, monkeypatch: MonkeyPatch): + monkeypatch.setattr(websockets, "connect", mock_connect) + configs_dict = { + "face": { + "identify_faces": True, + }, + } + async with stream_client._connect_to_models(configs_dict) as websocket: + assert isinstance(websocket, StreamSocket) diff --git a/tests/stream/test_stream_client_service.py b/tests/stream/test_stream_client_service.py new file mode 100644 index 00000000..c48e53e6 --- /dev/null +++ b/tests/stream/test_stream_client_service.py @@ -0,0 +1,47 @@ +import logging +import os +from typing import Dict +from urllib.request import urlretrieve + +import pytest +from pytest import TempPathFactory + +from hume import HumeStreamClient, HumeClientError +from hume._common.config import FaceConfig + +EvalData = Dict[str, str] + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def stream_client() -> HumeStreamClient: + api_key = os.getenv("HUME_DEV_API_KEY") + if api_key is None: + raise ValueError("Cannot construct HumeStreamClient, HUME_DEV_API_KEY variable not set.") + + return HumeStreamClient(api_key) + + +@pytest.mark.asyncio +@pytest.mark.stream +@pytest.mark.service +class TestHumeStreamClientService: + + async def test_run(self, eval_data: EvalData, stream_client: HumeStreamClient, tmp_path_factory: TempPathFactory): + data_url = eval_data["image-obama-face"] + data_filepath = tmp_path_factory.mktemp("data-dir") / "data-file" + urlretrieve(data_url, data_filepath) + + configs = [FaceConfig(identify_faces=True)] + async with stream_client.connect(configs) as websocket: + predictions = await websocket.send_file(data_filepath) + assert predictions is not None + + async def test_invalid_api_key(self): + invalid_client = HumeStreamClient("invalid-api-key") + message = "Client initialized with invalid API key" + configs = [FaceConfig(identify_faces=True)] + with pytest.raises(HumeClientError, match=message): + async with invalid_client.connect(configs): + pass diff --git a/tests/stream/test_stream_socket.py b/tests/stream/test_stream_socket.py new file mode 100644 index 00000000..b287e30b --- /dev/null +++ b/tests/stream/test_stream_socket.py @@ -0,0 +1,68 @@ +import json +from unittest.mock import Mock + +import pytest +from pytest import TempPathFactory + +from hume import StreamSocket +from hume._common.config import FaceConfig + + +@pytest.fixture(scope="function") +def mock_protocol(): + + async def mock_send(message: str) -> None: + assert json.loads(message) == { + "data": "bW9jay1tZWRpYS1maWxl", + "models": { + "face": { + "fps_pred": None, + "prob_threshold": None, + "identify_faces": True, + "min_face_size": None + }, + }, + } + + async def mock_recv() -> str: + return json.dumps({ + "face": { + "predictions": "mock-predictions", + }, + }) + + protocol = Mock() + protocol.send = mock_send + protocol.recv = mock_recv + return protocol + + +@pytest.mark.asyncio +@pytest.mark.stream +class TestStreamSocket: + + async def test_send_bytes_str(self, mock_protocol: Mock): + configs = [FaceConfig(identify_faces=True)] + socket = StreamSocket(mock_protocol, configs) + mock_bytes_str = "bW9jay1tZWRpYS1maWxl" + result = await socket.send_bytes_str(mock_bytes_str) + assert result["face"]["predictions"] == "mock-predictions" + + async def test_send_bytes(self, mock_protocol: Mock): + configs = [FaceConfig(identify_faces=True)] + socket = StreamSocket(mock_protocol, configs) + mock_bytes = b'bW9jay1tZWRpYS1maWxl' + result = await socket.send_bytes(mock_bytes) + assert result["face"]["predictions"] == "mock-predictions" + + async def test_send_file(self, mock_protocol: Mock, tmp_path_factory: TempPathFactory): + configs = [FaceConfig(identify_faces=True)] + socket = StreamSocket(mock_protocol, configs) + + media_data = "mock-media-file" + media_filepath = tmp_path_factory.mktemp("data") / "data.txt" + with media_filepath.open("w") as f: + f.write(media_data) + + result = await socket.send_file(media_filepath) + assert result["face"]["predictions"] == "mock-predictions"