Skip to content

Commit

Permalink
Shg/throttle (#175)
Browse files Browse the repository at this point in the history
* Create the `ConnectionThrottled` exception

* remove `RateLimited` from the `APIErrorCode` class

* Raise a ConnectionThrottled exception if an exception with the "rate_limited" code is raised

* Add `Client.retry` to determine wether or not rate limited requests should be retried

* Add a docstring to `retry_api_call` and implement `Client.retry`

* Update the readme and `setup.py`

* Remove `104` from the list of status codes that initiate a retry

* Update and add tests

* update the readme
  • Loading branch information
sHermanGriffiths authored May 7, 2024
1 parent 3976087 commit eb016a0
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 26 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ Here are some features we're planning to add in the future:

## Changelog

### v0.10.1
- Remove the `104` error status code from the list of error status codes that initate a retry in the
`retry_api_call` wrapper function.
- Add the `Client.retry` attribute to determine wether or not API calls should be retried after
being rate limited in the `retry_api_call` wrapper function.
- Remove unshared synced block warning and leave a comment explaining why it was removed.
- Create and implement the `ConnectionThrottled` exception

### v0.10.0
- Instead of importing the logger from the `n2y.logger` module, pass it as an argument wherever
necessary to allow custom loggers to be used.
Expand Down
26 changes: 20 additions & 6 deletions n2y/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ class UseNextClass(N2YError):
pass


class ConnectionThrottled(N2YError):
"""
Raised when the connection is throttled by the Notion API.
"""

def __init__(self, response, message=None) -> None:
retry = response.headers.get("retry-after")
self.retry_after = float(retry) if retry else None
self.status = response.status_code
self.headers = response.headers
self.body = response.text
if message is None:
message = (
"Your connection has been throttled by the Notion API for"
f" {self.retry_after} seconds. Please try again later."
)
super().__init__(message)


class RequestTimeoutError(N2YError):
"""
Exception for requests that timeout.
Expand All @@ -52,9 +71,7 @@ class HTTPResponseError(N2YError):

def __init__(self, response, message=None) -> None:
if message is None:
message = (
f"Request to Notion API failed with status: {response.status_code}"
)
message = f"Request to Notion API failed with status: {response.status_code}"
super().__init__(message)
self.status = response.status_code
self.headers = response.headers
Expand Down Expand Up @@ -89,9 +106,6 @@ class APIErrorCode(str, Enum):
This error can also indicate that the resource has not been shared with owner
of the bearer token."""

RateLimited = "rate_limited"
"""This request exceeds the number of requests allowed. Slow down and try again."""

InvalidJSON = "invalid_json"
"""The request body could not be decoded as JSON."""

Expand Down
5 changes: 5 additions & 0 deletions n2y/notion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from n2y.errors import (
APIErrorCode,
APIResponseError,
ConnectionThrottled,
HTTPResponseError,
ObjectNotFound,
PluginError,
Expand Down Expand Up @@ -70,11 +71,13 @@ def __init__(
plugins=None,
export_defaults=None,
logger=log,
retry=True,
):
self.access_token = access_token
self.media_root = media_root
self.media_url = media_url
self.logger = logger
self.retry = retry
self.export_defaults = export_defaults or merge_default_config({})

self.base_url = "https://api.notion.com/v1/"
Expand Down Expand Up @@ -451,6 +454,8 @@ def _parse_response(self, response, stream=False):
code = None
if code == APIErrorCode.ObjectNotFound:
raise ObjectNotFound(response, body["message"])
elif code == "rate_limited":
raise ConnectionThrottled(error.response)
elif code and is_api_error_code(code):
raise APIResponseError(response, body["message"], code)
raise HTTPResponseError(error.response)
Expand Down
14 changes: 9 additions & 5 deletions n2y/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,23 +299,27 @@ def load_yaml(data):


def retry_api_call(api_call):
"""
Retry an API call if it fails due to a rate limit or server error. Can only be used to
decorate methods of the `Client` class.
"""
max_api_retries = 4

@functools.wraps(api_call)
def wrapper(*args, retry_count=0, **kwargs):
client = args[0]
assert "retry_count" not in kwargs, "retry_count is a reserved keyword"
try:
return api_call(*args, **kwargs)
except HTTPResponseError as err:
should_retry = err.status in [409, 429, 500, 502, 504, 503, 104]
should_retry = err.status in [409, 429, 500, 502, 504, 503]
if not should_retry:
raise err
elif retry_count < max_api_retries:
retry_count += 1
log = args[0].logger if args and hasattr(args[0], "logger") else logger
if "retry-after" in err.headers:
if client.retry and "retry-after" in err.headers:
retry_after = float(err.headers["retry-after"])
log.info(
client.logger.info(
"This API call has been rate limited and "
"will be retried in %f seconds. Attempt %d of %d.",
retry_after,
Expand All @@ -324,7 +328,7 @@ def wrapper(*args, retry_count=0, **kwargs):
)
else:
retry_after = 2
log.info(
client.logger.info(
"This API call failed and "
"will be retried in %f seconds. Attempt %d of %d.",
retry_after,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name="n2y",
version="0.10.0",
version="0.10.1",
description=description,
long_description=description,
long_description_content_type="text/x-rst",
Expand Down
47 changes: 33 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from datetime import datetime, timedelta, timezone
from math import isclose
from pytest import raises
from datetime import datetime, timezone, timedelta

import pytest
from pandoc.types import MetaMap, MetaList, MetaBool, MetaString
from pandoc.types import MetaBool, MetaList, MetaMap, MetaString
from pytest import raises

from n2y.errors import APIResponseError
from n2y.notion import Client
from n2y.utils import (
fromisoformat,
header_id_from_text,
Expand All @@ -14,6 +15,8 @@
yaml_to_meta_value,
)

foo_token = "foo_token"


def test_fromisoformat_datetime():
expected = datetime(2022, 5, 10, 19, 52, tzinfo=timezone.utc)
Expand Down Expand Up @@ -65,20 +68,22 @@ def __init__(self, time, code):


def test_retry_api_call_no_error():
client = Client(foo_token)

@retry_api_call
def tester():
def tester(_):
return True

assert tester()
assert tester(client)


def test_retry_api_call_multiple_errors():
status_code = 429

client = Client(foo_token)
call_count = 0

@retry_api_call
def tester(time):
def tester(_, time):
seconds = timedelta.total_seconds(datetime.now() - time)
nonlocal call_count
call_count += 1
Expand All @@ -95,15 +100,16 @@ def tester(time):
assert isclose(0.51, seconds, abs_tol=0.1)
return True

assert tester(datetime.now())
assert tester(client, datetime.now())


@pytest.mark.parametrize("status_code", [409, 429, 500, 502, 504])
def test_retry_api_call_onc(status_code):
@pytest.mark.parametrize("status_code", [409, 429, 500, 502, 504, 503])
def test_retry_api_call_once(status_code):
call_count = 0
client = Client(foo_token)

@retry_api_call
def tester():
def tester(_):
nonlocal call_count
call_count += 1

Expand All @@ -112,19 +118,32 @@ def tester():
else:
return True

assert tester()
assert tester(client)
assert call_count == 2


def test_retry_api_call_max_errors():
status_code = 429
client = Client(foo_token)

@retry_api_call
def tester(_):
raise APIResponseError(MockResponse(0.001, status_code), "", status_code)

with raises(APIResponseError):
tester(client)


def test_retry_api_call_retry_false():
status_code = 429
client = Client(foo_token, retry=False)

@retry_api_call
def tester():
def tester(_):
raise APIResponseError(MockResponse(0.001, status_code), "", status_code)

with raises(APIResponseError):
tester()
tester(client)


def test_yaml_to_meta_value_scalar():
Expand Down

0 comments on commit eb016a0

Please sign in to comment.