diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index bbec383e..425b9c75 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -5,6 +5,7 @@ import time from enum import Enum from functools import partial +from pydoc import locate from typing import Any, Callable, Dict, Hashable, Optional, TypeVar from aiormq.tools import awaitable @@ -404,13 +405,33 @@ class JsonRPC(RPC): CONTENT_TYPE = "application/json" def serialize(self, data: Any) -> bytes: - return self.SERIALIZER.dumps(data, ensure_ascii=False, default=repr) + return self.SERIALIZER.dumps( + data, ensure_ascii=False, default=repr).encode('ascii') + + def deserialize(self, data: Any) -> bytes: + res = super().deserialize(data) + if isinstance(res, dict) and "error" in res: + cls = locate(res['error']['type']) + if not cls: + def exception_constructor(self, message, args): + self.message = message + self.args = args + + cls = type(res['error']['type'], (Exception, ), { + '__init__': exception_constructor + }) + + res = cls(res['error']['message'], res['error']['args']) + return res def serialize_exception(self, exception: Exception) -> bytes: return self.serialize( { "error": { - "type": exception.__class__.__name__, + "type": f'{exception.__module__}.' + f'{exception.__class__.__name__}' + if hasattr(exception, '__module__') + else exception.__class__.__name__, "message": repr(exception), "args": exception.args, }, diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 8582489e..fb56b760 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -7,7 +7,7 @@ from aio_pika import Message from aio_pika.exceptions import DeliveryError from aio_pika.message import IncomingMessage -from aio_pika.patterns.rpc import RPC +from aio_pika.patterns.rpc import RPC, JsonRPC from aio_pika.patterns.rpc import log as rpc_logger from tests import get_random_name @@ -19,6 +19,14 @@ def rpc_func(*, foo, bar): return {"foo": "bar"} +class CustomException(Exception): + pass + + +def rpc_raise_exception(*, foo, bar): + raise CustomException('foo bar') + + class TestCase: async def test_simple(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) @@ -172,3 +180,41 @@ async def test_register_twice(self, channel: aio_pika.Channel): await rpc.unregister(rpc_func) await rpc.close() + + async def test_jsonrpc_simple(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc", rpc_func, auto_delete=True) + + result = await rpc.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + await rpc.unregister(rpc_func) + await rpc.close() + + # Close already closed + await rpc.close() + + async def test_jsonrpc_assert(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc", rpc_func, auto_delete=True) + + with pytest.raises(AssertionError): + await rpc.proxy.test.rpc(foo=True, bar=None) + + await rpc.unregister(rpc_func) + await rpc.close() + + async def test_jsonrpc_error(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc_error", rpc_raise_exception, auto_delete=True) + + with pytest.raises(Exception): + await rpc.proxy.test.rpc_error(foo=True, bar=None) + with pytest.raises(CustomException): + await rpc.proxy.test.rpc_error(foo=True, bar=None) + + await rpc.unregister(rpc_raise_exception) + await rpc.close()