Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bytes argument to HumeBatchClient.submit_job() instead of only accepting file paths #121

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions hume/_measurement/batch/hume_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from typing import Any, Optional, Union

from hume._common.client_base import ClientBase
from hume._common.utilities.config_utilities import serialize_configs
Expand Down Expand Up @@ -76,6 +76,7 @@ def submit_job(
callback_url: str | None = None,
notify: bool | None = None,
files: list[Path | str] | None = None,
filebytes: Optional[list[tuple[str, bytes]]] = None,
text: list[str] | None = None,
) -> BatchJob:
"""Submit a job for batch processing.
Expand All @@ -90,13 +91,14 @@ def submit_job(
callback_url (str | None): A URL to which a POST request will be sent upon job completion.
notify (bool | None): Wether an email notification should be sent upon job completion.
files (list[Path | str] | None): List of paths to files on the local disk to be processed.
filebytes (list[tuple[str, bytes]] | None): List of file bytes (raw file data) to be processed.
text (list[str] | None): List of strings (raw text) to be processed.

Returns:
BatchJob: The `BatchJob` representing the batch computation.
"""
request = self._construct_request(configs, urls, text, transcription_config, callback_url, notify)
return self._submit_job(request, files)
return self._submit_job(request, files, filebytes)

def get_job_details(self, job_id: str) -> BatchJobDetails:
"""Get details for the batch job.
Expand Down Expand Up @@ -197,6 +199,7 @@ def _submit_job(
self,
request_body: Any,
filepaths: list[Path | str] | None,
filebytes: Optional[list[tuple[str, bytes]]],
) -> BatchJob:
"""Start a job for batch processing by passing a JSON request body.

Expand All @@ -206,6 +209,7 @@ def _submit_job(
Args:
request_body (Any): JSON request body to be passed to the batch API.
filepaths (list[Path | str] | None): List of paths to files on the local disk to be processed.
filebytes (list[tuple[str, bytes]]] | None): List of bytes (raw file data) to be processed.

Raises:
HumeClientException: If the batch job fails to start.
Expand All @@ -215,14 +219,14 @@ def _submit_job(
"""
endpoint = self._build_endpoint("batch", "jobs")

if filepaths is None:
if filepaths is None and filebytes is None:
response = self._http_client.post(
endpoint,
json=request_body,
headers=self._get_client_headers(),
)
else:
form_data = self._get_multipart_form_data(request_body, filepaths)
form_data = self._get_multipart_form_data(request_body, filepaths, filebytes)
response = self._http_client.post(
endpoint,
headers=self._get_client_headers(),
Expand Down Expand Up @@ -253,26 +257,33 @@ def _submit_job(
def _get_multipart_form_data(
self,
request_body: Any,
filepaths: Iterable[Path | str],
) -> list[tuple[str, bytes | tuple[str, bytes]]]:
"""Convert a list of filepaths into a list of multipart form data.
filepaths: Optional[Iterable[Union[str, Path]]],
filebytes: Optional[Iterable[tuple[str, bytes]]],
) -> list[tuple[str, Union[bytes, tuple[str, bytes]]]]:
"""Convert a list of filepaths and/or file bytes into a list of multipart form data.

Multipart form data allows the client to attach files to the POST request,
including both the raw file bytes and the filename.

Args:
request_body (Any): JSON request body to be passed to the batch API.
filepaths (list[Path | str]): List of paths to files on the local disk to be processed.
filepaths (list[Path | str] | None): List of paths to files on the local disk to be processed.
filebytes (list[str | bytes] | None): List of bytes (raw file data) to be processed.

Returns:
list[tuple[str, bytes | tuple[str, bytes]]]: A list of tuples representing
the multipart form data for the POST request.
"""
form_data: list[tuple[str, bytes | tuple[str, bytes]]] = []
for filepath in filepaths:
path = Path(filepath)
post_file = ("file", (path.name, path.read_bytes()))
form_data.append(post_file)
form_data: list[tuple[str, Union[bytes, tuple[str, bytes]]]] = []
if filepaths is not None:
for filepath in filepaths:
path = Path(filepath)
post_file = ("file", (path.name, path.read_bytes()))
form_data.append(post_file)
if filebytes is not None:
for filebyte in filebytes:
post_file = ("file", filebyte)
form_data.append(post_file)

