diff --git a/docs/oauth2.md b/docs/oauth2.md index f76ba9a..ba5ee59 100644 --- a/docs/oauth2.md +++ b/docs/oauth2.md @@ -29,6 +29,8 @@ Returns the authorization URL where you should redirect the user to ask for thei * `redirect_uri: str`: Your callback URI where the user will be redirected after the service prompt. * `state: str = None`: Optional string that will be returned back in the callback parameters to allow you to retrieve state information. * `scope: Optional[List[str]] = None`: Optional list of scopes to ask for. + * `code_challenge: Optional[str] = None`: Optional code_challenge in a [PKCE context](https://datatracker.ietf.org/doc/html/rfc7636). + * `code_challenge_method: Optional[Literal["plain", "S256"]] = None`: Optional code_challenge_method in a [PKCE context](https://datatracker.ietf.org/doc/html/rfc7636). * `extras_params: Optional[Dict[str, Any]] = None`: Optional dictionary containing parameters specific to the service. !!! example diff --git a/httpx_oauth/clients/franceconnect.py b/httpx_oauth/clients/franceconnect.py index 9466517..d89be6b 100644 --- a/httpx_oauth/clients/franceconnect.py +++ b/httpx_oauth/clients/franceconnect.py @@ -1,5 +1,5 @@ import secrets -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict from httpx_oauth.errors import GetIdEmailError from httpx_oauth.oauth2 import BaseOAuth2 @@ -58,6 +58,8 @@ async def get_authorization_url( redirect_uri: str, state: Optional[str] = None, scope: Optional[List[str]] = None, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[Literal["plain", "S256"]] = None, extras_params: Optional[FranceConnectOAuth2AuthorizeParams] = None, ) -> str: _extras_params = extras_params or {} @@ -67,7 +69,7 @@ async def get_authorization_url( _extras_params["nonce"] = secrets.token_urlsafe() return await super().get_authorization_url( - redirect_uri, state, scope, _extras_params + redirect_uri, state, scope, extras_params=_extras_params ) async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]: diff --git a/httpx_oauth/oauth2.py b/httpx_oauth/oauth2.py index 04634ee..bfef17e 100644 --- a/httpx_oauth/oauth2.py +++ b/httpx_oauth/oauth2.py @@ -5,6 +5,7 @@ Dict, Generic, List, + Literal, Optional, Tuple, TypeVar, @@ -100,6 +101,8 @@ async def get_authorization_url( redirect_uri: str, state: Optional[str] = None, scope: Optional[List[str]] = None, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[Literal["plain", "S256"]] = None, extras_params: Optional[T] = None, ) -> str: params = { @@ -116,6 +119,12 @@ async def get_authorization_url( if _scope is not None: params["scope"] = " ".join(_scope) + if code_challenge is not None: + params["code_challenge"] = code_challenge + + if code_challenge_method is not None: + params["code_challenge_method"] = code_challenge_method + if extras_params is not None: params = {**params, **extras_params} # type: ignore diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 313aa26..475d65f 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -93,6 +93,16 @@ async def test_get_authorization_url_with_scopes(self): ) assert "scope=SCOPE1+SCOPE2+SCOPE3" in authorization_url + @pytest.mark.asyncio + async def test_get_authorization_url_with_plain_code_challenge(self): + authorization_url = await client.get_authorization_url( + REDIRECT_URI, + code_challenge="CODE_CHALLENGE", + code_challenge_method="plain", + ) + assert "code_challenge=CODE_CHALLENGE" in authorization_url + assert "code_challenge_method=plain" in authorization_url + @pytest.mark.asyncio async def test_get_authorization_url_with_extras_params(self): authorization_url = await client.get_authorization_url(