From 71c035e80d675926b41e8328daa60f72258a009e Mon Sep 17 00:00:00 2001 From: Chris Gregory <8800689+gregorybchris@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:32:37 -0700 Subject: [PATCH] Utility for getting error messages from failed batch jobs (#22) --- hume/_batch/batch_job_result.py | 38 ++++++++++++++------ pyproject.toml | 6 ++-- tests/batch/data/result-response-failed.json | 23 ++++++++++++ tests/batch/test_batch_job_result.py | 12 +++++++ 4 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 tests/batch/data/result-response-failed.json diff --git a/hume/_batch/batch_job_result.py b/hume/_batch/batch_job_result.py index 31ebe7a5..b623ba95 100644 --- a/hume/_batch/batch_job_result.py +++ b/hume/_batch/batch_job_result.py @@ -23,6 +23,7 @@ def __init__( predictions_url: Optional[str] = None, artifacts_url: Optional[str] = None, errors_url: Optional[str] = None, + error_message: Optional[str] = None, ): """Construct a BatchJobResult. @@ -33,23 +34,15 @@ def __init__( predictions_url (Optional[str]): URL to predictions file. artifacts_url (Optional[str]): URL to artifacts zip archive. errors_url (Optional[str]): URL to errors file. + error_message (Optional[str]): Error message for request. """ self.configs = configs self.urls = urls self.status = status + self.predictions_url = predictions_url self.artifacts_url = artifacts_url self.errors_url = errors_url - self.predictions_url = predictions_url - - def download_artifacts(self, filepath: Optional[Union[str, Path]] = None) -> None: - """Download `BatchJob` artifacts zip archive. - - Args: - filepath (Optional[Union[str, Path]]): Filepath where artifacts zip archive will be downloaded. - """ - if self.artifacts_url is None: - raise HumeClientError("Could not download job artifacts. No artifacts found on job result.") - urlretrieve(self.artifacts_url, filepath) + self.error_message = error_message def download_predictions(self, filepath: Optional[Union[str, Path]] = None) -> None: """Download `BatchJob` predictions file. @@ -61,6 +54,16 @@ def download_predictions(self, filepath: Optional[Union[str, Path]] = None) -> N raise HumeClientError("Could not download job predictions. No predictions found on job result.") urlretrieve(self.predictions_url, filepath) + def download_artifacts(self, filepath: Optional[Union[str, Path]] = None) -> None: + """Download `BatchJob` artifacts zip archive. + + Args: + filepath (Optional[Union[str, Path]]): Filepath where artifacts zip archive will be downloaded. + """ + if self.artifacts_url is None: + raise HumeClientError("Could not download job artifacts. No artifacts found on job result.") + urlretrieve(self.artifacts_url, filepath) + def download_errors(self, filepath: Optional[Union[str, Path]] = None) -> None: """Download `BatchJob` errors file. @@ -71,6 +74,14 @@ def download_errors(self, filepath: Optional[Union[str, Path]] = None) -> None: raise HumeClientError("Could not download job errors. No errors found on job result.") urlretrieve(self.errors_url, filepath) + def get_error_message(self) -> Optional[str]: + """Get any available error messages on the job. + + Returns: + Optional[str]: A string with the error message if there was an error, otherwise None. + """ + return self.error_message + @classmethod def from_response(cls, response: Any) -> "BatchJobResult": """Construct a `BatchJobResult` from a batch API job response. @@ -96,6 +107,11 @@ def from_response(cls, response: Any) -> "BatchJobResult": kwargs["errors_url"] = completed_dict["errors_url"] kwargs["predictions_url"] = completed_dict["predictions_url"] + if "failed" in response: + failed_dict = response["failed"] + if "message" in failed_dict: + kwargs["error_message"] = failed_dict["message"] + return cls( configs=configs, urls=request["urls"], diff --git a/pyproject.toml b/pyproject.toml index 0cab08db..39fb8b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license = "Proprietary" name = "hume" readme = "README.md" repository = "https://github.com/HumeAI/hume-python-sdk" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] python = ">=3.8,<3.10" @@ -64,8 +64,8 @@ branch = 68.0 line = 82.0 [tool.covcheck.group.service.coverage] -branch = 80.0 -line = 93.0 +branch = 77.0 +line = 92.0 [tool.flake8] ignore = "" # Required to disable default ignores diff --git a/tests/batch/data/result-response-failed.json b/tests/batch/data/result-response-failed.json new file mode 100644 index 00000000..b0d94ff7 --- /dev/null +++ b/tests/batch/data/result-response-failed.json @@ -0,0 +1,23 @@ +{ + "completion_timestamp": 1663790201, + "creation_timestamp": 1663790199, + "failed": { + "message": "user 'abcde' has exceeded their usage limit" + }, + "request": { + "models": { + "face": { + "descriptions": null, + "facs": null, + "fps_pred": 3.0, + "identify_faces": false, + "min_face_size": 60.0, + "prob_threshold": 0.9900000095367432, + "save_faces": false + } + }, + "notify": false, + "urls": ["https://storage.googleapis.com/hume-test-data/image/obama.png"] + }, + "status": "FAILED" +} diff --git a/tests/batch/test_batch_job_result.py b/tests/batch/test_batch_job_result.py index 9dfcc24e..06f95232 100644 --- a/tests/batch/test_batch_job_result.py +++ b/tests/batch/test_batch_job_result.py @@ -22,6 +22,14 @@ def queued_result() -> BatchJobResult: return BatchJobResult.from_response(response) +@pytest.fixture(scope="function") +def failed_result() -> BatchJobResult: + response_filepath = Path(__file__).parent / "data" / "result-response-failed.json" + with response_filepath.open() as f: + response = json.load(f) + return BatchJobResult.from_response(response) + + class TestBatchJobResult: def test_queued_status(self, queued_result: BatchJobResult): @@ -46,3 +54,7 @@ def test_completed(self, completed_result: BatchJobResult): assert completed_result.predictions_url is not None assert completed_result.errors_url is not None assert completed_result.artifacts_url is not None + + def test_failed_message(self, failed_result: BatchJobResult): + assert failed_result.status == BatchJobStatus.FAILED + assert failed_result.get_error_message() == "user 'abcde' has exceeded their usage limit"