form_data.append(("json", json.dumps(request_body).encode("utf-8")))
return form_data
13 changes: 10 additions & 3 deletions tests/batch/test_hume_batch_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -38,6 +39,7 @@ def test_face(self, batch_client: HumeBatchClient) -> None:
"urls": [mock_url],
},
None,
None,
)

def test_burst(self, batch_client: HumeBatchClient) -> None:
Expand All @@ -54,6 +56,7 @@ def test_burst(self, batch_client: HumeBatchClient) -> None:
"urls": [mock_url],
},
None,
None,
)

def test_prosody(self, batch_client: HumeBatchClient) -> None:
Expand All @@ -72,6 +75,7 @@ def test_prosody(self, batch_client: HumeBatchClient) -> None:
"urls": [mock_url],
},
None,
None,
)

def test_language(self, batch_client: HumeBatchClient) -> None:
Expand All @@ -91,6 +95,7 @@ def test_language(self, batch_client: HumeBatchClient) -> None:
"urls": [mock_url],
},
None,
None,
)

def test_language_with_raw_text(self, batch_client: HumeBatchClient) -> None:
Expand All @@ -111,6 +116,7 @@ def test_language_with_raw_text(self, batch_client: HumeBatchClient) -> None:
"text": [mock_text],
},
None,
None,
)

def test_get_job(self, batch_client: HumeBatchClient) -> None:
Expand All @@ -133,6 +139,7 @@ def test_files(self, batch_client: HumeBatchClient) -> None:
},
},
["my-audio.mp3"],
None,
)

def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory) -> None:
Expand All @@ -142,9 +149,9 @@ def test_get_multipart_form_data(self, batch_client: HumeBatchClient, tmp_path_f
f.write("I can't believe this test passed!")

request_body = {"mock": "body"}
filepaths = [filepath]
# pylint: disable=protected-access
result = batch_client._get_multipart_form_data(request_body, filepaths)
filepaths: list[Path] = [filepath]
filebytes: list[tuple[str, bytes]] = []
result = batch_client._get_multipart_form_data(request_body, filepaths, filebytes)

assert result == [
("file", ("my-audio.mp3", b"I can't believe this test passed!")),
Expand Down
29 changes: 29 additions & 0 deletions tests/batch/test_service_hume_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,35 @@ def test_local_file_upload_configure(
# rather than the nine we'd get if we used 'word' granularity.
assert len(grouped_predictions[0]["predictions"]) == 1

# test for the case where a file is passed as a byte string
def test_file_upload_bytes_configure(
self, eval_data: EvalData, batch_client: HumeBatchClient, tmp_path_factory: TempPathFactory
) -> None:
data_url = eval_data["text-happy-place"]
data_filepath = tmp_path_factory.mktemp("data-dir") / "happy.txt"
urlretrieve(data_url, data_filepath)
with data_filepath.open("rb") as f:
data_bytes = f.read()
config = LanguageConfig(granularity="sentence")
job_files_dirpath = tmp_path_factory.mktemp("job-files")
job = batch_client.submit_job([], [config], filebytes=[("happy.txt", data_bytes)])
self.check_job(job, config, LanguageConfig, job_files_dirpath, complete_config=False)

predictions = job.get_predictions()

assert len(predictions) == 1
assert predictions[0]["results"]
assert len(predictions[0]["results"]["predictions"]) == 1
assert predictions[0]["source"]["type"] == "file"
assert predictions[0]["source"]["filename"] == "happy.txt"
language_results = predictions[0]["results"]["predictions"][0]["models"]["language"]
grouped_predictions = language_results["grouped_predictions"]
assert len(grouped_predictions) == 1

# Configuring 'sentence' granularity should give us only one prediction
# rather than the nine we'd get if we used 'word' granularity.
assert len(grouped_predictions[0]["predictions"]) == 1

def check_job(
self,
job: BatchJob,
Expand Down
Loading