Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
fix: Clean providers list when oauth config is disabled (#13)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

When `.oauth.yaml` file contains `enabled: false`, the oauth endpoints
return the configured provider info. This PR fixes it and set an empty
list when `enabled=False`

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Running locally and HF spaces

- [ ] Test A
- [ ] Test B

**Checklist**

- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Feb 5, 2024
1 parent ec79728 commit 0a6b1ea
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/argilla_server/apis/v1/handlers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@

@router.get("/providers", response_model=Providers)
def list_providers(_request: Request) -> Providers:
items = [Provider(name=provider_name) for provider_name in settings.oauth.providers]
if not settings.oauth.enabled:
return Providers(items=[])

return Providers(items=items)
return Providers(items=[Provider(name=provider_name) for provider_name in settings.oauth.providers])


@router.get("/providers/{provider}/authentication")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self.enabled = enabled
self.allow_http_redirect = allow_http_redirect
self.allowed_workspaces = allowed_workspaces or []

self._providers = providers or []

if self.allow_http_redirect:
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/api/v1/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ async def test_list_providers_with_oauth_disabled(
assert response.status_code == 200
assert response.json() == {"items": []}

async def test_list_provider_with_oauth_disabled_from_settings(
self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings
):
default_oauth_settings.enabled = False
with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings):
response = await async_client.get("/api/v1/oauth2/providers", headers=owner_auth_header)
assert response.status_code == 200
assert response.json() == {"items": []}

async def test_list_providers(
self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings
):
Expand Down Expand Up @@ -103,6 +112,19 @@ async def test_provider_authentication_with_oauth_disabled(
)
assert response.status_code == 404

async def test_provider_authentication_with_oauth_disabled_and_provider_defined(
self,
async_client: AsyncClient,
owner_auth_header: dict,
default_oauth_settings: OAuth2Settings,
):
default_oauth_settings.enabled = False
with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings):
response = await async_client.get(
"/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header
)
assert response.status_code == 404

async def test_provider_authentication_with_invalid_provider(
self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings
):
Expand Down

0 comments on commit 0a6b1ea

Please sign in to comment.