Skip to content

Commit

Permalink
Shg/api err codes (#184)
Browse files Browse the repository at this point in the history
* Update error codes from documentation

* Update error codes from documentation

* Optimize the `APIErrorCode` and how it's used

* Flake8 Compliance

* Add types to `APIErrorCode.__new__`

* Small adjustments

* Flake8 Compliance

* Minor improvement to `APIErrorCode`
  • Loading branch information
sHermanGriffiths authored Jul 10, 2024
1 parent e3eabf4 commit a7ef203
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 77 deletions.
108 changes: 57 additions & 51 deletions n2y/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum


class N2YError(Exception):
Expand All @@ -10,24 +10,18 @@ class PandocASTParseError(N2YError):
Raised if there was an error parsing the AST we provided to Pandoc.
"""

pass


class PluginError(N2YError):
"""
Raised due to various errors loading a plugin.
"""

pass


class UseNextClass(N2YError):
"""
Used by plugin classes to indicate that the next class should be used instead of them.
"""

pass


class RequestTimeoutError(N2YError):
"""
Expand Down Expand Up @@ -67,81 +61,93 @@ def __init__(self, response, message, code) -> None:
self.code = code


class ConnectionThrottled(HTTPResponseError):
class ConnectionThrottled(APIResponseError):
"""
Raised when the connection is throttled by the Notion API.
"""

def __init__(self, response, message=None) -> None:
retry = response.headers.get("retry-after")
self.retry_after = float(retry) if retry else None
self.retry_after = (
float(retry) if (retry := response.headers.get("retry-after")) else 0
)
if message is None:
message = (
"Your connection has been throttled by the Notion API for"
f" {self.retry_after} seconds. Please try again later."
)
super().__init__(response, message)
super().__init__(response, message, APIErrorCode.RateLimited)


class ObjectNotFound(APIResponseError):
def __init__(self, response, message) -> None:
code = APIErrorCode.ObjectNotFound
super().__init__(response, f"{message} [{code}]", code)
self.code = code
super().__init__(response, message, APIErrorCode.ObjectNotFound)


class APIErrorCode(str, Enum):
Unauthorized = "unauthorized"
"""The bearer token is not valid."""
class APIErrorCode(StrEnum):
is_retryable: bool

RestrictedResource = "restricted_resource"
"""Given the bearer token used, the client doesn't have permission to
perform this operation."""
def __new__(cls, code: str, is_retryable: bool):
obj = str.__new__(cls, code)
obj._value_ = code
obj.is_retryable = is_retryable
cls.RetryableCodes = [ec.value for ec in cls if ec.is_retryable]
cls.NonretryableCodes = [ec.value for ec in cls if not ec.is_retryable]
return obj

ObjectNotFound = "object_not_found"
"""Given the bearer token used, the resource does not exist.
This error can also indicate that the resource has not been shared with owner
of the bearer token."""
BadGateway = "bad_gateway", True
"""Notion encountered an issue while attempting to complete this request.
Please try again."""

InvalidJSON = "invalid_json"
"""The request body could not be decoded as JSON."""
ConflictError = "conflict_error", True
"""The transaction could not be completed, potentially due to a data collision.
Make sure the parameters are up to date and try again."""

InvalidRequestURL = "invalid_request_url"
"""The request URL is not valid."""
DatabaseConnectionUnavailable = "database_connection_unavailable", True
"""Notion's database is unavailable or in an unqueryable state. Try again later."""

InvalidRequest = "invalid_request"
"""This request is not supported."""
GatewayTimeout = "gateway_timeout", True
"""Notion timed out while attempting to complete this request.
Please try again later."""

ValidationError = "validation_error"
"""The request body does not match the schema for the expected parameters."""
InternalServerError = "internal_server_error", True
"""An unexpected error occurred. Reach out to Notion support."""

ConflictError = "conflict_error"
"""The transaction could not be completed, potentially due to a data collision.
Make sure the parameters are up to date and try again."""
InvalidGrant = "invalid_grant", False
"""The authorization code or refresh token is not valid."""

InternalServerError = "internal_server_error"
"""An unexpected error occurred. Reach out to Notion support."""
InvalidJSON = "invalid_json", False
"""The request body could not be decoded as JSON."""

ServiceUnavailable = "service_unavailable"
"""Notion is unavailable. Try again later.
This can occur when the time to respond to a request takes longer than 60 seconds,
the maximum request timeout."""
InvalidRequest = "invalid_request", False
"""This request is not supported."""

GatewayTimeoutError = "gateway_timeout"
"""Notion timed out while attempting to complete this request. Please try again later."""
InvalidRequestURL = "invalid_request_url", False
"""The request URL is not valid."""

MissingVersion = "missing_version"
MissingVersion = "missing_version", False
"""The request is missing the required Notion-Version header"""

DatabaseConnectionUnavailable = "database_connection_unavailable"
"""Notion's database is unavailable or in an unqueryable state. Try again later."""
ObjectNotFound = "object_not_found", False
"""Given the bearer token used, the resource does not exist.
This error can also indicate that the resource has not been shared with owner
of the bearer token."""

RateLimited = "rate_limited", True
"""The client has sent too many requests in a given amount of time."""

RestrictedResource = "restricted_resource", False
"""Given the bearer token used, the client doesn't have permission to
perform this operation."""

def is_api_error_code(code: str) -> bool:
"""Check if given code belongs to the list of valid API error codes."""
if isinstance(code, str):
return code in (error_code.value for error_code in APIErrorCode)
return False
ServiceUnavailable = "service_unavailable", True
"""Notion is unavailable. Try again later. This can occur when the time to respond
to a request takes longer than 60 seconds, the maximum request timeout."""

Unauthorized = "unauthorized", False
"""The bearer token is not valid."""

ValidationError = "validation_error", False
"""The request body does not match the schema for the expected parameters."""


# Some of this code was taken from https://github.com/ramnes/notion-sdk-py
Expand Down
7 changes: 3 additions & 4 deletions n2y/notion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
ObjectNotFound,
PluginError,
UseNextClass,
is_api_error_code,
)
from n2y.file import File
from n2y.logger import logger as log
Expand Down Expand Up @@ -208,7 +207,7 @@ def _wrap_notion_page(self, notion_data):
# the currently favored page class. Otherwise, plugins set for one export
# will be used in another or a plugin will be set but not used.
# As we currently use the `jinjarenderpage` plugin for all pages,
# this check is most likely uneccessary at this point.
# this check is most likely unnecessary at this point.
self.pages_cache[notion_data["id"]],
"page",
):
Expand Down Expand Up @@ -455,9 +454,9 @@ def _parse_response(self, response, stream=False):
code = None
if code == APIErrorCode.ObjectNotFound:
raise ObjectNotFound(response, body["message"])
elif code == "rate_limited":
elif code == APIErrorCode.RateLimited:
raise ConnectionThrottled(error.response)
elif code and is_api_error_code(code):
elif code and code in APIErrorCode:
raise APIResponseError(response, body["message"], code)
raise HTTPResponseError(error.response)
return response.json() if not stream else response.content
Expand Down
16 changes: 10 additions & 6 deletions n2y/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from pandoc.types import Meta, MetaBool, MetaList, MetaMap, MetaString, Space, Str
from plumbum import ProcessExecutionError

