Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Handle NotFound and UnprocessableEntity errors in middleware #3327

Merged
merged 13 commits into from
Nov 5, 2024
Merged
94 changes: 54 additions & 40 deletions acapy_agent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from aiohttp_apispec import setup_aiohttp_apispec, validation_middleware
from uuid_utils import uuid4

from acapy_agent.wallet import singletons

from ..config.injection_context import InjectionContext
from ..config.logging import context_wallet_id
from ..core.event_bus import Event, EventBus
Expand All @@ -31,9 +29,11 @@
from ..transport.outbound.status import OutboundSendStatus
from ..transport.queue.basic import BasicMessageQueue
from ..utils import general as general_utils
from ..utils.extract_validation_error import extract_validation_error_message
from ..utils.stats import Collector
from ..utils.task_queue import TaskQueue
from ..version import __version__
from ..wallet import singletons
from ..wallet.anoncreds_upgrade import check_upgrade_completion_loop
from .base_server import BaseAdminServer
from .error import AdminSetupError
Expand Down Expand Up @@ -68,6 +68,8 @@
anoncreds_wallets = singletons.IsAnoncredsSingleton().wallets
in_progress_upgrades = singletons.UpgradeInProgressSingleton()

status_paths = ("/status/live", "/status/ready")


class AdminResponder(BaseResponder):
"""Handle outgoing messages from message handlers."""
Expand Down Expand Up @@ -134,44 +136,56 @@ def send_fn(self) -> Coroutine:
async def ready_middleware(request: web.BaseRequest, handler: Coroutine):
"""Only continue if application is ready to take work."""

if str(request.rel_url).rstrip("/") in (
"/status/live",
"/status/ready",
) or request.app._state.get("ready"):
try:
return await handler(request)
except (LedgerConfigError, LedgerTransactionError) as e:
# fatal, signal server shutdown
LOGGER.error("Shutdown with %s", str(e))
request.app._state["ready"] = False
request.app._state["alive"] = False
raise
except web.HTTPFound as e:
# redirect, typically / -> /api/doc
LOGGER.info("Handler redirect to: %s", e.location)
raise
except (web.HTTPUnauthorized, jwt.InvalidTokenError, InvalidTokenError) as e:
LOGGER.info(
"Unauthorized access during %s %s: %s", request.method, request.path, e
)
raise web.HTTPUnauthorized(reason=str(e)) from e
except (web.HTTPBadRequest, MultitenantManagerError) as e:
LOGGER.info("Bad request during %s %s: %s", request.method, request.path, e)
raise web.HTTPBadRequest(reason=str(e)) from e
except asyncio.CancelledError:
# redirection spawns new task and cancels old
LOGGER.debug("Task cancelled")
raise
except Exception as e:
# some other error?
LOGGER.error("Handler error with exception: %s", str(e))
import traceback

print("\n=================")
traceback.print_exc()
raise

raise web.HTTPServiceUnavailable(reason="Shutdown in progress")
is_status_check = str(request.rel_url).rstrip("/") in status_paths
is_app_ready = request.app._state.get("ready")

if not (is_status_check or is_app_ready):
raise web.HTTPServiceUnavailable(reason="Shutdown in progress")

try:
return await handler(request)
except web.HTTPFound as e:
# redirect, typically / -> /api/doc
LOGGER.info("Handler redirect to: %s", e.location)
raise
except asyncio.CancelledError:
# redirection spawns new task and cancels old
LOGGER.debug("Task cancelled")
raise
except (web.HTTPUnauthorized, jwt.InvalidTokenError, InvalidTokenError) as e:
LOGGER.info(
"Unauthorized access during %s %s: %s", request.method, request.path, e
)
raise web.HTTPUnauthorized(reason=str(e)) from e
except (web.HTTPBadRequest, MultitenantManagerError) as e:
LOGGER.info("Bad request during %s %s: %s", request.method, request.path, e)
raise web.HTTPBadRequest(reason=str(e)) from e
except (web.HTTPNotFound, StorageNotFoundError) as e:
LOGGER.info(
"Not Found error occurred during %s %s: %s",
request.method,
request.path,
e,
)
raise web.HTTPNotFound(reason=str(e)) from e
except web.HTTPUnprocessableEntity as e:
validation_error_message = extract_validation_error_message(e)
LOGGER.info(
"Unprocessable Entity occurred during %s %s: %s",
request.method,
request.path,
validation_error_message,
)
raise
except (LedgerConfigError, LedgerTransactionError) as e:
# fatal, signal server shutdown
LOGGER.critical("Shutdown with %s", str(e))
request.app._state["ready"] = False
request.app._state["alive"] = False
raise
except Exception as e:
LOGGER.exception("Handler error with exception:", exc_info=e)
raise


