From 762c57b9bfa14cc30bf5538007f116b676a50172 Mon Sep 17 00:00:00 2001 From: Kirill Safonov <122353021+ksafonov-db@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:20:47 +0100 Subject: [PATCH] [Internal] Extract "before retry" handler, use it to rewind the stream (#878) ## What changes are proposed in this pull request? - Introduce a separate handler to be called before we retry the API call. This will make sure handler is called both when (1) we receive an error response we want to retry on and (2) when low-level connection exception is thrown. - Rewind the stream to the initial position in this handler (if applicable). ## How is this tested? Existing tests. **ALWAYS ANSWER THIS QUESTION:** Answer with "N/A" if tests are not applicable to your PR (e.g. if the PR only modifies comments). Do not be afraid of answering "Not tested" if the PR has not been tested. Being clear about what has been done and not done provides important context to the reviewers. --- databricks/sdk/_base_client.py | 41 +++++++++++++++++----------------- databricks/sdk/retries.py | 6 ++++- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index e61dd39c3..f0950f656 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -159,16 +159,29 @@ def do(self, if isinstance(data, (str, bytes)): data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data) - # Only retry if the request is not a stream or if the stream is seekable and - # we can rewind it. This is necessary to avoid bugs where the retry doesn't - # re-read already read data from the body. - if data is not None and not self._is_seekable_stream(data): - logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}") - call = self._perform - else: + if not data: + # The request is not a stream. call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), is_retryable=self._is_retryable, clock=self._clock)(self._perform) + elif self._is_seekable_stream(data): + # Keep track of the initial position of the stream so that we can rewind to it + # if we need to retry the request. + initial_data_position = data.tell() + + def rewind(): + logger.debug(f"Rewinding input data to offset {initial_data_position} before retry") + data.seek(initial_data_position) + + call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), + is_retryable=self._is_retryable, + clock=self._clock, + before_retry=rewind)(self._perform) + else: + # Do not retry if the stream is not seekable. This is necessary to avoid bugs + # where the retry doesn't re-read already read data from the stream. + logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}") + call = self._perform response = call(method, url, @@ -249,12 +262,6 @@ def _perform(self, files=None, data=None, auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): - # Keep track of the initial position of the stream so that we can rewind it if - # we need to retry the request. - initial_data_position = 0 - if self._is_seekable_stream(data): - initial_data_position = data.tell() - response = self._session.request(method, url, params=self._fix_query_string(query), @@ -266,16 +273,8 @@ def _perform(self, stream=raw, timeout=self._http_timeout_seconds) self._record_request_log(response, raw=raw or data is not None or files is not None) - error = self._error_parser.get_api_error(response) if error is not None: - # If the request body is a seekable stream, rewind it so that it is ready - # to be read again in case of a retry. - # - # TODO: This should be moved into a "before-retry" hook to avoid one - # unnecessary seek on the last failed retry before aborting. - if self._is_seekable_stream(data): - data.seek(initial_data_position) raise error from None return response diff --git a/databricks/sdk/retries.py b/databricks/sdk/retries.py index b98c54281..4f55087ea 100644 --- a/databricks/sdk/retries.py +++ b/databricks/sdk/retries.py @@ -13,7 +13,8 @@ def retried(*, on: Sequence[Type[BaseException]] = None, is_retryable: Callable[[BaseException], Optional[str]] = None, timeout=timedelta(minutes=20), - clock: Clock = None): + clock: Clock = None, + before_retry: Callable = None): has_allowlist = on is not None has_callback = is_retryable is not None if not (has_allowlist or has_callback) or (has_allowlist and has_callback): @@ -54,6 +55,9 @@ def wrapper(*args, **kwargs): raise err logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)') + if before_retry: + before_retry() + clock.sleep(sleep + random()) attempt += 1 raise TimeoutError(f'Timed out after {timeout}') from last_err