diff --git a/auth.go b/auth.go index 8a24239..829b2cb 100644 --- a/auth.go +++ b/auth.go @@ -7,7 +7,9 @@ package tscaddy import ( "fmt" + "net" "net/http" + "reflect" "strings" "github.com/caddyserver/caddy/v2" @@ -40,6 +42,54 @@ func (Auth) CaddyModule() caddy.ModuleInfo { } } +// findTsnetListener recursively searches ln for wrapped or embedded net.Listeners +// until it finds a tsnetListener or runs out. +// ok indicates if a tsnetListener was found. +// +// In the future consider alternative approach if Caddy supports unwrapping listeners. +// See discussion in https://github.com/tailscale/caddy-tailscale/pull/70 +func findTsnetListener(ln net.Listener) (_ tsnetListener, ok bool) { + if ln == nil { + return nil, false + } + + // if ln is a tsnetListener, return it. + if tsn, ok := ln.(tsnetListener); ok { + return tsn, true + } + + // if ln is a wrappedListener, unwrap it. + if wl, ok := ln.(wrappedListener); ok { + return findTsnetListener(wl.Unwrap()) + } + + // if ln has an embedded net.Listener field, unwrap it. + s := reflect.ValueOf(ln) + if s.Kind() == reflect.Ptr { + s = s.Elem() + } + if s.Kind() != reflect.Struct { + return nil, false + } + + innerLn := s.FieldByName("Listener") + if innerLn.IsZero() { + // no more child/embedded listeners left + return nil, false + } + + // if the "Listener" field is a net.Listener, use it. + if wl, ok := innerLn.Interface().(net.Listener); ok { + return findTsnetListener(wl) + } + return nil, false +} + +// wrappedListener is implemented by types that wrap net.Listeners. +type wrappedListener interface { + Unwrap() net.Listener +} + // client returns the tailscale LocalClient for the TailscaleAuth module. // If the LocalClient has not already been configured, the provided request will be used to // lookup the tailscale node that serviced the request, and get the associated LocalClient. @@ -52,7 +102,7 @@ func (ta *Auth) client(r *http.Request) (*tailscale.LocalClient, error) { // server. server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server) for _, listener := range server.Listeners() { - if tsl, ok := listener.(tsnetListener); ok { + if tsl, ok := findTsnetListener(listener); ok { var err error ta.localclient, err = tsl.Server().LocalClient() if err != nil { diff --git a/module.go b/module.go index 7ed0f0c..48fd3e9 100644 --- a/module.go +++ b/module.go @@ -309,6 +309,13 @@ type tsnetServerListener struct { net.Listener } +func (t *tsnetServerListener) Unwrap() net.Listener { + if t == nil { + return nil + } + return t.Listener +} + func (t *tsnetServerListener) Close() error { if err := t.Listener.Close(); err != nil { return err