Skip to content

Commit

Permalink
Fix backward compatibility of LocalServerCallbackPath (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
int128 authored Jan 26, 2025
1 parent 69eb49a commit 8dec00b
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions e2e_test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestHappyPath(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestRedirectURLHostname(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "/") {
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSuccessRedirect(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
2 changes: 1 addition & 1 deletion e2e_test/pkce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestPKCE(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
2 changes: 1 addition & 1 deletion e2e_test/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestTLS(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "/") {
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
4 changes: 0 additions & 4 deletions oauth2cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ type Config struct {

// Callback path of the local server.
// If your provider requires a specific path of the redirect URL, set it here.
// Default to "/".
LocalServerCallbackPath string

// Response HTML body on authorization completed.
Expand Down Expand Up @@ -124,9 +123,6 @@ func (cfg *Config) validateAndSetDefaults() error {
}
cfg.State = state
}
if cfg.LocalServerCallbackPath == "" {
cfg.LocalServerCallbackPath = "/"
}
if cfg.LocalServerMiddleware == nil {
cfg.LocalServerMiddleware = noopMiddleware
}
Expand Down
8 changes: 6 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,17 @@ type localServerHandler struct {
}

func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
callbackPath := h.config.LocalServerCallbackPath
if callbackPath == "" {
callbackPath = "/"
}
q := r.URL.Query()
switch {
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("error") != "":
case r.Method == "GET" && r.URL.Path == callbackPath && q.Get("error") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleErrorResponse(w, r)
})
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("code") != "":
case r.Method == "GET" && r.URL.Path == callbackPath && q.Get("code") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleCodeResponse(w, r)
})
Expand Down

0 comments on commit 8dec00b

Please sign in to comment.