Skip to content

Commit

Permalink
Refactor remote api, improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Feb 23, 2024
1 parent 130ba5b commit f6ba842
Showing 1 changed file with 54 additions and 50 deletions.
104 changes: 54 additions & 50 deletions ai_models/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
import time
from urllib.parse import urljoin

Expand All @@ -23,70 +24,73 @@ def __init__(self, url: str, token: str, output_file: str, input_file: str = Non
self.auth = BearerAuth(token)
self.output_file = output_file
self.input_file = input_file
self._timeout = 10
self._status = self._last_status = None
self._timeout = 300

def run(self, cfg: dict, model_args: list):
cfg.pop("remote_url", None)
cfg.pop("remote_token", None)
cfg["model_args"] = model_args

if self._upload(self.input_file) == "success":
self._submit(cfg)
# upload file
with open(self.input_file, "rb") as f:
LOG.info("Uploading input file to remote")
status, href = self._request(requests.post, "upload", data=f)

while not self._ready():
time.sleep(5)
if status != "success":
LOG.error(status)
sys.exit(1)

download(urljoin(self.url, self._href), target=self.output_file)
# submit job
status, href = self._request(requests.post, href, json=cfg)

LOG.debug("Result written to %s", self.output_file)
if status != "queued":
LOG.error(status)
sys.exit(1)

def _upload(self, file):
with open(file, "rb") as f:
LOG.info("Uploading input file to remote")
r = robust(requests.post, retry_after=self._timeout)(
urljoin(self.url, "upload"),
data=f,
auth=self.auth,
timeout=self._timeout,
).json()

LOG.debug(r)
self._uid = r["id"]
self._href = r["href"]
self._status = r["status"].lower()

return self._status

def _submit(self, data):
r = robust(requests.post, retry_after=self._timeout)(
urljoin(self.url, self._href),
json=data,
auth=self.auth,
timeout=self._timeout,
).json()
LOG.info("Job status: queued")
last_status = status

while True:
status, href = self._request(requests.get, href)

if status != last_status:
LOG.info("Job status: %s", status)
last_status = status

LOG.debug(r)
self._uid = r["id"]
self._href = r["href"]
self._status = r["status"].lower()
if status == "failed":
sys.exit(1)

return self._status
if status == "ready":
break

def _poll(self):
r = robust(requests.get, retry_after=self._timeout)(
urljoin(self.url, self._href), auth=self.auth, timeout=self._timeout
).json()
time.sleep(4)

download(urljoin(self.url, href), target=self.output_file)

LOG.debug("Result written to %s", self.output_file)

def _request(self, type, href, data=None, json=None, auth=None):
r = robust(type, retry_after=self._timeout)(
urljoin(self.url, href),
json=json,
data=data,
auth=self.auth,
timeout=self._timeout,
)

LOG.debug(r)
self._href = r["href"]
self._status = r["status"].lower()
status, href = self._update_state(r)
return status, href

if self._last_status != self._status:
LOG.info("Job status: %s", self._status)
self._last_status = self._status
def _update_state(self, response: requests.Response):
if response.status_code == 401:
return "Unauthorized Access", None

return self._status
try:
data = response.json()
href = data["href"]
status = data["status"].lower()
except:
status = f"{response.status_code} {response.url} {response.text}"
href = None

def _ready(self):
return self._poll() == "ready"
return status, href

0 comments on commit f6ba842

Please sign in to comment.