Skip to content

Commit

Permalink
Utility for getting error messages from failed batch jobs (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorybchris authored Sep 21, 2022
1 parent e0c1ba1 commit 71c035e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
38 changes: 27 additions & 11 deletions hume/_batch/batch_job_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
predictions_url: Optional[str] = None,
artifacts_url: Optional[str] = None,
errors_url: Optional[str] = None,
error_message: Optional[str] = None,
):
"""Construct a BatchJobResult.
Expand All @@ -33,23 +34,15 @@ def __init__(
predictions_url (Optional[str]): URL to predictions file.
artifacts_url (Optional[str]): URL to artifacts zip archive.
errors_url (Optional[str]): URL to errors file.
error_message (Optional[str]): Error message for request.
"""
self.configs = configs
self.urls = urls
self.status = status
self.predictions_url = predictions_url
self.artifacts_url = artifacts_url
self.errors_url = errors_url
self.predictions_url = predictions_url

def download_artifacts(self, filepath: Optional[Union[str, Path]] = None) -> None:
"""Download `BatchJob` artifacts zip archive.
Args:
filepath (Optional[Union[str, Path]]): Filepath where artifacts zip archive will be downloaded.
"""
if self.artifacts_url is None:
raise HumeClientError("Could not download job artifacts. No artifacts found on job result.")
urlretrieve(self.artifacts_url, filepath)
self.error_message = error_message

def download_predictions(self, filepath: Optional[Union[str, Path]] = None) -> None:
"""Download `BatchJob` predictions file.
Expand All @@ -61,6 +54,16 @@ def download_predictions(self, filepath: Optional[Union[str, Path]] = None) -> N
raise HumeClientError("Could not download job predictions. No predictions found on job result.")
urlretrieve(self.predictions_url, filepath)

def download_artifacts(self, filepath: Optional[Union[str, Path]] = None) -> None:
"""Download `BatchJob` artifacts zip archive.
Args:
filepath (Optional[Union[str, Path]]): Filepath where artifacts zip archive will be downloaded.
"""
if self.artifacts_url is None:
raise HumeClientError("Could not download job artifacts. No artifacts found on job result.")
urlretrieve(self.artifacts_url, filepath)

def download_errors(self, filepath: Optional[Union[str, Path]] = None) -> None:
"""Download `BatchJob` errors file.
Expand All @@ -71,6 +74,14 @@ def download_errors(self, filepath: Optional[Union[str, Path]] = None) -> None:
raise HumeClientError("Could not download job errors. No errors found on job result.")
urlretrieve(self.errors_url, filepath)

def get_error_message(self) -> Optional[str]:
"""Get any available error messages on the job.
Returns:
Optional[str]: A string with the error message if there was an error, otherwise None.
"""
return self.error_message

@classmethod
def from_response(cls, response: Any) -> "BatchJobResult":
"""Construct a `BatchJobResult` from a batch API job response.
Expand All @@ -96,6 +107,11 @@ def from_response(cls, response: Any) -> "BatchJobResult":
kwargs["errors_url"] = completed_dict["errors_url"]
kwargs["predictions_url"] = completed_dict["predictions_url"]

if "failed" in response:
failed_dict = response["failed"]
if "message" in failed_dict:
kwargs["error_message"] = failed_dict["message"]

return cls(
configs=configs,
urls=request["urls"],
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ license = "Proprietary"
name = "hume"
readme = "README.md"
repository = "https://github.com/HumeAI/hume-python-sdk"
version = "0.1.2"
version = "0.1.3"

[tool.poetry.dependencies]
python = ">=3.8,<3.10"
Expand Down Expand Up @@ -64,8 +64,8 @@ branch = 68.0
line = 82.0

[tool.covcheck.group.service.coverage]
branch = 80.0
line = 93.0
branch = 77.0
line = 92.0

[tool.flake8]
ignore = "" # Required to disable default ignores
Expand Down
23 changes: 23 additions & 0 deletions tests/batch/data/result-response-failed.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"completion_timestamp": 1663790201,
"creation_timestamp": 1663790199,
"failed": {
"message": "user 'abcde' has exceeded their usage limit"
},
"request": {
"models": {
"face": {
"descriptions": null,
"facs": null,
"fps_pred": 3.0,
"identify_faces": false,
"min_face_size": 60.0,
"prob_threshold": 0.9900000095367432,
"save_faces": false
}
},
"notify": false,
"urls": ["https://storage.googleapis.com/hume-test-data/image/obama.png"]
},
"status": "FAILED"
}
12 changes: 12 additions & 0 deletions tests/batch/test_batch_job_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def queued_result() -> BatchJobResult:
return BatchJobResult.from_response(response)


@pytest.fixture(scope="function")
def failed_result() -> BatchJobResult:
response_filepath = Path(__file__).parent / "data" / "result-response-failed.json"
with response_filepath.open() as f:
response = json.load(f)
return BatchJobResult.from_response(response)


class TestBatchJobResult:

def test_queued_status(self, queued_result: BatchJobResult):
Expand All @@ -46,3 +54,7 @@ def test_completed(self, completed_result: BatchJobResult):
assert completed_result.predictions_url is not None
assert completed_result.errors_url is not None
assert completed_result.artifacts_url is not None

def test_failed_message(self, failed_result: BatchJobResult):
assert failed_result.status == BatchJobStatus.FAILED
assert failed_result.get_error_message() == "user 'abcde' has exceeded their usage limit"

0 comments on commit 71c035e

Please sign in to comment.