@web.middleware
Expand Down
158 changes: 158 additions & 0 deletions acapy_agent/admin/tests/test_admin_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Optional
from unittest import IsolatedAsyncioTestCase

import jwt
import pytest
from aiohttp import ClientSession, DummyCookieJar, TCPConnector, web
from aiohttp.test_utils import unused_port
from marshmallow import ValidationError

from acapy_agent.tests import mock
from acapy_agent.wallet import singletons
Expand All @@ -16,7 +18,9 @@
from ...core.goal_code_registry import GoalCodeRegistry
from ...core.in_memory import InMemoryProfile
from ...core.protocol_registry import ProtocolRegistry
from ...multitenant.error import MultitenantManagerError
from ...storage.base import BaseStorage
from ...storage.error import StorageNotFoundError
from ...storage.record import StorageRecord
from ...storage.type import RECORD_TYPE_ACAPY_UPGRADING
from ...utils.stats import Collector
Expand Down Expand Up @@ -108,6 +112,160 @@ async def test_ready_middleware(self):
with self.assertRaises(KeyError):
await test_module.ready_middleware(request, handler)

async def test_ready_middleware_http_unauthorized(self):
"""Test handling of web.HTTPUnauthorized and related exceptions."""
with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger:
mock_logger.info = mock.MagicMock()

request = mock.MagicMock(
method="GET",
path="/unauthorized",
app=mock.MagicMock(_state={"ready": True}),
)

# Test web.HTTPUnauthorized
handler = mock.CoroutineMock(
side_effect=web.HTTPUnauthorized(reason="Unauthorized")
)
with self.assertRaises(web.HTTPUnauthorized):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Unauthorized access during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

# Test jwt.InvalidTokenError
handler = mock.CoroutineMock(
side_effect=jwt.InvalidTokenError("Invalid token")
)
with self.assertRaises(web.HTTPUnauthorized):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Unauthorized access during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

# Test InvalidTokenError
handler = mock.CoroutineMock(
side_effect=test_module.InvalidTokenError("Token error")
)
with self.assertRaises(web.HTTPUnauthorized):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Unauthorized access during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

async def test_ready_middleware_http_bad_request(self):
"""Test handling of web.HTTPBadRequest and MultitenantManagerError."""
with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger:
mock_logger.info = mock.MagicMock()

request = mock.MagicMock(
method="POST",
path="/bad-request",
app=mock.MagicMock(_state={"ready": True}),
)

# Test web.HTTPBadRequest
handler = mock.CoroutineMock(
side_effect=web.HTTPBadRequest(reason="Bad request")
)
with self.assertRaises(web.HTTPBadRequest):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Bad request during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

# Test MultitenantManagerError
handler = mock.CoroutineMock(
side_effect=MultitenantManagerError("Multitenant error")
)
with self.assertRaises(web.HTTPBadRequest):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Bad request during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

async def test_ready_middleware_http_not_found(self):
"""Test handling of web.HTTPNotFound and StorageNotFoundError."""
with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger:
mock_logger.info = mock.MagicMock()

request = mock.MagicMock(
method="GET",
path="/not-found",
app=mock.MagicMock(_state={"ready": True}),
)

# Test web.HTTPNotFound
handler = mock.CoroutineMock(side_effect=web.HTTPNotFound(reason="Not found"))
with self.assertRaises(web.HTTPNotFound):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Not Found error occurred during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

