Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OAuth2 PKCE support #2048

Merged
merged 1 commit into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/http/request/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
SessionIDContextKey
CSRFContextKey
OAuth2StateContextKey
OAuth2CodeVerifierContextKey
FlashMessageContextKey
FlashErrorMessageContextKey
PocketRequestTokenContextKey
Expand Down Expand Up @@ -94,6 +95,10 @@ func OAuth2State(r *http.Request) string {
return getContextStringValue(r, OAuth2StateContextKey)
}

func OAuth2CodeVerifier(r *http.Request) string {
return getContextStringValue(r, OAuth2CodeVerifierContextKey)
}

// FlashMessage returns the message message if any.
func FlashMessage(r *http.Request) string {
return getContextStringValue(r, FlashMessageContextKey)
Expand Down
5 changes: 3 additions & 2 deletions internal/model/app_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type SessionData struct {
CSRF string `json:"csrf"`
OAuth2State string `json:"oauth2_state"`
OAuth2CodeVerifier string `json:"oauth2_code_verifier"`
FlashMessage string `json:"flash_message"`
FlashErrorMessage string `json:"flash_error_message"`
Language string `json:"language"`
Expand All @@ -22,8 +23,8 @@ type SessionData struct {
}

func (s SessionData) String() string {
return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`,
s.CSRF, s.OAuth2State, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken)
return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, OAuth2CodeVerifier=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`,
s.CSRF, s.OAuth2State, s.OAuth2CodeVerifier, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken)
}

// Value converts the session data to JSON.
Expand Down
54 changes: 54 additions & 0 deletions internal/oauth2/authorization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package oauth2 // import "miniflux.app/v2/internal/oauth2"

import (
"crypto/sha256"
"encoding/base64"
"io"

"golang.org/x/oauth2"

"miniflux.app/v2/internal/crypto"
)

type Authorization struct {
url string
state string
codeVerifier string
}

func (u *Authorization) RedirectURL() string {
return u.url
}

func (u *Authorization) State() string {
return u.state
}

func (u *Authorization) CodeVerifier() string {
return u.codeVerifier
}

func GenerateAuthorization(config *oauth2.Config) *Authorization {
codeVerifier := crypto.GenerateRandomStringHex(32)

sha2 := sha256.New()
io.WriteString(sha2, codeVerifier)
codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))

state := crypto.GenerateRandomStringHex(24)

authUrl := config.AuthCodeURL(
state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
)

return &Authorization{
url: authUrl,
state: state,
codeVerifier: codeVerifier,
}
}
44 changes: 20 additions & 24 deletions internal/oauth2/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,30 @@ type googleProvider struct {
redirectURL string
}

func (g *googleProvider) GetUserExtraKey() string {
return "google_id"
func NewGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider {
return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL}
}

func (g *googleProvider) GetConfig() *oauth2.Config {
return &oauth2.Config{
RedirectURL: g.redirectURL,
ClientID: g.clientID,
ClientSecret: g.clientSecret,
Scopes: []string{"email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://accounts.google.com/o/oauth2/token",
},
}
}

func (g *googleProvider) GetRedirectURL(state string) string {
return g.config().AuthCodeURL(state)
func (g *googleProvider) GetUserExtraKey() string {
return "google_id"
}

func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
conf := g.config()
token, err := conf.Exchange(ctx, code)
func (g *googleProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) {
conf := g.GetConfig()
token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -67,20 +80,3 @@ func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Pr
func (g *googleProvider) UnsetUserProfileID(user *model.User) {
user.GoogleID = ""
}

func (g *googleProvider) config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: g.redirectURL,
ClientID: g.clientID,
ClientSecret: g.clientSecret,
Scopes: []string{"email"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://accounts.google.com/o/oauth2/token",
},
}
}

func newGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider {
return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL}
}
8 changes: 2 additions & 6 deletions internal/oauth2/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ import (
"miniflux.app/v2/internal/logger"
)

// Manager handles OAuth2 providers.
type Manager struct {
providers map[string]Provider
}

