Skip to content

Commit

Permalink
[Internal] Extract "before retry" handler, use it to rewind the stream (
Browse files Browse the repository at this point in the history
#878)

## What changes are proposed in this pull request?

- Introduce a separate handler to be called before we retry the API
call. This will make sure handler is called both when (1) we receive an
error response we want to retry on and (2) when low-level connection
exception is thrown.
- Rewind the stream to the initial position in this handler (if
applicable).

## How is this tested?

Existing tests.

**ALWAYS ANSWER THIS QUESTION:** Answer with "N/A" if tests are not
applicable
to your PR (e.g. if the PR only modifies comments). Do not be afraid of 
answering "Not tested" if the PR has not been tested. Being clear about
what
has been done and not done provides important context to the reviewers.
  • Loading branch information
ksafonov-db authored Jan 29, 2025
1 parent 4bcfb0a commit 762c57b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
41 changes: 20 additions & 21 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,29 @@ def do(self,
if isinstance(data, (str, bytes)):
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
if not data:
# The request is not a stream.
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)
elif self._is_seekable_stream(data):
# Keep track of the initial position of the stream so that we can rewind to it
# if we need to retry the request.
initial_data_position = data.tell()

def rewind():
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
data.seek(initial_data_position)

call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
before_retry=rewind)(self._perform)
else:
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
# where the retry doesn't re-read already read data from the stream.
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform

response = call(method,
url,
Expand Down Expand Up @@ -249,12 +262,6 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -266,16 +273,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
#
# TODO: This should be moved into a "before-retry" hook to avoid one
# unnecessary seek on the last failed retry before aborting.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response
Expand Down
6 changes: 5 additions & 1 deletion databricks/sdk/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def retried(*,
on: Sequence[Type[BaseException]] = None,
is_retryable: Callable[[BaseException], Optional[str]] = None,
timeout=timedelta(minutes=20),
clock: Clock = None):
clock: Clock = None,
before_retry: Callable = None):
has_allowlist = on is not None
has_callback = is_retryable is not None
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
Expand Down Expand Up @@ -54,6 +55,9 @@ def wrapper(*args, **kwargs):
raise err

logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
if before_retry:
before_retry()

clock.sleep(sleep + random())
attempt += 1
raise TimeoutError(f'Timed out after {timeout}') from last_err
Expand Down

0 comments on commit 762c57b

Please sign in to comment.