diff --git a/aiohttp_jwt/permissions.py b/aiohttp_jwt/permissions.py index 6adb2e43..94c411fb 100644 --- a/aiohttp_jwt/permissions.py +++ b/aiohttp_jwt/permissions.py @@ -19,7 +19,15 @@ def match_all(required, provided): def login_required(func): async def wrapped(*args, **kwargs): request = args[-1] - assert isinstance(request, web.Request) + + if isinstance(request, web.View): + request = request.request + + if not isinstance(request, web.BaseRequest): # pragma: no cover + raise RuntimeError( + 'Incorrect usage of decorator.' + 'Expect web.BaseRequest as an argument') + request_property = __config[__REQUEST_IDENT] if not request.get(request_property): @@ -43,7 +51,15 @@ def check_permissions( def scopes_checker(func): async def wrapped(*args, **kwargs): request = args[-1] - assert isinstance(request, web.Request) + + if isinstance(request, web.View): + request = request.request + + if not isinstance(request, web.BaseRequest): # pragma: no cover + raise RuntimeError( + 'Incorrect usage of decorator.' + 'Expect web.BaseRequest as an argument') + request_property = __config[__REQUEST_IDENT] payload = request.get(request_property) diff --git a/tests/conftest.py b/tests/conftest.py index 2fd4b017..715c4687 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ def token(fake_payload, secret): @pytest.fixture def create_app(secret): - def factory(routes, *args, **kwargs): + def factory(routes=tuple(), views=tuple(), *args, **kwargs): defaults = {'secret_or_pub_key': secret} app = web.Application( middlewares=[ @@ -39,5 +39,8 @@ def factory(routes, *args, **kwargs): for path, handler in routes: app.router.add_get(path, handler) + for path, view in views: + app.router.add_view(path, view) + return app return factory diff --git a/tests/test_permissions.py b/tests/test_permissions.py index eccfd2e5..cdea19ea 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -32,6 +32,22 @@ async def handler(self, request): assert 'Authorization required' in response.reason +async def test_login_required_view( + create_app, fake_payload, aiohttp_client, secret): + class App(web.View): + @login_required + async def get(self): + return web.json_response({}) + + views = (('/foo', App),) + client = await aiohttp_client( + create_app(views=views, credentials_required=False)) + + response = await client.get('/foo') + assert response.status == 401 + assert 'Authorization required' in response.reason + + async def test_check_permissions( create_app, fake_payload, aiohttp_client, secret): token = jwt.encode({**fake_payload, 'scopes': ['view']}, secret) @@ -65,6 +81,24 @@ async def handler(self, request): assert response.status == 200 +async def test_check_permissions_view( + create_app, fake_payload, aiohttp_client, secret): + token = jwt.encode({**fake_payload, 'scopes': ['view']}, secret) + + class App(web.View): + @check_permissions(['view']) + async def get(self): + return web.json_response({}) + + views = (('/foo', App),) + client = await aiohttp_client( + create_app(views=views, credentials_required=False)) + response = await client.get('/foo', headers={ + 'Authorization': 'Bearer {}'.format(token.decode('utf-8')) + }) + assert response.status == 200 + + async def test_insufficient_scopes( create_app, fake_payload, aiohttp_client, secret): token = jwt.encode({**fake_payload, 'scopes': ['view']}, secret)