Skip to content

Commit

Permalink
refactor(getClientName): change function to method + adopt the function
Browse files Browse the repository at this point in the history
  • Loading branch information
BarcoMasile committed Feb 6, 2024
1 parent 6214f7f commit c2f7b0e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
7 changes: 3 additions & 4 deletions pkg/kratos/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (s *Service) CheckAllowedProvider(ctx context.Context, loginFlow *kClient.L
defer span.End()

provider := updateFlowBody.UpdateLoginFlowWithOidcMethod.Provider
clientName := getClientName(loginFlow)
clientName := s.getClientName(loginFlow)

allowedProviders, err := s.authz.ListObjects(ctx, fmt.Sprintf("app:%s", clientName), "allowed_access", "provider")
if err != nil {
Expand All @@ -183,7 +183,7 @@ func (s *Service) CheckAllowedProvider(ctx context.Context, loginFlow *kClient.L
return s.contains(allowedProviders, fmt.Sprintf("%v", provider)), nil
}

func getClientName(loginFlow *kClient.LoginFlow) string {
func (s *Service) getClientName(loginFlow *kClient.LoginFlow) string {
oauth2LoginRequest := loginFlow.Oauth2LoginRequest
if oauth2LoginRequest != nil {
return oauth2LoginRequest.Client.GetClientName()
Expand All @@ -196,8 +196,7 @@ func (s *Service) FilterFlowProviderList(ctx context.Context, flow *kClient.Logi
ctx, span := s.tracer.Start(ctx, "kratos.Service.FilterFlowProviderList")
defer span.End()

loginRequest := flow.Oauth2LoginRequest
clientName := loginRequest.Client.GetClientName()
clientName := s.getClientName(flow)

allowedProviders, err := s.authz.ListObjects(ctx, fmt.Sprintf("app:%s", clientName), "allowed_access", "provider")
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions pkg/kratos/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -734,20 +733,26 @@ func TestCheckAllowedProviderFail(t *testing.T) {

func TestGetClientNameOAuthKeeper(t *testing.T) {
loginFlow := &kClient.LoginFlow{}
service := NewService(nil, nil, nil, nil, nil, nil)

actualClientName := getClientName(loginFlow)
actualClientName := service.getClientName(loginFlow)

const expectedClientName = ""
assert.Equal(t, expectedClientName, actualClientName)
if expectedClientName != actualClientName {
t.Fatalf("Expected client name doesn't match")
}
}

func TestGetClientNameOAuth2Request(t *testing.T) {
expectedClientName := "mockClientName"
loginFlow := &kClient.LoginFlow{Oauth2LoginRequest: &kClient.OAuth2LoginRequest{Client: &kClient.OAuth2Client{ClientName: &expectedClientName}}}
service := NewService(nil, nil, nil, nil, nil, nil)

actualClientName := getClientName(loginFlow)
actualClientName := service.getClientName(loginFlow)

assert.Equal(t, expectedClientName, actualClientName)
if expectedClientName != actualClientName {
t.Fatalf("Expected client name doesn't match")
}
}

func TestFilterFlowProviderListAllowAll(t *testing.T) {
Expand Down

0 comments on commit c2f7b0e

Please sign in to comment.