# Test StorageNotFoundError
handler = mock.CoroutineMock(
side_effect=StorageNotFoundError("Item not found")
)
with self.assertRaises(web.HTTPNotFound):
await test_module.ready_middleware(request, handler)
mock_logger.info.assert_called_with(
"Not Found error occurred during %s %s: %s",
request.method,
request.path,
handler.side_effect,
)

async def test_ready_middleware_http_unprocessable_entity(self):
"""Test handling of web.HTTPUnprocessableEntity with nested ValidationError."""
with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger:
mock_logger.info = mock.MagicMock()
# Mock the extract_validation_error_message function
with mock.patch.object(
test_module, "extract_validation_error_message"
) as mock_extract:
mock_extract.return_value = {"field": ["Invalid input"]}

request = mock.MagicMock(
method="POST",
path="/unprocessable",
app=mock.MagicMock(_state={"ready": True}),
)

# Create a HTTPUnprocessableEntity exception with a nested ValidationError
validation_error = ValidationError({"field": ["Invalid input"]})
http_error = web.HTTPUnprocessableEntity(reason="Unprocessable Entity")
http_error.__cause__ = validation_error

handler = mock.CoroutineMock(side_effect=http_error)
with self.assertRaises(web.HTTPUnprocessableEntity):
await test_module.ready_middleware(request, handler)
mock_extract.assert_called_once_with(http_error)
mock_logger.info.assert_called_with(
"Unprocessable Entity occurred during %s %s: %s",
request.method,
request.path,
mock_extract.return_value,
)

def get_admin_server(
self, settings: Optional[dict] = None, context: Optional[InjectionContext] = None
) -> AdminServer:
Expand Down
2 changes: 1 addition & 1 deletion acapy_agent/askar/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_ledger_pool(self):
read_only = bool(self.settings.get("ledger.read_only", False))
socks_proxy = self.settings.get("ledger.socks_proxy")
if read_only:
LOGGER.error("Note: setting ledger to read-only mode")
LOGGER.warning("Note: setting ledger to read-only mode")
genesis_transactions = self.settings.get("ledger.genesis_transactions")
cache = self.context.injector.inject_or(BaseCache)
self.ledger_pool = IndyVdrLedgerPool(
Expand Down
2 changes: 1 addition & 1 deletion acapy_agent/askar/profile_anon.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def init_ledger_pool(self):
read_only = bool(self.settings.get("ledger.read_only", False))
socks_proxy = self.settings.get("ledger.socks_proxy")
if read_only:
LOGGER.error("Note: setting ledger to read-only mode")
LOGGER.warning("Note: setting ledger to read-only mode")
genesis_transactions = self.settings.get("ledger.genesis_transactions")
cache = self.context.injector.inject_or(BaseCache)
self.ledger_pool = IndyVdrLedgerPool(
Expand Down
2 changes: 1 addition & 1 deletion acapy_agent/protocols/out_of_band/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def handle_use_did_method(
did_method = PEER4 if did_peer_4 else PEER2
my_info = await self.oob.fetch_invitation_reuse_did(did_method)
if not my_info:
LOGGER.warn("No invitation DID found, creating new DID")
LOGGER.warning("No invitation DID found, creating new DID")

if not my_info:
did_metadata = (
Expand Down
16 changes: 16 additions & 0 deletions acapy_agent/utils/extract_validation_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Extract validation error messages from nested exceptions."""

from aiohttp.web import HTTPUnprocessableEntity
from marshmallow.exceptions import ValidationError


def extract_validation_error_message(exc: HTTPUnprocessableEntity) -> str:
"""Extract marshmallow error message from a nested UnprocessableEntity exception."""
visited = set()
current_exc = exc
while current_exc and current_exc not in visited:
visited.add(current_exc)
if isinstance(current_exc, ValidationError):
return current_exc.messages
current_exc = current_exc.__cause__ or current_exc.__context__
return exc.reason
Loading
Loading