diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 3acf675ac..91d0b8eeb 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -7,9 +7,7 @@ package dslx import ( "context" "crypto/tls" - "crypto/x509" "io" - "net" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -17,61 +15,23 @@ import ( "github.com/quic-go/quic-go" ) -// QUICHandshakeOption is an option you can pass to QUICHandshake. -type QUICHandshakeOption func(*quicHandshakeFunc) - -// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled. -func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.InsecureSkipVerify = value - } -} - -// QUICHandshakeOptionRootCAs allows to configure custom root CAs. -func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.RootCAs = value - } -} - -// QUICHandshakeOptionServerName allows to configure the SNI to use. -func QUICHandshakeOptionServerName(value string) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.ServerName = value - } -} - // QUICHandshake returns a function performing QUIC handshakes. -func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[ +func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *Endpoint, *Maybe[*QUICConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &quicHandshakeFunc{ - InsecureSkipVerify: false, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // quicHandshakeFunc performs QUIC handshakes. type quicHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool + // Options contains the options. + Options []TLSHandshakeOption - // Rt is the Runtime that owns us. + // Rt is the runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply( // create trace trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...) - // use defaults or user-configured overrides - serverName := f.serverName(input) + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( f.Rt.Logger(), - "[#%d] QUICHandshake with %s SNI=%s", + "[#%d] QUICHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, + config.ServerName, + config.NextProtos, ) // setup udpListener := netxlite.NewUDPListener() quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - config := &tls.Config{ - NextProtos: []string{"h3"}, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply( } } -func (f *quicHandshakeFunc) serverName(input *Endpoint) string { - if f.ServerName != "" { - return f.ServerName - } - if input.Domain != "" { - return input.Domain - } - addr, _, err := net.SplitHostPort(input.Address) - if err == nil { - return addr - } - // Note: golang requires a ServerName and fails if it's empty. If the provided - // ServerName is an IP address, however, golang WILL NOT emit any SNI extension - // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") - return "" -} - // QUICConnection is an established QUIC connection. If you initialize // manually, init at least the ones marked as MANDATORY. type QUICConnection struct { diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 2d34954ba..9512783a7 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -3,7 +3,6 @@ package dslx import ( "context" "crypto/tls" - "crypto/x509" "io" "testing" "time" @@ -16,28 +15,12 @@ import ( /* Test cases: -- Get quicHandshakeFunc with options - Apply quicHandshakeFunc: - with EOF - success - with sni */ func TestQUICHandshake(t *testing.T) { - t.Run("Get quicHandshakeFunc with options", func(t *testing.T) { - certpool := x509.NewCertPool() - certpool.AddCert(&x509.Certificate{}) - - f := QUICHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), - QUICHandshakeOptionInsecureSkipVerify(true), - QUICHandshakeOptionServerName("sni"), - QUICHandshakeOptionRootCAs(certpool), - ) - if _, ok := f.(*quicHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: quicHandshakeFunc") - } - }) - t.Run("Apply quicHandshakeFunc", func(t *testing.T) { wasClosed := false plainConn := &mocks.QUICEarlyConnection{ @@ -103,10 +86,7 @@ func TestQUICHandshake(t *testing.T) { return tt.dialer }, })) - quicHandshake := &quicHandshakeFunc{ - Rt: rt, - ServerName: tt.sni, - } + quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni)) endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", @@ -136,61 +116,3 @@ func TestQUICHandshake(t *testing.T) { } }) } - -/* -Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address -*/ -func TestServerNameQUIC(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - endpoint := &Endpoint{ - Address: "example.com:123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} - serverName := f.serverName(endpoint) - if serverName != sni { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - endpoint := &Endpoint{ - Address: "example.com:123", - Domain: domain, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - endpoint := &Endpoint{ - Address: hostaddr + ":123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != hostaddr { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - endpoint := &Endpoint{ - Address: ip, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != "" { - t.Fatalf("unexpected server name: %s", serverName) - } - }) -}