From d82b3397da7b0c451fb15510f3c18e9157df783c Mon Sep 17 00:00:00 2001 From: Josh Goldberg Date: Thu, 4 Apr 2024 13:02:52 -0700 Subject: [PATCH 1/2] add filebytes upload method --- hume/_measurement/batch/hume_batch_client.py | 37 ++++++++++++------- tests/batch/test_service_hume_batch_client.py | 29 +++++++++++++++ 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/hume/_measurement/batch/hume_batch_client.py b/hume/_measurement/batch/hume_batch_client.py index 3d64311f..4daad40a 100644 --- a/hume/_measurement/batch/hume_batch_client.py +++ b/hume/_measurement/batch/hume_batch_client.py @@ -5,7 +5,7 @@ import json from collections.abc import Iterable from pathlib import Path -from typing import Any +from typing import Any, Optional, Union from hume._common.client_base import ClientBase from hume._common.utilities.config_utilities import serialize_configs @@ -76,6 +76,7 @@ def submit_job( callback_url: str | None = None, notify: bool | None = None, files: list[Path | str] | None = None, + filebytes: Optional[list[tuple[str, bytes]]] = None, text: list[str] | None = None, ) -> BatchJob: """Submit a job for batch processing. @@ -90,13 +91,14 @@ def submit_job( callback_url (str | None): A URL to which a POST request will be sent upon job completion. notify (bool | None): Wether an email notification should be sent upon job completion. files (list[Path | str] | None): List of paths to files on the local disk to be processed. + filebytes (Optional[list[tuple[str, bytes]]]): List of file bytes (raw file data) to be processed. text (list[str] | None): List of strings (raw text) to be processed. Returns: BatchJob: The `BatchJob` representing the batch computation. """ request = self._construct_request(configs, urls, text, transcription_config, callback_url, notify) - return self._submit_job(request, files) + return self._submit_job(request, files, filebytes) def get_job_details(self, job_id: str) -> BatchJobDetails: """Get details for the batch job. @@ -197,6 +199,7 @@ def _submit_job( self, request_body: Any, filepaths: list[Path | str] | None, + filebytes: Optional[list[tuple[str, bytes]]], ) -> BatchJob: """Start a job for batch processing by passing a JSON request body. @@ -206,6 +209,7 @@ def _submit_job( Args: request_body (Any): JSON request body to be passed to the batch API. filepaths (list[Path | str] | None): List of paths to files on the local disk to be processed. + filebytes (list[tuple[str, bytes]]] | None): List of bytes (raw file data) to be processed. Raises: HumeClientException: If the batch job fails to start. @@ -215,14 +219,14 @@ def _submit_job( """ endpoint = self._build_endpoint("batch", "jobs") - if filepaths is None: + if filepaths is None and filebytes is None: response = self._http_client.post( endpoint, json=request_body, headers=self._get_client_headers(), ) else: - form_data = self._get_multipart_form_data(request_body, filepaths) + form_data = self._get_multipart_form_data(request_body, filepaths, filebytes) response = self._http_client.post( endpoint, headers=self._get_client_headers(), @@ -253,26 +257,33 @@ def _submit_job( def _get_multipart_form_data( self, request_body: Any, - filepaths: Iterable[Path | str], - ) -> list[tuple[str, bytes | tuple[str, bytes]]]: - """Convert a list of filepaths into a list of multipart form data. + filepaths: Optional[list[Union[str, Path]]], + filebytes: Optional[list[tuple[str, bytes]]], + ) -> list[tuple[str, Union[bytes, tuple[str, bytes]]]]: + """Convert a list of filepaths and/or file bytes into a list of multipart form data. Multipart form data allows the client to attach files to the POST request, including both the raw file bytes and the filename. Args: request_body (Any): JSON request body to be passed to the batch API. - filepaths (list[Path | str]): List of paths to files on the local disk to be processed. + filepaths (list[Path | str] | None): List of paths to files on the local disk to be processed. + filebytes (list[str | bytes] | None): List of bytes (raw file data) to be processed. Returns: list[tuple[str, bytes | tuple[str, bytes]]]: A list of tuples representing the multipart form data for the POST request. """ - form_data: list[tuple[str, bytes | tuple[str, bytes]]] = [] - for filepath in filepaths: - path = Path(filepath) - post_file = ("file", (path.name, path.read_bytes())) - form_data.append(post_file) + form_data: list[tuple[str, Union[bytes, tuple[str, bytes]]]] = [] + if filepaths is not None: + for filepath in filepaths: + path = Path(filepath) + post_file = ("file", (path.name, path.read_bytes())) + form_data.append(post_file) + if filebytes is not None: + for filebyte in filebytes: + post_file = ("file", filebyte) + form_data.append(post_file) form_data.append(("json", json.dumps(request_body).encode("utf-8"))) return form_data diff --git a/tests/batch/test_service_hume_batch_client.py b/tests/batch/test_service_hume_batch_client.py index 08af72d9..6f6ef5d4 100644 --- a/tests/batch/test_service_hume_batch_client.py +++ b/tests/batch/test_service_hume_batch_client.py @@ -174,6 +174,35 @@ def test_local_file_upload_configure( # rather than the nine we'd get if we used 'word' granularity. assert len(grouped_predictions[0]["predictions"]) == 1 + # test for the case where a file is passed as a byte string + def test_file_upload_bytes_configure( + self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory + ) -> None: + data_url = eval_data["text-happy-place"] + data_filepath = tmp_path_factory.mktemp("data-dir") / "happy.txt" + urlretrieve(data_url, data_filepath) + with data_filepath.open("rb") as f: + data_bytes = f.read() + config = LanguageConfig(granularity="sentence") + job_files_dirpath = tmp_path_factory.mktemp("job-files") + job = batch_client.submit_job([], [config], filebytes=[("happy.txt", data_bytes)]) + self.check_job(job, config, LanguageConfig, job_files_dirpath, complete_config=False) + + predictions = job.get_predictions() + + assert len(predictions) == 1 + assert predictions[0]["results"] + assert len(predictions[0]["results"]["predictions"]) == 1 + assert predictions[0]["source"]["type"] == "file" + assert predictions[0]["source"]["filename"] == "happy.txt" + language_results = predictions[0]["results"]["predictions"][0]["models"]["language"] + grouped_predictions = language_results["grouped_predictions"] + assert len(grouped_predictions) == 1 + + # Configuring 'sentence' granularity should give us only one prediction + # rather than the nine we'd get if we used 'word' granularity. + assert len(grouped_predictions[0]["predictions"]) == 1 + def check_job( self, job: BatchJob, From 0f76c2ffe5f10a55435480df5eb89d7e7d1fd4d9 Mon Sep 17 00:00:00 2001 From: Chris Gregory Date: Fri, 14 Jun 2024 04:13:55 -0700 Subject: [PATCH 2/2] Update to latest docstrings --- hume/_measurement/batch/hume_batch_client.py | 6 +++--- tests/batch/test_hume_batch_client.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/hume/_measurement/batch/hume_batch_client.py b/hume/_measurement/batch/hume_batch_client.py index 4daad40a..3b842486 100644 --- a/hume/_measurement/batch/hume_batch_client.py +++ b/hume/_measurement/batch/hume_batch_client.py @@ -91,7 +91,7 @@ def submit_job( callback_url (str | None): A URL to which a POST request will be sent upon job completion. notify (bool | None): Wether an email notification should be sent upon job completion. files (list[Path | str] | None): List of paths to files on the local disk to be processed. - filebytes (Optional[list[tuple[str, bytes]]]): List of file bytes (raw file data) to be processed. + filebytes (list[tuple[str, bytes]] | None): List of file bytes (raw file data) to be processed. text (list[str] | None): List of strings (raw text) to be processed. Returns: @@ -257,8 +257,8 @@ def _submit_job( def _get_multipart_form_data( self, request_body: Any, - filepaths: Optional[list[Union[str, Path]]], - filebytes: Optional[list[tuple[str, bytes]]], + filepaths: Optional[Iterable[Union[str, Path]]], + filebytes: Optional[Iterable[tuple[str, bytes]]], ) -> list[tuple[str, Union[bytes, tuple[str, bytes]]]]: """Convert a list of filepaths and/or file bytes into a list of multipart form data. diff --git a/tests/batch/test_hume_batch_client.py b/tests/batch/test_hume_batch_client.py index af8f08d4..19e0b5a9 100644 --- a/tests/batch/test_hume_batch_client.py +++ b/tests/batch/test_hume_batch_client.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import MagicMock import pytest @@ -38,6 +39,7 @@ def test_face(self, batch_client: HumeBatchClient) -> None: "urls": [mock_url], }, None, + None, ) def test_burst(self, batch_client: HumeBatchClient) -> None: @@ -54,6 +56,7 @@ def test_burst(self, batch_client: HumeBatchClient) -> None: "urls": [mock_url], }, None, + None, ) def test_prosody(self, batch_client: HumeBatchClient) -> None: @@ -72,6 +75,7 @@ def test_prosody(self, batch_client: HumeBatchClient) -> None: "urls": [mock_url], }, None, + None, ) def test_language(self, batch_client: HumeBatchClient) -> None: @@ -91,6 +95,7 @@ def test_language(self, batch_client: HumeBatchClient) -> None: "urls": [mock_url], }, None, + None, ) def test_language_with_raw_text(self, batch_client: HumeBatchClient) -> None: @@ -111,6 +116,7 @@ def test_language_with_raw_text(self, batch_client: HumeBatchClient) -> None: "text": [mock_text], }, None, + None, ) def test_get_job(self, batch_client: HumeBatchClient) -> None: @@ -133,6 +139,7 @@ def test_files(self, batch_client: HumeBatchClient) -> None: }, }, ["my-audio.mp3"], + None, ) def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory) -> None: @@ -142,9 +149,9 @@ def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_f f.write("I can't believe this test passed!") request_body = {"mock": "body"} - filepaths = [filepath] - # pylint: disable=protected-access - result = batch_client._get_multipart_form_data(request_body, filepaths) + filepaths: list[Path] = [filepath] + filebytes: list[tuple[str, bytes]] = [] + result = batch_client._get_multipart_form_data(request_body, filepaths, filebytes) assert result == [ ("file", ("my-audio.mp3", b"I can't believe this test passed!")),