diff --git a/pkg/kratos/service.go b/pkg/kratos/service.go index a5e5b3c5e..f6dcdb99b 100644 --- a/pkg/kratos/service.go +++ b/pkg/kratos/service.go @@ -170,7 +170,7 @@ func (s *Service) CheckAllowedProvider(ctx context.Context, loginFlow *kClient.L defer span.End() provider := updateFlowBody.UpdateLoginFlowWithOidcMethod.Provider - clientName := getClientName(loginFlow) + clientName := s.GetClientName(loginFlow) allowedProviders, err := s.authz.ListObjects(ctx, fmt.Sprintf("app:%s", clientName), "allowed_access", "provider") if err != nil { @@ -183,7 +183,7 @@ func (s *Service) CheckAllowedProvider(ctx context.Context, loginFlow *kClient.L return s.contains(allowedProviders, fmt.Sprintf("%v", provider)), nil } -func getClientName(loginFlow *kClient.LoginFlow) string { +func (s *Service) GetClientName(loginFlow *kClient.LoginFlow) string { oauth2LoginRequest := loginFlow.Oauth2LoginRequest if oauth2LoginRequest != nil { return oauth2LoginRequest.Client.GetClientName() @@ -196,8 +196,7 @@ func (s *Service) FilterFlowProviderList(ctx context.Context, flow *kClient.Logi ctx, span := s.tracer.Start(ctx, "kratos.Service.FilterFlowProviderList") defer span.End() - loginRequest := flow.Oauth2LoginRequest - clientName := loginRequest.Client.GetClientName() + clientName := s.GetClientName(flow) allowedProviders, err := s.authz.ListObjects(ctx, fmt.Sprintf("app:%s", clientName), "allowed_access", "provider") if err != nil { diff --git a/pkg/kratos/service_test.go b/pkg/kratos/service_test.go index 243e29c73..562423044 100644 --- a/pkg/kratos/service_test.go +++ b/pkg/kratos/service_test.go @@ -734,8 +734,9 @@ func TestCheckAllowedProviderFail(t *testing.T) { func TestGetClientNameOAuthKeeper(t *testing.T) { loginFlow := &kClient.LoginFlow{} + service := NewService(nil, nil, nil, nil, nil, nil) - actualClientName := getClientName(loginFlow) + actualClientName := service.GetClientName(loginFlow) const expectedClientName = "" assert.Equal(t, expectedClientName, actualClientName) @@ -744,8 +745,9 @@ func TestGetClientNameOAuthKeeper(t *testing.T) { func TestGetClientNameOAuth2Request(t *testing.T) { expectedClientName := "mockClientName" loginFlow := &kClient.LoginFlow{Oauth2LoginRequest: &kClient.OAuth2LoginRequest{Client: &kClient.OAuth2Client{ClientName: &expectedClientName}}} + service := NewService(nil, nil, nil, nil, nil, nil) - actualClientName := getClientName(loginFlow) + actualClientName := service.GetClientName(loginFlow) assert.Equal(t, expectedClientName, actualClientName) }