From 8bc336dba4fe1284438c0d12eadf301fd4b8a427 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 25 Oct 2023 09:33:55 +0200 Subject: [PATCH] refactor(dslx): unify TLS and QUIC handshake options (#1378) Closes https://github.com/ooni/probe/issues/2611. --- internal/dslx/quic.go | 83 +++------------------ internal/dslx/quic_test.go | 66 +---------------- internal/dslx/tls.go | 108 ++++++++++++++-------------- internal/dslx/tls_test.go | 143 ++++++++++++++++--------------------- 4 files changed, 124 insertions(+), 276 deletions(-) diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 3acf675ac9..327b6d4be6 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -7,9 +7,7 @@ package dslx import ( "context" "crypto/tls" - "crypto/x509" "io" - "net" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -17,61 +15,23 @@ import ( "github.com/quic-go/quic-go" ) -// QUICHandshakeOption is an option you can pass to QUICHandshake. -type QUICHandshakeOption func(*quicHandshakeFunc) - -// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled. -func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.InsecureSkipVerify = value - } -} - -// QUICHandshakeOptionRootCAs allows to configure custom root CAs. -func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.RootCAs = value - } -} - -// QUICHandshakeOptionServerName allows to configure the SNI to use. -func QUICHandshakeOptionServerName(value string) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.ServerName = value - } -} - // QUICHandshake returns a function performing QUIC handshakes. -func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[ +func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *Endpoint, *Maybe[*QUICConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &quicHandshakeFunc{ - InsecureSkipVerify: false, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // quicHandshakeFunc performs QUIC handshakes. type quicHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool + // Options contains the options. + Options []TLSHandshakeOption // Rt is the Runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply( // create trace trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...) - // use defaults or user-configured overrides - serverName := f.serverName(input) + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( f.Rt.Logger(), - "[#%d] QUICHandshake with %s SNI=%s", + "[#%d] QUICHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, + config.ServerName, + config.NextProtos, ) // setup udpListener := netxlite.NewUDPListener() quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - config := &tls.Config{ - NextProtos: []string{"h3"}, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply( } } -func (f *quicHandshakeFunc) serverName(input *Endpoint) string { - if f.ServerName != "" { - return f.ServerName - } - if input.Domain != "" { - return input.Domain - } - addr, _, err := net.SplitHostPort(input.Address) - if err == nil { - return addr - } - // Note: golang requires a ServerName and fails if it's empty. If the provided - // ServerName is an IP address, however, golang WILL NOT emit any SNI extension - // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") - return "" -} - // QUICConnection is an established QUIC connection. If you initialize // manually, init at least the ones marked as MANDATORY. type QUICConnection struct { diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 2d34954bae..d8a8b066ac 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -29,9 +29,6 @@ func TestQUICHandshake(t *testing.T) { f := QUICHandshake( NewMinimalRuntime(model.DiscardLogger, time.Now()), - QUICHandshakeOptionInsecureSkipVerify(true), - QUICHandshakeOptionServerName("sni"), - QUICHandshakeOptionRootCAs(certpool), ) if _, ok := f.(*quicHandshakeFunc); !ok { t.Fatal("unexpected type. Expected: quicHandshakeFunc") @@ -103,10 +100,7 @@ func TestQUICHandshake(t *testing.T) { return tt.dialer }, })) - quicHandshake := &quicHandshakeFunc{ - Rt: rt, - ServerName: tt.sni, - } + quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni)) endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", @@ -136,61 +130,3 @@ func TestQUICHandshake(t *testing.T) { } }) } - -/* -Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address -*/ -func TestServerNameQUIC(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - endpoint := &Endpoint{ - Address: "example.com:123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} - serverName := f.serverName(endpoint) - if serverName != sni { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - endpoint := &Endpoint{ - Address: "example.com:123", - Domain: domain, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - endpoint := &Endpoint{ - Address: hostaddr + ":123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != hostaddr { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - endpoint := &Endpoint{ - Address: ip, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != "" { - t.Fatalf("unexpected server name: %s", serverName) - } - }) -} diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 5a37685dba..6ed1a63cdc 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,75 +12,58 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TLSHandshakeOption is an option you can pass to TLSHandshake. -type TLSHandshakeOption func(*tlsHandshakeFunc) +type TLSHandshakeOption func(config *tls.Config) // TLSHandshakeOptionInsecureSkipVerify controls whether TLS verification is enabled. func TLSHandshakeOptionInsecureSkipVerify(value bool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.InsecureSkipVerify = value + return func(config *tls.Config) { + config.InsecureSkipVerify = value } } // TLSHandshakeOptionNextProto allows to configure the ALPN protocols. func TLSHandshakeOptionNextProto(value []string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.NextProto = value + return func(config *tls.Config) { + config.NextProtos = value } } // TLSHandshakeOptionRootCAs allows to configure custom root CAs. func TLSHandshakeOptionRootCAs(value *x509.CertPool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.RootCAs = value + return func(config *tls.Config) { + config.RootCAs = value } } // TLSHandshakeOptionServerName allows to configure the SNI to use. func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.ServerName = value + return func(config *tls.Config) { + config.ServerName = value } } // TLSHandshake returns a function performing TSL handshakes. func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // tlsHandshakeFunc performs TLS handshakes. type tlsHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // NextProto contains the ALPNs to negotiate. - NextProto []string - - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool + // Options contains the options. + Options []TLSHandshakeOption // Rt is the Runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -89,9 +72,8 @@ func (f *tlsHandshakeFunc) Apply( // keep using the same trace trace := input.Trace - // use defaults or user-configured overrides - serverName := f.serverName(input) - nextProto := f.nextProto() + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h2", "http/1.1"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( @@ -99,20 +81,14 @@ func (f *tlsHandshakeFunc) Apply( "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, - nextProto, + config.ServerName, + config.NextProtos, ) // obtain the handshaker for use handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup - config := &tls.Config{ - NextProtos: nextProto, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -143,31 +119,51 @@ func (f *tlsHandshakeFunc) Apply( } } -func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { - if f.ServerName != "" { - return f.ServerName +// tlsNewConfig is an utility function to create a new TLS config. +// +// Arguments: +// +// - address is the endpoint address (e.g., 1.1.1.1:443); +// +// - defaultALPN contains the default to be used for configuring ALPN; +// +// - domain is the possibly empty domain to use; +// +// - logger is the logger to use; +// +// - options contains options to modify the TLS handshake defaults. +func tlsNewConfig(address string, defaultALPN []string, domain string, logger model.Logger, options ...TLSHandshakeOption) *tls.Config { + // See https://github.com/ooni/probe/issues/2413 to understand + // why we're using nil to force netxlite to use the cached + // default Mozilla cert pool. + config := &tls.Config{ + NextProtos: append([]string{}, defaultALPN...), + InsecureSkipVerify: false, + RootCAs: nil, + ServerName: tlsServerName(address, domain, logger), } - if input.Domain != "" { - return input.Domain + for _, option := range options { + option(config) + } + return config +} + +// tlsServerName is an utility function to obtina the server name from a TCPConnection. +func tlsServerName(address, domain string, logger model.Logger) string { + if domain != "" { + return domain } - addr, _, err := net.SplitHostPort(input.Address) + addr, _, err := net.SplitHostPort(address) if err == nil { return addr } // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") + logger.Warn("TLSHandshake: cannot determine which SNI to use") return "" } -func (f *tlsHandshakeFunc) nextProto() []string { - if len(f.NextProto) > 0 { - return f.NextProto - } - return []string{"h2", "http/1.1"} -} - // TLSConnection is an established TLS connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TLSConnection struct { diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 4dcce7e0e3..36df4a79ca 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,51 +10,66 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) -/* -Test cases: -- Get tlsHandshakeFunc with options -- Apply tlsHandshakeFunc: - - with EOF - - with invalid address - - with success - - with sni - - with options -*/ -func TestTLSHandshake(t *testing.T) { - t.Run("Get tlsHandshakeFunc with options", func(t *testing.T) { +func TestTLSNewConfig(t *testing.T) { + t.Run("without options", func(t *testing.T) { + config := tlsNewConfig("1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger) + + if config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", false, config.InsecureSkipVerify) + } + if diff := cmp.Diff([]string{"h2", "http/1.1"}, config.NextProtos); diff != "" { + t.Fatal(diff) + } + if config.ServerName != "sni" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", config.ServerName) + } + if !config.RootCAs.Equal(nil) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) + } + }) + + t.Run("with options", func(t *testing.T) { certpool := x509.NewCertPool() certpool.AddCert(&x509.Certificate{}) - f := TLSHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), + config := tlsNewConfig( + "1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger, TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), - TLSHandshakeOptionServerName("sni"), + TLSHandshakeOptionServerName("example.domain"), TLSHandshakeOptionRootCAs(certpool), ) - var handshakeFunc *tlsHandshakeFunc - var ok bool - if handshakeFunc, ok = f.(*tlsHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: tlsHandshakeFunc") - } - if !handshakeFunc.InsecureSkipVerify { - t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, false) + + if !config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, config.InsecureSkipVerify) } - if len(handshakeFunc.NextProto) != 1 || handshakeFunc.NextProto[0] != "h2" { - t.Fatalf("unexpected %s, expected %v, got %v", "NextProto", []string{"h2"}, handshakeFunc.NextProto) + if diff := cmp.Diff([]string{"h2"}, config.NextProtos); diff != "" { + t.Fatal(diff) } - if handshakeFunc.ServerName != "sni" { - t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", handshakeFunc.ServerName) + if config.ServerName != "example.domain" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "example.domain", config.ServerName) } - if !handshakeFunc.RootCAs.Equal(certpool) { - t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", certpool, handshakeFunc.RootCAs) + if !config.RootCAs.Equal(certpool) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) } }) +} +/* +Test cases: +- Apply tlsHandshakeFunc: + - with EOF + - with invalid address + - with success + - with sni + - with options +*/ +func TestTLSHandshake(t *testing.T) { t.Run("Apply tlsHandshakeFunc", func(t *testing.T) { wasClosed := false @@ -137,11 +152,10 @@ func TestTLSHandshake(t *testing.T) { return tt.handshaker }, })) - tlsHandshake := &tlsHandshakeFunc{ - NextProto: tt.config.nextProtos, - Rt: rt, - ServerName: tt.config.sni, - } + tlsHandshake := TLSHandshake(rt, + TLSHandshakeOptionNextProto(tt.config.nextProtos), + TLSHandshakeOptionServerName(tt.config.sni), + ) idGen := &atomic.Int64{} zeroTime := time.Time{} trace := rt.NewTrace(idGen.Add(1), zeroTime) @@ -174,62 +188,27 @@ func TestTLSHandshake(t *testing.T) { /* Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address +- With domain +- With host address +- With IP address */ -func TestServerNameTLS(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - tcpConn := TCPConnection{ - Address: "example.com:123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - ServerName: sni, - } - serverName := f.serverName(&tcpConn) - if serverName != sni { +func TestTLSServerName(t *testing.T) { + t.Run("With domain", func(t *testing.T) { + serverName := tlsServerName("example.com:123", "domain", model.DiscardLogger) + if serverName != "domain" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - tcpConn := TCPConnection{ - Address: "example.com:123", - Domain: domain, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - tcpConn := TCPConnection{ - Address: hostaddr + ":123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != hostaddr { + + t.Run("With host address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1:443", "", model.DiscardLogger) + if serverName != "1.1.1.1" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - tcpConn := TCPConnection{ - Address: ip, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) + + t.Run("With IP address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1", "", model.DiscardLogger) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) }