// FindProvider returns the given provider.
func (m *Manager) FindProvider(name string) (Provider, error) {
if provider, found := m.providers[name]; found {
return provider, nil
Expand All @@ -24,18 +22,16 @@ func (m *Manager) FindProvider(name string) (Provider, error) {
return nil, errors.New("oauth2 provider not found")
}

// AddProvider add a new OAuth2 provider.
func (m *Manager) AddProvider(name string, provider Provider) {
m.providers[name] = provider
}

// NewManager returns a new Manager.
func NewManager(ctx context.Context, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint string) *Manager {
m := &Manager{providers: make(map[string]Provider)}
m.AddProvider("google", newGoogleProvider(clientID, clientSecret, redirectURL))
m.AddProvider("google", NewGoogleProvider(clientID, clientSecret, redirectURL))

if oidcDiscoveryEndpoint != "" {
if genericOidcProvider, err := newOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil {
if genericOidcProvider, err := NewOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil {
logger.Error("[OAuth2] failed to initialize OIDC provider: %v", err)
} else {
m.AddProvider("oidc", genericOidcProvider)
Expand Down
44 changes: 20 additions & 24 deletions internal/oauth2/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,32 @@ type oidcProvider struct {
provider *oidc.Provider
}

func NewOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) {
provider, err := oidc.NewProvider(ctx, discoveryEndpoint)
if err != nil {
return nil, err
}

return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil
}

func (o *oidcProvider) GetUserExtraKey() string {
return "openid_connect_id"
}

func (o *oidcProvider) GetRedirectURL(state string) string {
return o.config().AuthCodeURL(state)
func (o *oidcProvider) GetConfig() *oauth2.Config {
return &oauth2.Config{
RedirectURL: o.redirectURL,
ClientID: o.clientID,
ClientSecret: o.clientSecret,
Scopes: []string{"openid", "email"},
Endpoint: o.provider.Endpoint(),
}
}

func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
conf := o.config()
token, err := conf.Exchange(ctx, code)
func (o *oidcProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) {
conf := o.GetConfig()
token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
if err != nil {
return nil, err
}
Expand All @@ -54,22 +69,3 @@ func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Prof
func (o *oidcProvider) UnsetUserProfileID(user *model.User) {
user.OpenIDConnectID = ""
}

func (o *oidcProvider) config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: o.redirectURL,
ClientID: o.clientID,
ClientSecret: o.clientSecret,
Scopes: []string{"openid", "email"},
Endpoint: o.provider.Endpoint(),
}
}

func newOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) {
provider, err := oidc.NewProvider(ctx, discoveryEndpoint)
if err != nil {
return nil, err
}

return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil
}
6 changes: 4 additions & 2 deletions internal/oauth2/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ package oauth2 // import "miniflux.app/v2/internal/oauth2"
import (
"context"

"golang.org/x/oauth2"

"miniflux.app/v2/internal/model"
)

// Provider is an interface for OAuth2 providers.
type Provider interface {
GetConfig() *oauth2.Config
GetUserExtraKey() string
GetRedirectURL(state string) string
GetProfile(ctx context.Context, code string) (*Profile, error)
GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error)
PopulateUserCreationWithProfileID(user *model.UserCreationRequest, profile *Profile)
PopulateUserWithProfileID(user *model.User, profile *Profile)
UnsetUserProfileID(user *model.User)
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (s *Storage) createAppSession(session *model.Session) (*model.Session, erro
}

// UpdateAppSessionField updates only one session field.
func (s *Storage) UpdateAppSessionField(sessionID, field string, value interface{}) error {
func (s *Storage) UpdateAppSessionField(sessionID, field string, value any) error {
query := `
UPDATE
sessions
Expand Down
3 changes: 2 additions & 1 deletion internal/ui/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler {
return
}

html.BadRequest(w, r, errors.New("Invalid or missing CSRF"))
html.BadRequest(w, r, errors.New("invalid or missing CSRF"))
return
}
}
Expand All @@ -103,6 +103,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler {
ctx = context.WithValue(ctx, request.SessionIDContextKey, session.ID)
ctx = context.WithValue(ctx, request.CSRFContextKey, session.Data.CSRF)
ctx = context.WithValue(ctx, request.OAuth2StateContextKey, session.Data.OAuth2State)
ctx = context.WithValue(ctx, request.OAuth2CodeVerifierContextKey, session.Data.OAuth2CodeVerifier)
ctx = context.WithValue(ctx, request.FlashMessageContextKey, session.Data.FlashMessage)
ctx = context.WithValue(ctx, request.FlashErrorMessageContextKey, session.Data.FlashErrorMessage)
ctx = context.WithValue(ctx, request.UserLanguageContextKey, session.Data.Language)
Expand Down
5 changes: 3 additions & 2 deletions internal/ui/oauth2_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package ui // import "miniflux.app/v2/internal/ui"

import (
"crypto/subtle"
"errors"
"net/http"

Expand Down Expand Up @@ -38,7 +39,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
}

state := request.QueryStringParam(r, "state", "")
if state == "" || state != request.OAuth2State(r) {
if subtle.ConstantTimeCompare([]byte(state), []byte(request.OAuth2State(r))) == 0 {
logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r))
html.Redirect(w, r, route.Path(h.router, "login"))
return
Expand All @@ -51,7 +52,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
return
}

profile, err := authProvider.GetProfile(r.Context(), code)
profile, err := authProvider.GetProfile(r.Context(), code, request.OAuth2CodeVerifier(r))
if err != nil {
logger.Error("[OAuth2] %v", err)
html.Redirect(w, r, route.Path(h.router, "login"))
Expand Down
8 changes: 7 additions & 1 deletion internal/ui/oauth2_redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"miniflux.app/v2/internal/http/response/html"
"miniflux.app/v2/internal/http/route"
"miniflux.app/v2/internal/logger"
"miniflux.app/v2/internal/oauth2"
"miniflux.app/v2/internal/ui/session"
)

Expand All @@ -30,5 +31,10 @@ func (h *handler) oauth2Redirect(w http.ResponseWriter, r *http.Request) {
return
}

html.Redirect(w, r, authProvider.GetRedirectURL(sess.NewOAuth2State()))
auth := oauth2.GenerateAuthorization(authProvider.GetConfig())

sess.SetOAuth2State(auth.State())
sess.SetOAuth2CodeVerifier(auth.CodeVerifier())

html.Redirect(w, r, auth.RedirectURL())
}
10 changes: 5 additions & 5 deletions internal/ui/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package session // import "miniflux.app/v2/internal/ui/session"

import (
"miniflux.app/v2/internal/crypto"
"miniflux.app/v2/internal/storage"
)

Expand All @@ -14,11 +13,12 @@ type Session struct {
sessionID string
}

// NewOAuth2State generates a new OAuth2 state and stores the value into the database.
func (s *Session) NewOAuth2State() string {
state := crypto.GenerateRandomString(32)
func (s *Session) SetOAuth2State(state string) {
s.store.UpdateAppSessionField(s.sessionID, "oauth2_state", state)
return state
}

func (s *Session) SetOAuth2CodeVerifier(codeVerfier string) {
s.store.UpdateAppSessionField(s.sessionID, "oauth2_code_verifier", codeVerfier)
}

// NewFlashMessage creates a new flash message.
Expand Down
Loading