From eb016a0b00921e32ac0598a225a26928ccf0eb9a Mon Sep 17 00:00:00 2001 From: Sam'an Herman-Griffiths <100145229+sHermanGriffiths@users.noreply.github.com> Date: Tue, 7 May 2024 14:38:22 -0500 Subject: [PATCH] Shg/throttle (#175) * Create the `ConnectionThrottled` exception * remove `RateLimited` from the `APIErrorCode` class * Raise a ConnectionThrottled exception if an exception with the "rate_limited" code is raised * Add `Client.retry` to determine wether or not rate limited requests should be retried * Add a docstring to `retry_api_call` and implement `Client.retry` * Update the readme and `setup.py` * Remove `104` from the list of status codes that initiate a retry * Update and add tests * update the readme --- README.md | 8 ++++++++ n2y/errors.py | 26 +++++++++++++++++++------ n2y/notion.py | 5 +++++ n2y/utils.py | 14 +++++++++----- setup.py | 2 +- tests/test_utils.py | 47 +++++++++++++++++++++++++++++++-------------- 6 files changed, 76 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 50797d18..daf33170 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,14 @@ Here are some features we're planning to add in the future: ## Changelog +### v0.10.1 +- Remove the `104` error status code from the list of error status codes that initate a retry in the + `retry_api_call` wrapper function. +- Add the `Client.retry` attribute to determine wether or not API calls should be retried after + being rate limited in the `retry_api_call` wrapper function. +- Remove unshared synced block warning and leave a comment explaining why it was removed. +- Create and implement the `ConnectionThrottled` exception + ### v0.10.0 - Instead of importing the logger from the `n2y.logger` module, pass it as an argument wherever necessary to allow custom loggers to be used. diff --git a/n2y/errors.py b/n2y/errors.py index 6a4758fb..61779cbf 100644 --- a/n2y/errors.py +++ b/n2y/errors.py @@ -29,6 +29,25 @@ class UseNextClass(N2YError): pass +class ConnectionThrottled(N2YError): + """ + Raised when the connection is throttled by the Notion API. + """ + + def __init__(self, response, message=None) -> None: + retry = response.headers.get("retry-after") + self.retry_after = float(retry) if retry else None + self.status = response.status_code + self.headers = response.headers + self.body = response.text + if message is None: + message = ( + "Your connection has been throttled by the Notion API for" + f" {self.retry_after} seconds. Please try again later." + ) + super().__init__(message) + + class RequestTimeoutError(N2YError): """ Exception for requests that timeout. @@ -52,9 +71,7 @@ class HTTPResponseError(N2YError): def __init__(self, response, message=None) -> None: if message is None: - message = ( - f"Request to Notion API failed with status: {response.status_code}" - ) + message = f"Request to Notion API failed with status: {response.status_code}" super().__init__(message) self.status = response.status_code self.headers = response.headers @@ -89,9 +106,6 @@ class APIErrorCode(str, Enum): This error can also indicate that the resource has not been shared with owner of the bearer token.""" - RateLimited = "rate_limited" - """This request exceeds the number of requests allowed. Slow down and try again.""" - InvalidJSON = "invalid_json" """The request body could not be decoded as JSON.""" diff --git a/n2y/notion.py b/n2y/notion.py index b1917b65..99be508a 100644 --- a/n2y/notion.py +++ b/n2y/notion.py @@ -13,6 +13,7 @@ from n2y.errors import ( APIErrorCode, APIResponseError, + ConnectionThrottled, HTTPResponseError, ObjectNotFound, PluginError, @@ -70,11 +71,13 @@ def __init__( plugins=None, export_defaults=None, logger=log, + retry=True, ): self.access_token = access_token self.media_root = media_root self.media_url = media_url self.logger = logger + self.retry = retry self.export_defaults = export_defaults or merge_default_config({}) self.base_url = "https://api.notion.com/v1/" @@ -451,6 +454,8 @@ def _parse_response(self, response, stream=False): code = None if code == APIErrorCode.ObjectNotFound: raise ObjectNotFound(response, body["message"]) + elif code == "rate_limited": + raise ConnectionThrottled(error.response) elif code and is_api_error_code(code): raise APIResponseError(response, body["message"], code) raise HTTPResponseError(error.response) diff --git a/n2y/utils.py b/n2y/utils.py index f1bc21fc..cc5cd2c5 100644 --- a/n2y/utils.py +++ b/n2y/utils.py @@ -299,23 +299,27 @@ def load_yaml(data): def retry_api_call(api_call): + """ + Retry an API call if it fails due to a rate limit or server error. Can only be used to + decorate methods of the `Client` class. + """ max_api_retries = 4 @functools.wraps(api_call) def wrapper(*args, retry_count=0, **kwargs): + client = args[0] assert "retry_count" not in kwargs, "retry_count is a reserved keyword" try: return api_call(*args, **kwargs) except HTTPResponseError as err: - should_retry = err.status in [409, 429, 500, 502, 504, 503, 104] + should_retry = err.status in [409, 429, 500, 502, 504, 503] if not should_retry: raise err elif retry_count < max_api_retries: retry_count += 1 - log = args[0].logger if args and hasattr(args[0], "logger") else logger - if "retry-after" in err.headers: + if client.retry and "retry-after" in err.headers: retry_after = float(err.headers["retry-after"]) - log.info( + client.logger.info( "This API call has been rate limited and " "will be retried in %f seconds. Attempt %d of %d.", retry_after, @@ -324,7 +328,7 @@ def wrapper(*args, retry_count=0, **kwargs): ) else: retry_after = 2 - log.info( + client.logger.info( "This API call failed and " "will be retried in %f seconds. Attempt %d of %d.", retry_after, diff --git a/setup.py b/setup.py index 7260948f..8f19350a 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name="n2y", - version="0.10.0", + version="0.10.1", description=description, long_description=description, long_description_content_type="text/x-rst", diff --git a/tests/test_utils.py b/tests/test_utils.py index d52a943b..879bc412 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,12 @@ +from datetime import datetime, timedelta, timezone from math import isclose -from pytest import raises -from datetime import datetime, timezone, timedelta import pytest -from pandoc.types import MetaMap, MetaList, MetaBool, MetaString +from pandoc.types import MetaBool, MetaList, MetaMap, MetaString +from pytest import raises from n2y.errors import APIResponseError +from n2y.notion import Client from n2y.utils import ( fromisoformat, header_id_from_text, @@ -14,6 +15,8 @@ yaml_to_meta_value, ) +foo_token = "foo_token" + def test_fromisoformat_datetime(): expected = datetime(2022, 5, 10, 19, 52, tzinfo=timezone.utc) @@ -65,20 +68,22 @@ def __init__(self, time, code): def test_retry_api_call_no_error(): + client = Client(foo_token) + @retry_api_call - def tester(): + def tester(_): return True - assert tester() + assert tester(client) def test_retry_api_call_multiple_errors(): status_code = 429 - + client = Client(foo_token) call_count = 0 @retry_api_call - def tester(time): + def tester(_, time): seconds = timedelta.total_seconds(datetime.now() - time) nonlocal call_count call_count += 1 @@ -95,15 +100,16 @@ def tester(time): assert isclose(0.51, seconds, abs_tol=0.1) return True - assert tester(datetime.now()) + assert tester(client, datetime.now()) -@pytest.mark.parametrize("status_code", [409, 429, 500, 502, 504]) -def test_retry_api_call_onc(status_code): +@pytest.mark.parametrize("status_code", [409, 429, 500, 502, 504, 503]) +def test_retry_api_call_once(status_code): call_count = 0 + client = Client(foo_token) @retry_api_call - def tester(): + def tester(_): nonlocal call_count call_count += 1 @@ -112,19 +118,32 @@ def tester(): else: return True - assert tester() + assert tester(client) assert call_count == 2 def test_retry_api_call_max_errors(): status_code = 429 + client = Client(foo_token) + + @retry_api_call + def tester(_): + raise APIResponseError(MockResponse(0.001, status_code), "", status_code) + + with raises(APIResponseError): + tester(client) + + +def test_retry_api_call_retry_false(): + status_code = 429 + client = Client(foo_token, retry=False) @retry_api_call - def tester(): + def tester(_): raise APIResponseError(MockResponse(0.001, status_code), "", status_code) with raises(APIResponseError): - tester() + tester(client) def test_yaml_to_meta_value_scalar():