from n2y.errors import HTTPResponseError, PandocASTParseError
from n2y.errors import (
APIErrorCode,
APIResponseError,
ConnectionThrottled,
PandocASTParseError,
)
from n2y.logger import logger

# see https://pandoc.org/MANUAL.html#exit-codes
Expand Down Expand Up @@ -311,14 +316,13 @@ def wrapper(*args, retry_count=0, **kwargs):
assert "retry_count" not in kwargs, "retry_count is a reserved keyword"
try:
return api_call(*args, **kwargs)
except HTTPResponseError as err:
should_retry = err.status in [409, 429, 500, 502, 504, 503]
if not should_retry:
except APIResponseError as err:
if err.code not in APIErrorCode.RetryableCodes:
raise err
elif retry_count < max_api_retries:
retry_count += 1
if client.retry and "retry-after" in err.headers:
retry_after = float(err.headers["retry-after"])
if client.retry and isinstance(err, ConnectionThrottled):
retry_after = err.retry_after
client.logger.info(
"This API call has been rate limited and "
"will be retried in %f seconds. Attempt %d of %d.",
Expand Down
30 changes: 14 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandoc.types import MetaBool, MetaList, MetaMap, MetaString
from pytest import raises

from n2y.errors import APIResponseError, ConnectionThrottled
from n2y.errors import APIErrorCode, APIResponseError, ConnectionThrottled
from n2y.notion import Client
from n2y.utils import (
fromisoformat,
Expand All @@ -16,6 +16,7 @@
)

foo_token = "foo_token"
rate_limited_status_code = 429


def test_fromisoformat_datetime():
Expand Down Expand Up @@ -61,10 +62,10 @@ def test_page_id_from_share_link():


class MockResponse:
def __init__(self, time, code):
def __init__(self, time, status_code):
self.headers = {"retry-after": time}
self.text = ""
self.status_code = code
self.status_code = status_code


def test_retry_api_call_no_error():
Expand All @@ -78,7 +79,6 @@ def tester(_):


def test_retry_api_call_multiple_errors():
status_code = 429
client = Client(foo_token)
call_count = 0

Expand All @@ -89,22 +89,22 @@ def tester(_, time):
call_count += 1

if call_count == 1:
raise ConnectionThrottled(MockResponse(0.05, status_code))
raise ConnectionThrottled(MockResponse(0.05, rate_limited_status_code))
elif call_count == 2:
assert isclose(0.05, seconds, abs_tol=0.1)
raise ConnectionThrottled(MockResponse(0.23, status_code))
raise ConnectionThrottled(MockResponse(0.23, rate_limited_status_code))
elif call_count == 3:
assert isclose(0.35, seconds, abs_tol=0.1)
raise ConnectionThrottled(MockResponse(0.16, status_code))
raise ConnectionThrottled(MockResponse(0.16, rate_limited_status_code))
elif call_count == 4:
assert isclose(0.51, seconds, abs_tol=0.1)
return True

assert tester(client, datetime.now())


@pytest.mark.parametrize("status_code", [409, 429, 500, 502, 504, 503])
def test_retry_api_call_once(status_code):
@pytest.mark.parametrize("code", APIErrorCode.RetryableCodes)
def test_retry_api_call_once(code):
call_count = 0
client = Client(foo_token)

Expand All @@ -114,9 +114,9 @@ def tester(_):
call_count += 1

if call_count == 1:
if status_code == 429:
raise ConnectionThrottled(MockResponse(0.001, status_code))
raise APIResponseError(MockResponse(0.001, status_code), "", status_code)
if code == APIErrorCode.RateLimited:
raise ConnectionThrottled(MockResponse(0.001, rate_limited_status_code))
raise APIResponseError(MockResponse(0.001, 500), "", code)
else:
return True

Expand All @@ -125,24 +125,22 @@ def tester(_):


def test_retry_api_call_max_errors():
status_code = 429
client = Client(foo_token)

@retry_api_call
def tester(_):
raise ConnectionThrottled(MockResponse(0.001, status_code))
raise ConnectionThrottled(MockResponse(0.001, rate_limited_status_code))

with raises(ConnectionThrottled):
tester(client)


def test_retry_api_call_retry_false():
status_code = 429
client = Client(foo_token, retry=False)

@retry_api_call
def tester(_):
raise ConnectionThrottled(MockResponse(0.001, status_code))
raise ConnectionThrottled(MockResponse(0.001, rate_limited_status_code))

with raises(ConnectionThrottled):
tester(client)
Expand Down

0 comments on commit a7ef203

Please sign in to comment.