Skip to content

Commit

Permalink
Factor out stream rewind before retry
Browse files Browse the repository at this point in the history
  • Loading branch information
ksafonov-db committed Jan 29, 2025
1 parent 4bcfb0a commit 8c6d9a6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
41 changes: 20 additions & 21 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,29 @@ def do(self,
if isinstance(data, (str, bytes)):
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
if not data:
# The request is not a stream.
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)
elif self._is_seekable_stream(data):
# Keep track of the initial position of the stream so that we can rewind to it
# if we need to retry the request.
initial_data_position = data.tell()

def rewind():
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
data.seek(initial_data_position)

call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
before_retry=rewind)(self._perform)
else:
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
# where the retry doesn't re-read already read data from the stream.
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform

response = call(method,
url,
Expand Down Expand Up @@ -249,12 +262,6 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -266,16 +273,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
#
# TODO: This should be moved into a "before-retry" hook to avoid one
# unnecessary seek on the last failed retry before aborting.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response
Expand Down
6 changes: 5 additions & 1 deletion databricks/sdk/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def retried(*,
on: Sequence[Type[BaseException]] = None,
is_retryable: Callable[[BaseException], Optional[str]] = None,
timeout=timedelta(minutes=20),
clock: Clock = None):
clock: Clock = None,
before_retry: Callable = None):
has_allowlist = on is not None
has_callback = is_retryable is not None
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
Expand Down Expand Up @@ -54,6 +55,9 @@ def wrapper(*args, **kwargs):
raise err

logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
if before_retry:
before_retry()

clock.sleep(sleep + random())
attempt += 1
raise TimeoutError(f'Timed out after {timeout}') from last_err
Expand Down

0 comments on commit 8c6d9a6

Please sign in to comment.