Skip to content

Commit

Permalink
Support Named Pipes in gRPC target strings (#198)
Browse files Browse the repository at this point in the history
Signed-off-by: Agustín Martínez Fayó <[email protected]>
  • Loading branch information
amartinezfayo authored Jun 9, 2022
1 parent be346a3 commit fcf03d7
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 87 deletions.
5 changes: 2 additions & 3 deletions v2/internal/test/fakeworkloadapi/workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"testing"

Expand Down Expand Up @@ -51,7 +50,7 @@ func New(tb testing.TB) *WorkloadAPI {
x509BundlesChans: make(map[chan *workload.X509BundlesResponse]struct{}),
}

listener, err := net.Listen("tcp", "localhost:0")
listener, err := newListener()
require.NoError(tb, err)

server := grpc.NewServer()
Expand All @@ -63,7 +62,7 @@ func New(tb testing.TB) *WorkloadAPI {
_ = server.Serve(listener)
}()

w.addr = fmt.Sprintf("%s://%s", listener.Addr().Network(), listener.Addr().String())
w.addr = getTargetName(listener.Addr())
tb.Logf("WorkloadAPI address: %s", w.addr)
w.server = server
return w
Expand Down
17 changes: 17 additions & 0 deletions v2/internal/test/fakeworkloadapi/workload_api_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build !windows
// +build !windows

package fakeworkloadapi

import (
"fmt"
"net"
)

func newListener() (net.Listener, error) {
return net.Listen("tcp", "localhost:0")
}

func getTargetName(addr net.Addr) string {
return fmt.Sprintf("%s://%s", addr.Network(), addr.String())
}
23 changes: 22 additions & 1 deletion v2/internal/test/fakeworkloadapi/workload_api_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package fakeworkloadapi
import (
"fmt"
"math/rand"
"net"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -33,12 +35,31 @@ func NewWithNamedPipeListener(tb testing.TB) *WorkloadAPI {
_ = server.Serve(listener)
}()

w.addr = listener.Addr().String()
w.addr = getTargetName(listener.Addr())
tb.Logf("WorkloadAPI address: %s", w.addr)
w.server = server
return w
}

func GetPipeName(s string) string {
return strings.TrimPrefix(s, `\\.\pipe`)
}

func init() {
rand.Seed(time.Now().UnixNano())
}

func newListener() (net.Listener, error) {
return winio.ListenPipe(fmt.Sprintf(`\\.\pipe\go-spiffe-test-pipe-%x`, rand.Uint64()), nil)
}

func getTargetName(addr net.Addr) string {
if addr.Network() == "pipe" {
// The go-winio library defines the network of a
// named pipe address as "pipe", but we use the
// "npipe" scheme for named pipes URLs.
return "npipe:" + GetPipeName(addr.String())
}

return fmt.Sprintf("%s://%s", addr.Network(), addr.String())
}
21 changes: 21 additions & 0 deletions v2/spiffetls/spiffetls_posix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//go:build !windows
// +build !windows

package spiffetls_test

import (
"github.com/spiffe/go-spiffe/v2/spiffetls"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
)

func listenAndDialCasesOS() []listenAndDialCase {
return []listenAndDialCase{
{
name: "Wrong workload API server socket",
dialMode: spiffetls.TLSClient(tlsconfig.AuthorizeID(serverID)),
defaultWlAPIAddr: "wrong-socket-path",
dialErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a \"tcp\" or \"unix\" scheme",
listenErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a \"tcp\" or \"unix\" scheme",
},
}
}
48 changes: 22 additions & 26 deletions v2/spiffetls/spiffetls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ type testEnv struct {
err error
}

type listenAndDialCase struct {
name string

dialMode spiffetls.DialMode
dialOption []spiffetls.DialOption

listenMode spiffetls.ListenMode
listenOption []spiffetls.ListenOption

defaultWlAPIAddr string
dialErr string
listenErr string
listenLAddr string
listenProtocol string
serverConnPeerIDErr string
clientConnPeerIDErr string
usesExternalDialer bool
usesBaseTLSConfig bool
}

func TestListenAndDial(t *testing.T) {
testEnv, cleanup := setupTestEnv(t)
defer cleanup()
Expand All @@ -67,33 +87,8 @@ func TestListenAndDial(t *testing.T) {
externalTLSConfBuffer := &bytes.Buffer{}

// Test Table
tests := []struct {
name string

dialMode spiffetls.DialMode
dialOption []spiffetls.DialOption

listenMode spiffetls.ListenMode
listenOption []spiffetls.ListenOption

defaultWlAPIAddr string
dialErr string
listenErr string
listenLAddr string
listenProtocol string
serverConnPeerIDErr string
clientConnPeerIDErr string
usesExternalDialer bool
usesBaseTLSConfig bool
}{
tests := []listenAndDialCase{
// Failure Scenarios
{
name: "Wrong workload API server socket",
dialMode: spiffetls.TLSClient(tlsconfig.AuthorizeID(serverID)),
defaultWlAPIAddr: "wrong-socket-path",
dialErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a tcp:// or unix:// scheme",
listenErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a tcp:// or unix:// scheme",
},
{
name: "No server listening",
dialMode: spiffetls.TLSClient(tlsconfig.AuthorizeID(serverID)),
Expand Down Expand Up @@ -248,6 +243,7 @@ func TestListenAndDial(t *testing.T) {
clientConnPeerIDErr: "spiffetls: no URI SANs",
},
}
tests = append(tests, listenAndDialCasesOS()...)

for _, test := range tests {
test := test
Expand Down
21 changes: 21 additions & 0 deletions v2/spiffetls/spiffetls_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//go:build windows
// +build windows

package spiffetls_test

import (
"github.com/spiffe/go-spiffe/v2/spiffetls"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
)

func listenAndDialCasesOS() []listenAndDialCase {
return []listenAndDialCase{
{
name: "Wrong workload API server socket",
dialMode: spiffetls.TLSClient(tlsconfig.AuthorizeID(serverID)),
defaultWlAPIAddr: "wrong-socket-path",
dialErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a \"tcp\" or \"npipe\" scheme",
listenErr: "spiffetls: cannot create X.509 source: workload endpoint socket URI must have a \"tcp\" or \"npipe\" scheme",
},
}
}
31 changes: 10 additions & 21 deletions v2/workloadapi/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,22 @@ func GetDefaultAddress() (string, bool) {
// a Workload API endpoint exposed as either a Unix
// Domain Socket or TCP socket.
func ValidateAddress(addr string) error {
_, err := parseTargetFromAddr(addr)
_, err := parseTargetFromStringAddr(addr)
return err
}

// parseTargetFromAddr parses the endpoint address and returns a gRPC target
// parseTargetFromStringAddr parses the endpoint address and returns a gRPC target
// string for dialing.
func parseTargetFromAddr(addr string) (string, error) {
func parseTargetFromStringAddr(addr string) (string, error) {
u, err := url.Parse(addr)
if err != nil {
return "", errors.New("workload endpoint socket is not a valid URI: " + err.Error())
}
switch u.Scheme {
case "unix":
switch {
case u.Opaque != "":
return "", errors.New("workload endpoint unix socket URI must not be opaque")
case u.User != nil:
return "", errors.New("workload endpoint unix socket URI must not include user info")
case u.Host == "" && u.Path == "":
return "", errors.New("workload endpoint unix socket URI must include a path")
case u.RawQuery != "":
return "", errors.New("workload endpoint unix socket URI must not include query values")
case u.Fragment != "":
return "", errors.New("workload endpoint unix socket URI must not include a fragment")
}
return u.String(), nil
case "tcp":
return parseTargetFromURLAddr(u)
}

func parseTargetFromURLAddr(u *url.URL) (string, error) {
if u.Scheme == "tcp" {
switch {
case u.Opaque != "":
return "", errors.New("workload endpoint tcp socket URI must not be opaque")
Expand All @@ -74,7 +63,7 @@ func parseTargetFromAddr(addr string) (string, error) {
}

return net.JoinHostPort(ip.String(), port), nil
default:
return "", errors.New("workload endpoint socket URI must have a tcp:// or unix:// scheme")
}

return parseTargetFromURLAddrOS(u)
}
34 changes: 34 additions & 0 deletions v2/workloadapi/addr_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//go:build !windows
// +build !windows

package workloadapi

import (
"errors"
"net/url"
)

var (
ErrInvalidEndpointScheme = errors.New("workload endpoint socket URI must have a \"tcp\" or \"unix\" scheme")
)

func parseTargetFromURLAddrOS(u *url.URL) (string, error) {
switch u.Scheme {
case "unix":
switch {
case u.Opaque != "":
return "", errors.New("workload endpoint unix socket URI must not be opaque")
case u.User != nil:
return "", errors.New("workload endpoint unix socket URI must not include user info")
case u.Host == "" && u.Path == "":
return "", errors.New("workload endpoint unix socket URI must include a path")
case u.RawQuery != "":
return "", errors.New("workload endpoint unix socket URI must not include query values")
case u.Fragment != "":
return "", errors.New("workload endpoint unix socket URI must not include a fragment")
}
return u.String(), nil
default:
return "", ErrInvalidEndpointScheme
}
}
33 changes: 33 additions & 0 deletions v2/workloadapi/addr_posix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//go:build !windows
// +build !windows

package workloadapi

func validateAddressCasesOS() []validateAddressCase {
return []validateAddressCase{
{
addr: "unix:opaque",
err: "workload endpoint unix socket URI must not be opaque",
},
{
addr: "unix://",
err: "workload endpoint unix socket URI must include a path",
},
{
addr: "unix://foo?whatever",
err: "workload endpoint unix socket URI must not include query values",
},
{
addr: "unix://foo#whatever",
err: "workload endpoint unix socket URI must not include a fragment",
},
{
addr: "unix://john:doe@foo/path",
err: "workload endpoint unix socket URI must not include user info",
},
{
addr: "unix://foo",
err: "",
},
}
}
37 changes: 8 additions & 29 deletions v2/workloadapi/addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"github.com/stretchr/testify/require"
)

type validateAddressCase struct {
addr string
err string
}

func TestGetDefaultAddress(t *testing.T) {
if orig, ok := os.LookupEnv(SocketEnv); ok {
defer os.Setenv(SocketEnv, orig)
Expand All @@ -27,41 +32,14 @@ func TestGetDefaultAddress(t *testing.T) {
}

func TestValidateAddress(t *testing.T) {
testCases := []struct {
addr string
err string
}{
testCases := []validateAddressCase{
{
addr: "\t",
err: "net/url: invalid control character in URL",
},
{
addr: "blah",
err: "workload endpoint socket URI must have a tcp:// or unix:// scheme",
},
{
addr: "unix:opaque",
err: "workload endpoint unix socket URI must not be opaque",
},
{
addr: "unix://",
err: "workload endpoint unix socket URI must include a path",
},
{
addr: "unix://foo?whatever",
err: "workload endpoint unix socket URI must not include query values",
},
{
addr: "unix://foo#whatever",
err: "workload endpoint unix socket URI must not include a fragment",
},
{
addr: "unix://john:doe@foo/path",
err: "workload endpoint unix socket URI must not include user info",
},
{
addr: "unix://foo",
err: "",
err: ErrInvalidEndpointScheme.Error(),
},
{
addr: "tcp:opaque",
Expand Down Expand Up @@ -100,6 +78,7 @@ func TestValidateAddress(t *testing.T) {
err: "",
},
}
testCases = append(testCases, validateAddressCasesOS()...)

for _, testCase := range testCases {
err := ValidateAddress(testCase.addr)
Expand Down
Loading

0 comments on commit fcf03d7

Please sign in to comment.