From c44bac4f47a6d9043cb226263064261d45f8055c Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 20:19:20 -0400 Subject: [PATCH 01/10] fix(server): typo causing HTTPS router to be used for HTTP server --- server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index d3f4d4c88..3383dd7a6 100644 --- a/server/server.go +++ b/server/server.go @@ -501,7 +501,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - Handler: s.httpsMux, + Handler: s.httpMux, } if err := srv.Serve(listener); err != nil { From 6afa03c519a6c3e403188b21c7ae22c77a812a48 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Fri, 29 Mar 2024 19:29:40 -0400 Subject: [PATCH 02/10] refactor(server): simplify HTTP router setup --- server/server.go | 22 ++++--------------- server/server_endpoints.go | 45 ++++++++++++++------------------------ 2 files changed, 20 insertions(+), 47 deletions(-) diff --git a/server/server.go b/server/server.go index 3383dd7a6..a42c6df8b 100644 --- a/server/server.go +++ b/server/server.go @@ -50,7 +50,6 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config httpMux *chi.Mux - httpsMux *chi.Mux cert tls.Certificate } @@ -117,19 +116,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpRouter := createHTTPRouter(cfg) - httpsRouter := createHTTPSRouter(cfg) - httpListeners, httpsListeners, err := createHTTPListeners(cfg) if err != nil { return nil, err } - if len(httpListeners) != 0 || len(httpsListeners) != 0 { - metrics.Start(httpRouter, cfg.Prometheus) - metrics.Start(httpsRouter, cfg.Prometheus) - } - metrics.RegisterEventListeners() bootstrap, err := resolver.NewBootstrap(ctx, cfg) @@ -156,25 +147,20 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - httpMux: httpRouter, - httpsMux: httpsRouter, cert: cert, } server.printConfiguration() server.registerDNSHandlers(ctx) - err = server.registerAPIEndpoints(httpRouter) + openAPIImpl, err := server.createOpenAPIInterfaceImpl() if err != nil { return nil, err } - err = server.registerAPIEndpoints(httpsRouter) - - if err != nil { - return nil, err - } + server.httpMux = createHTTPRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(server.httpMux) return server, err } @@ -518,7 +504,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Infof("https server is up and running on addr/port %s", address) server := http.Server{ - Handler: s.httpsMux, + Handler: s.httpMux, ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 0302ab744..264f2f19b 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/metrics" "github.com/0xERR0R/blocky/resolver" "github.com/0xERR0R/blocky/api" @@ -37,10 +38,13 @@ const ( func secureHeader(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("strict-transport-security", "max-age=63072000") - w.Header().Set("x-frame-options", "DENY") - w.Header().Set("x-content-type-options", "nosniff") - w.Header().Set("x-xss-protection", "1; mode=block") + if r.TLS != nil { + w.Header().Set("strict-transport-security", "max-age=63072000") + w.Header().Set("x-frame-options", "DENY") + w.Header().Set("x-content-type-options", "nosniff") + w.Header().Set("x-xss-protection", "1; mode=block") + } + next.ServeHTTP(w, r) }) } @@ -64,24 +68,15 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerAPIEndpoints(router *chi.Mux) error { +func (s *Server) registerDoHEndpoints(router *chi.Mux) { const pathDohQuery = "/dns-query" - openAPIImpl, err := s.createOpenAPIInterfaceImpl() - if err != nil { - return err - } - - api.RegisterOpenAPIEndpoints(router, openAPIImpl) - router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) router.Post(pathDohQuery, s.dohPostRequestHandler) router.Post(pathDohQuery+"/", s.dohPostRequestHandler) router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) - - return nil } func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { @@ -177,27 +172,15 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPSRouter(cfg *config.Config) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() configureSecureHeaderHandler(router) - registerHandlers(cfg, router) - - return router -} - -func createHTTPRouter(cfg *config.Config) *chi.Mux { - router := chi.NewRouter() - - registerHandlers(cfg, router) - - return router -} - -func registerHandlers(cfg *config.Config, router *chi.Mux) { configureCorsHandler(router) + api.RegisterOpenAPIEndpoints(router, openAPIImpl) + configureDebugHandler(router) configureDocsHandler(router) @@ -205,6 +188,10 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) { configureStaticAssetsHandler(router) configureRootHandler(cfg, router) + + metrics.Start(router, cfg.Prometheus) + + return router } func configureDocsHandler(router *chi.Mux) { From ae83f2e90dfd24142b0015d4657716ea5af013c8 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Fri, 29 Mar 2024 19:27:26 -0400 Subject: [PATCH 03/10] refactor(server): deduplicate `tls.Config` setup --- server/server.go | 56 ++++++++++++++++++++++++------------------- server/server_test.go | 2 +- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/server/server.go b/server/server.go index a42c6df8b..c8cf05edd 100644 --- a/server/server.go +++ b/server/server.go @@ -50,7 +50,7 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config httpMux *chi.Mux - cert tls.Certificate + tlsCfg *tls.Config } func logger() *logrus.Entry { @@ -98,20 +98,38 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { return } +func newTLSConfig(cfg *config.Config) (*tls.Config, error) { + var cert tls.Certificate + + cert, err := retrieveCertificate(cfg) + if err != nil { + return nil, fmt.Errorf("can't retrieve cert: %w", err) + } + + // #nosec G402 // See TLSVersion.validate + res := &tls.Config{ + MinVersion: uint16(cfg.MinTLSServeVer), + CipherSuites: tlsCipherSuites(), + Certificates: []tls.Certificate{cert}, + } + + return res, nil +} + // NewServer creates new server instance with passed config // //nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { - var cert tls.Certificate + var tlsCfg *tls.Config if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 { - cert, err = retrieveCertificate(cfg) + tlsCfg, err = newTLSConfig(cfg) if err != nil { - return nil, fmt.Errorf("can't retrieve cert: %w", err) + return nil, err } } - dnsServers, err := createServers(cfg, cert) + dnsServers, err := createServers(cfg, tlsCfg) if err != nil { return nil, fmt.Errorf("server creation failed: %w", err) } @@ -147,7 +165,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - cert: cert, + tlsCfg: tlsCfg, } server.printConfiguration() @@ -165,7 +183,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return server, err } -func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) { +func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) { var dnsServers []*dns.Server var err *multierror.Error @@ -187,7 +205,7 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err addServers(createUDPServer, cfg.Ports.DNS), addServers(createTCPServer, cfg.Ports.DNS), addServers(func(address string) (*dns.Server, error) { - return createTLSServer(cfg, address, cert) + return createTLSServer(address, tlsCfg) }, cfg.Ports.TLS)) return dnsServers, err.ErrorOrNil() @@ -222,17 +240,12 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, return listeners, nil } -func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) { +func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { return &dns.Server{ - Addr: address, - Net: "tcp-tls", - //nolint:gosec - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: uint16(cfg.MinTLSServeVer), - CipherSuites: tlsCipherSuites(), - }, - Handler: dns.NewServeMux(), + Addr: address, + Net: "tcp-tls", + TLSConfig: tlsCfg, + Handler: dns.NewServeMux(), NotifyStartedFunc: func() { logger().Infof("TLS server is up and running on address %s", address) }, @@ -508,12 +521,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - //nolint:gosec - TLSConfig: &tls.Config{ - MinVersion: uint16(s.cfg.MinTLSServeVer), - CipherSuites: tlsCipherSuites(), - Certificates: []tls.Certificate{s.cert}, - }, + TLSConfig: s.tlsCfg, } if err := server.ServeTLS(listener, "", ""); err != nil { diff --git a/server/server_test.go b/server/server_test.go index fc0fde069..1ea4cfc02 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -745,7 +745,7 @@ var _ = Describe("Running DNS server", func() { } sut, err := NewServer(ctx, &cfg) Expect(err).Should(Succeed()) - Expect(sut.cert.Certificate).ShouldNot(BeNil()) + Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty()) }) }) }) From 39ae0885b83a97a420e7dfba48fec87b7d6b65a3 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 20:10:30 -0400 Subject: [PATCH 04/10] refactor(server): setup TLS listeners manually to remove `ServeTLS` use --- server/server.go | 32 +++++++++++++++++++++----------- server/server_test.go | 8 ++++---- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/server/server.go b/server/server.go index c8cf05edd..8765af83a 100644 --- a/server/server.go +++ b/server/server.go @@ -50,7 +50,6 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config httpMux *chi.Mux - tlsCfg *tls.Config } func logger() *logrus.Entry { @@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config -// -//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpListeners, httpsListeners, err := createHTTPListeners(cfg) + httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg) if err != nil { return nil, err } @@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - tlsCfg: tlsCfg, } server.printConfiguration() @@ -211,13 +207,15 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error return dnsServers, err.ErrorOrNil() } -func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) { - httpListeners, err = newListeners("http", cfg.Ports.HTTP) +func createHTTPListeners( + cfg *config.Config, tlsCfg *tls.Config, +) (httpListeners, httpsListeners []net.Listener, err error) { + httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP) if err != nil { return nil, nil, err } - httpsListeners, err = newListeners("https", cfg.Ports.HTTPS) + httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg) if err != nil { return nil, nil, err } @@ -225,7 +223,7 @@ func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []ne return httpListeners, httpsListeners, nil } -func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { +func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { listeners := make([]net.Listener, 0, len(addresses)) for _, address := range addresses { @@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, return listeners, nil } +func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) { + listeners, err := newTCPListeners(proto, addresses) + if err != nil { + return nil, err + } + + for i, inner := range listeners { + listeners[i] = tls.NewListener(inner, tlsCfg) + } + + return listeners, nil +} + func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { return &dns.Server{ Addr: address, @@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { ReadTimeout: readTimeout, ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - TLSConfig: s.tlsCfg, } - if err := server.ServeTLS(listener, "", ""); err != nil { + if err := server.Serve(listener); err != nil { errCh <- fmt.Errorf("start https listener failed: %w", err) } }() diff --git a/server/server_test.go b/server/server_test.go index 1ea4cfc02..60207d0ac 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/base64" - "fmt" "io" "net" "net/http" @@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() { cfg.KeyFile = "" cfg.CertFile = "" cfg.Ports = config.Ports{ - HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)}, + HTTPS: []string{":0"}, } - sut, err := NewServer(ctx, &cfg) + + sut, err := newTLSConfig(&cfg) Expect(err).Should(Succeed()) - Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty()) + Expect(sut.Certificates).ShouldNot(BeEmpty()) }) }) }) From d97ac838d8de6e84d8ec95d07b5625e97e96664f Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 20:11:10 -0400 Subject: [PATCH 05/10] refactor(server): deduplicate HTTP server setup with new `httpServer` --- server/http.go | 48 ++++++++++++++++++++++++++ server/server.go | 89 ++++++++++++++++++++---------------------------- 2 files changed, 85 insertions(+), 52 deletions(-) create mode 100644 server/http.go diff --git a/server/http.go b/server/http.go new file mode 100644 index 000000000..78f7fe0df --- /dev/null +++ b/server/http.go @@ -0,0 +1,48 @@ +package server + +import ( + "context" + "net" + "net/http" + "time" +) + +type httpServer struct { + inner http.Server + + name string +} + +func newHTTPServer(name string, handler http.Handler) *httpServer { + const ( + readHeaderTimeout = 20 * time.Second + readTimeout = 20 * time.Second + writeTimeout = 20 * time.Second + ) + + return &httpServer{ + inner: http.Server{ + ReadTimeout: readTimeout, + ReadHeaderTimeout: readHeaderTimeout, + WriteTimeout: writeTimeout, + + Handler: handler, + }, + + name: name, + } +} + +func (s *httpServer) String() string { + return s.name +} + +func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { + go func() { + <-ctx.Done() + + s.inner.Close() + }() + + return s.inner.Serve(l) +} diff --git a/server/server.go b/server/server.go index 8765af83a..844734a73 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,7 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" + "github.com/0xERR0R/blocky/util" "github.com/google/uuid" "github.com/hashicorp/go-multierror" @@ -44,12 +45,11 @@ const ( // Server controls the endpoints for DNS and HTTP type Server struct { - dnsServers []*dns.Server - httpListeners []net.Listener - httpsListeners []net.Listener - queryResolver resolver.ChainedResolver - cfg *config.Config - httpMux *chi.Mux + dnsServers []*dns.Server + queryResolver resolver.ChainedResolver + cfg *config.Config + + servers map[net.Listener]*httpServer } func logger() *logrus.Entry { @@ -116,6 +116,8 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config +// +//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -157,11 +159,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err } server = &Server{ - dnsServers: dnsServers, - queryResolver: queryResolver, - cfg: cfg, - httpListeners: httpListeners, - httpsListeners: httpsListeners, + dnsServers: dnsServers, + queryResolver: queryResolver, + cfg: cfg, + + servers: make(map[net.Listener]*httpServer), } server.printConfiguration() @@ -173,8 +175,24 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, err } - server.httpMux = createHTTPRouter(cfg, openAPIImpl) - server.registerDoHEndpoints(server.httpMux) + httpRouter := createHTTPRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(httpRouter) + + if len(cfg.Ports.HTTP) != 0 { + srv := newHTTPServer("http", httpRouter) + + for _, l := range httpListeners { + server.servers[l] = srv + } + } + + if len(cfg.Ports.HTTPS) != 0 { + srv := newHTTPServer("https", httpRouter) + + for _, l := range httpsListeners { + server.servers[l] = srv + } + } return server, err } @@ -480,12 +498,6 @@ func toMB(b uint64) uint64 { return b / bytesInKB / bytesInKB } -const ( - readHeaderTimeout = 20 * time.Second - readTimeout = 20 * time.Second - writeTimeout = 20 * time.Second -) - // Start starts the server func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") @@ -500,42 +512,15 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { }() } - for i, listener := range s.httpListeners { - listener := listener - address := s.cfg.Ports.HTTP[i] + for listener, srv := range s.servers { + listener, srv := listener, srv go func() { - logger().Infof("http server is up and running on addr/port %s", address) - - srv := &http.Server{ - ReadTimeout: readTimeout, - ReadHeaderTimeout: readHeaderTimeout, - WriteTimeout: writeTimeout, - Handler: s.httpMux, - } - - if err := srv.Serve(listener); err != nil { - errCh <- fmt.Errorf("start http listener failed: %w", err) - } - }() - } + logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr()) - for i, listener := range s.httpsListeners { - listener := listener - address := s.cfg.Ports.HTTPS[i] - - go func() { - logger().Infof("https server is up and running on addr/port %s", address) - - server := http.Server{ - Handler: s.httpMux, - ReadTimeout: readTimeout, - ReadHeaderTimeout: readHeaderTimeout, - WriteTimeout: writeTimeout, - } - - if err := server.Serve(listener); err != nil { - errCh <- fmt.Errorf("start https listener failed: %w", err) + err := srv.Serve(ctx, listener) + if err != nil { + errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err) } }() } From 28b319dbf3105c3eecde0ead8a913fde3939e161 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 22:06:09 -0400 Subject: [PATCH 06/10] refactor(server): move middleware setup to `httpServer` --- server/http.go | 50 +++++++++++++++++++++++++++++++++++++- server/server_endpoints.go | 36 --------------------------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/server/http.go b/server/http.go index 78f7fe0df..cac0e8102 100644 --- a/server/http.go +++ b/server/http.go @@ -5,6 +5,9 @@ import ( "net" "net/http" "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/cors" ) type httpServer struct { @@ -26,7 +29,7 @@ func newHTTPServer(name string, handler http.Handler) *httpServer { ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - Handler: handler, + Handler: withCommonMiddleware(handler), }, name: name, @@ -46,3 +49,48 @@ func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { return s.inner.Serve(l) } + +func withCommonMiddleware(inner http.Handler) *chi.Mux { + // Middleware must be defined before routes, so + // create a new router and mount the inner handler + mux := chi.NewMux() + + mux.Use( + secureHeadersMiddleware, + newCORSMiddleware(), + ) + + mux.Mount("/", inner) + + return mux +} + +type httpMiddleware = func(http.Handler) http.Handler + +func secureHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS != nil { + w.Header().Set("strict-transport-security", "max-age=63072000") + w.Header().Set("x-frame-options", "DENY") + w.Header().Set("x-content-type-options", "nosniff") + w.Header().Set("x-xss-protection", "1; mode=block") + } + + next.ServeHTTP(w, r) + }) +} + +func newCORSMiddleware() httpMiddleware { + const corsMaxAge = 5 * time.Minute + + options := cors.Options{ + AllowCredentials: true, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"*"}, + ExposedHeaders: []string{"Link"}, + MaxAge: int(corsMaxAge.Seconds()), + } + + return cors.New(options).Handler +} diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 264f2f19b..1fb3db602 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "time" "github.com/0xERR0R/blocky/metrics" "github.com/0xERR0R/blocky/resolver" @@ -23,7 +22,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/cors" "github.com/miekg/dns" ) @@ -33,22 +31,8 @@ const ( dnsContentType = "application/dns-message" htmlContentType = "text/html; charset=UTF-8" yamlContentType = "text/yaml" - corsMaxAge = 5 * time.Minute ) -func secureHeader(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.TLS != nil { - w.Header().Set("strict-transport-security", "max-age=63072000") - w.Header().Set("x-frame-options", "DENY") - w.Header().Set("x-content-type-options", "nosniff") - w.Header().Set("x-xss-protection", "1; mode=block") - } - - next.ServeHTTP(w, r) - }) -} - func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) { bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver) if err != nil { @@ -175,10 +159,6 @@ func (s *Server) Query( func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() - configureSecureHeaderHandler(router) - - configureCorsHandler(router) - api.RegisterOpenAPIEndpoints(router, openAPIImpl) configureDebugHandler(router) @@ -269,22 +249,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit } } -func configureSecureHeaderHandler(router *chi.Mux) { - router.Use(secureHeader) -} - func configureDebugHandler(router *chi.Mux) { router.Mount("/debug", middleware.Profiler()) } - -func configureCorsHandler(router *chi.Mux) { - crs := cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, - ExposedHeaders: []string{"Link"}, - AllowCredentials: true, - MaxAge: int(corsMaxAge.Seconds()), - }) - router.Use(crs.Handler) -} From 6240a5da22bda4dafdfa9c500826925a45826cd5 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Wed, 28 Aug 2024 19:35:43 -0400 Subject: [PATCH 07/10] refactor: move `createSelfSignedCert` to util Make it available outside of `server`, for use in `service` tests. --- server/server.go | 109 +------------------------------------------ util/tls.go | 118 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 108 deletions(-) create mode 100644 util/tls.go diff --git a/server/server.go b/server/server.go index 844734a73..4dc588eb2 100644 --- a/server/server.go +++ b/server/server.go @@ -1,19 +1,10 @@ package server import ( - "bytes" "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" - "crypto/x509" - "encoding/pem" "errors" "fmt" - "math" - "math/big" - mrand "math/rand" "net" "net/http" "runtime" @@ -81,7 +72,7 @@ type NewServerFunc func(address string) (*dns.Server, error) func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { if cfg.CertFile == "" && cfg.KeyFile == "" { - cert, err = createSelfSignedCert() + cert, err = util.CreateSelfSignedCert() if err != nil { return tls.Certificate{}, fmt.Errorf("unable to generate self-signed certificate: %w", err) } @@ -304,104 +295,6 @@ func createUDPServer(address string) (*dns.Server, error) { }, nil } -//nolint:funlen -func createSelfSignedCert() (tls.Certificate, error) { - // Create CA - ca := &x509.Certificate{ - //nolint:gosec - SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(caExpiryYears, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - caPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return tls.Certificate{}, err - } - - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) - if err != nil { - return tls.Certificate{}, err - } - - caPEM := new(bytes.Buffer) - if err = pem.Encode(caPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - }); err != nil { - return tls.Certificate{}, err - } - - caPrivKeyPEM := new(bytes.Buffer) - - b, err := x509.MarshalECPrivateKey(caPrivKey) - if err != nil { - return tls.Certificate{}, err - } - - if err = pem.Encode(caPrivKeyPEM, &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: b, - }); err != nil { - return tls.Certificate{}, err - } - - // Create certificate - cert := &x509.Certificate{ - //nolint:gosec - SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), - DNSNames: []string{"*"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(certExpiryYears, 0, 0), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - - certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return tls.Certificate{}, err - } - - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) - if err != nil { - return tls.Certificate{}, err - } - - certPEM := new(bytes.Buffer) - if err = pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }); err != nil { - return tls.Certificate{}, err - } - - certPrivKeyPEM := new(bytes.Buffer) - - b, err = x509.MarshalECPrivateKey(certPrivKey) - if err != nil { - return tls.Certificate{}, err - } - - if err = pem.Encode(certPrivKeyPEM, &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: b, - }); err != nil { - return tls.Certificate{}, err - } - - keyPair, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) - if err != nil { - return tls.Certificate{}, err - } - - return keyPair, nil -} - func createQueryResolver( ctx context.Context, cfg *config.Config, diff --git a/util/tls.go b/util/tls.go new file mode 100644 index 000000000..a5ebcc414 --- /dev/null +++ b/util/tls.go @@ -0,0 +1,118 @@ +package util + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math" + "math/big" + mrand "math/rand" + "time" +) + +const ( + caExpiryYears = 10 + certExpiryYears = 5 +) + +//nolint:funlen +func CreateSelfSignedCert() (tls.Certificate, error) { + // Create CA + ca := &x509.Certificate{ + //nolint:gosec + SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(caExpiryYears, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) + if err != nil { + return tls.Certificate{}, err + } + + caPEM := new(bytes.Buffer) + if err = pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + }); err != nil { + return tls.Certificate{}, err + } + + caPrivKeyPEM := new(bytes.Buffer) + + b, err := x509.MarshalECPrivateKey(caPrivKey) + if err != nil { + return tls.Certificate{}, err + } + + if err = pem.Encode(caPrivKeyPEM, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + }); err != nil { + return tls.Certificate{}, err + } + + // Create certificate + cert := &x509.Certificate{ + //nolint:gosec + SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), + DNSNames: []string{"*"}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(certExpiryYears, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := new(bytes.Buffer) + if err = pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }); err != nil { + return tls.Certificate{}, err + } + + certPrivKeyPEM := new(bytes.Buffer) + + b, err = x509.MarshalECPrivateKey(certPrivKey) + if err != nil { + return tls.Certificate{}, err + } + + if err = pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + }); err != nil { + return tls.Certificate{}, err + } + + keyPair, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) + if err != nil { + return tls.Certificate{}, err + } + + return keyPair, nil +} From 1f8bc4fcf49275794110fb57b4bc3a1c21e69a4a Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Fri, 30 Aug 2024 12:56:58 -0400 Subject: [PATCH 08/10] refactor: cleanup TLS self-signed cert generation It's now actually a self-signed cert, instead of using a CA no one will ever see. --- server/server.go | 2 +- util/tls.go | 113 +++++++++++++---------------------------------- util/tls_test.go | 45 +++++++++++++++++++ 3 files changed, 77 insertions(+), 83 deletions(-) create mode 100644 util/tls_test.go diff --git a/server/server.go b/server/server.go index 4dc588eb2..0ca7217f2 100644 --- a/server/server.go +++ b/server/server.go @@ -72,7 +72,7 @@ type NewServerFunc func(address string) (*dns.Server, error) func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { if cfg.CertFile == "" && cfg.KeyFile == "" { - cert, err = util.CreateSelfSignedCert() + cert, err = util.TLSGenerateSelfSignedCert([]string{"blocky.invalid", "*"}) if err != nil { return tls.Certificate{}, fmt.Errorf("unable to generate self-signed certificate: %w", err) } diff --git a/util/tls.go b/util/tls.go index a5ebcc414..78857a9e5 100644 --- a/util/tls.go +++ b/util/tls.go @@ -1,118 +1,67 @@ package util import ( - "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" - "encoding/pem" - "math" + "crypto/x509/pkix" + "fmt" "math/big" - mrand "math/rand" "time" ) const ( - caExpiryYears = 10 - certExpiryYears = 5 + certSerialMaxBits = 128 + certExpiryYears = 5 ) -//nolint:funlen -func CreateSelfSignedCert() (tls.Certificate, error) { - // Create CA - ca := &x509.Certificate{ - //nolint:gosec - SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(caExpiryYears, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - caPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +// TLSGenerateSelfSignedCert returns a new self-signed cert for the given domains. +// +// Being self-signed, no client will trust this certificate. +func TLSGenerateSelfSignedCert(domains []string) (tls.Certificate, error) { + serialMax := new(big.Int).Lsh(big.NewInt(1), certSerialMaxBits) + serial, err := rand.Int(rand.Reader, serialMax) if err != nil { return tls.Certificate{}, err } - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) - if err != nil { - return tls.Certificate{}, err - } + template := &x509.Certificate{ + SerialNumber: serial, - caPEM := new(bytes.Buffer) - if err = pem.Encode(caPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - }); err != nil { - return tls.Certificate{}, err - } + Subject: pkix.Name{Organization: []string{"Blocky"}}, + DNSNames: domains, - caPrivKeyPEM := new(bytes.Buffer) - - b, err := x509.MarshalECPrivateKey(caPrivKey) - if err != nil { - return tls.Certificate{}, err - } - - if err = pem.Encode(caPrivKeyPEM, &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: b, - }); err != nil { - return tls.Certificate{}, err - } + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(certExpiryYears, 0, 0), - // Create certificate - cert := &x509.Certificate{ - //nolint:gosec - SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))), - DNSNames: []string{"*"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(certExpiryYears, 0, 0), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } - certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - return tls.Certificate{}, err + return tls.Certificate{}, fmt.Errorf("unable to generate private key: %w", err) } - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) + der, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey) if err != nil { - return tls.Certificate{}, err + return tls.Certificate{}, fmt.Errorf("cert creation from template failed: %w", err) } - certPEM := new(bytes.Buffer) - if err = pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }); err != nil { - return tls.Certificate{}, err - } - - certPrivKeyPEM := new(bytes.Buffer) - - b, err = x509.MarshalECPrivateKey(certPrivKey) + // Parse the generated DER back into a useable cert + // This avoids needing to do it for each TLS handshake (see tls.Certificate.Leaf comment) + cert, err := x509.ParseCertificate(der) if err != nil { - return tls.Certificate{}, err - } - - if err = pem.Encode(certPrivKeyPEM, &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: b, - }); err != nil { - return tls.Certificate{}, err + return tls.Certificate{}, fmt.Errorf("generated cert DER could not be parsed: %w", err) } - keyPair, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) - if err != nil { - return tls.Certificate{}, err + tlsCert := tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: privKey, + Leaf: cert, } - return keyPair, nil + return tlsCert, nil } diff --git a/util/tls_test.go b/util/tls_test.go new file mode 100644 index 000000000..4fb4e0f55 --- /dev/null +++ b/util/tls_test.go @@ -0,0 +1,45 @@ +package util + +import ( + "crypto/x509" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TLS Util", func() { + Describe("TLSGenerateSelfSignedCert", func() { + It("returns a good value", func() { + const domain = "whatever.test.blocky.invalid" + + cert, err := TLSGenerateSelfSignedCert([]string{domain}) + Expect(err).Should(Succeed()) + + Expect(cert.Certificate).ShouldNot(BeEmpty()) + + By("having the right Leaf", func() { + fromDER, err := x509.ParseCertificate(cert.Certificate[0]) + Expect(err).Should(Succeed()) + + Expect(cert.Leaf).Should(Equal(fromDER)) + }) + + By("being valid as self-signed for server TLS on the given domain", func() { + pool := x509.NewCertPool() + pool.AddCert(cert.Leaf) + + chain, err := cert.Leaf.Verify(x509.VerifyOptions{ + DNSName: domain, + Roots: pool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }) + Expect(err).Should(Succeed()) + Expect(chain).Should(Equal([][]*x509.Certificate{{cert.Leaf}})) + }) + + By("mentioning Blocky", func() { + Expect(cert.Leaf.Subject.Organization).Should(Equal([]string{"Blocky"})) + }) + }) + }) +}) From a4f853af8e0a95712c4c79c28ecf63dbf28c2893 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Thu, 29 Aug 2024 20:04:16 -0400 Subject: [PATCH 09/10] refactor: make QueryLoggingResolver read the hostname on creation Remove that code from util and the use of globals. Partly move away from calling the value a "hostname" since it's more of an instance ID, especially if you run more than one per host. --- querylog/database_writer.go | 2 +- querylog/file_writer.go | 2 +- querylog/logger_writer.go | 3 +- querylog/logger_writer_test.go | 1 - querylog/writer.go | 1 + resolver/query_logging_resolver.go | 46 +++++++++++++++++++----- resolver/query_logging_resolver_test.go | 22 ++++++++++-- server/server.go | 4 ++- util/hostname.go | 40 --------------------- util/hostname_test.go | 48 ------------------------- 10 files changed, 63 insertions(+), 106 deletions(-) delete mode 100644 util/hostname.go delete mode 100644 util/hostname_test.go diff --git a/querylog/database_writer.go b/querylog/database_writer.go index 678ff811f..54ae3a6c1 100644 --- a/querylog/database_writer.go +++ b/querylog/database_writer.go @@ -176,7 +176,7 @@ func (d *DatabaseWriter) Write(entry *LogEntry) { EffectiveTLDP: eTLD, Answer: entry.Answer, ResponseCode: entry.ResponseCode, - Hostname: util.HostnameString(), + Hostname: entry.BlockyInstance, } d.lock.Lock() diff --git a/querylog/file_writer.go b/querylog/file_writer.go index 1cb00cab9..545e99d6d 100644 --- a/querylog/file_writer.go +++ b/querylog/file_writer.go @@ -115,7 +115,7 @@ func createQueryLogRow(logEntry *LogEntry) []string { logEntry.ResponseCode, logEntry.ResponseType, logEntry.QuestionType, - util.HostnameString(), + logEntry.BlockyInstance, } } diff --git a/querylog/logger_writer.go b/querylog/logger_writer.go index 786bc6bc4..c2ec0822c 100644 --- a/querylog/logger_writer.go +++ b/querylog/logger_writer.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/0xERR0R/blocky/log" - "github.com/0xERR0R/blocky/util" "github.com/sirupsen/logrus" ) @@ -40,7 +39,7 @@ func LogEntryFields(entry *LogEntry) logrus.Fields { "question_type": entry.QuestionType, "answer": entry.Answer, "duration_ms": entry.DurationMs, - "hostname": util.HostnameString(), + "instance": entry.BlockyInstance, }) } diff --git a/querylog/logger_writer_test.go b/querylog/logger_writer_test.go index 38b136c54..b5053641f 100644 --- a/querylog/logger_writer_test.go +++ b/querylog/logger_writer_test.go @@ -51,7 +51,6 @@ var _ = Describe("LoggerWriter", func() { Expect(fields).Should(HaveKeyWithValue("duration_ms", entry.DurationMs)) Expect(fields).Should(HaveKeyWithValue("question_type", entry.QuestionType)) Expect(fields).Should(HaveKeyWithValue("response_code", entry.ResponseCode)) - Expect(fields).Should(HaveKey("hostname")) Expect(fields).ShouldNot(HaveKey("client_names")) Expect(fields).ShouldNot(HaveKey("question_name")) diff --git a/querylog/writer.go b/querylog/writer.go index 0d83534af..e68483ac0 100644 --- a/querylog/writer.go +++ b/querylog/writer.go @@ -15,6 +15,7 @@ type LogEntry struct { QuestionType string QuestionName string Answer string + BlockyInstance string } type Writer interface { diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go index 456d98c73..44eb92e1f 100644 --- a/resolver/query_logging_resolver.go +++ b/resolver/query_logging_resolver.go @@ -2,6 +2,10 @@ package resolver import ( "context" + "errors" + "fmt" + "os" + "strings" "time" "github.com/0xERR0R/blocky/config" @@ -25,8 +29,9 @@ type QueryLoggingResolver struct { NextResolver typed - logChan chan *querylog.LogEntry - writer querylog.Writer + logChan chan *querylog.LogEntry + writer querylog.Writer + instanceID string } func GetQueryLoggingWriter(ctx context.Context, cfg config.QueryLog) (querylog.Writer, error) { @@ -58,7 +63,7 @@ func GetQueryLoggingWriter(ctx context.Context, cfg config.QueryLog) (querylog.W } // NewQueryLoggingResolver returns a new resolver instance -func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLoggingResolver { +func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) (*QueryLoggingResolver, error) { logger := log.PrefixedLog(queryLoggingResolverType) var writer querylog.Writer @@ -86,14 +91,20 @@ func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLog cfg.Type = config.QueryLogTypeConsole } + instanceID, err := readInstanceID("/etc/hostname") + if err != nil { + return nil, err + } + logChan := make(chan *querylog.LogEntry, logChanCap) resolver := QueryLoggingResolver{ configurable: withConfig(&cfg), typed: withType(queryLoggingResolverType), - logChan: logChan, - writer: writer, + logChan: logChan, + writer: writer, + instanceID: instanceID, } go resolver.writeLog(ctx) @@ -103,7 +114,7 @@ func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLog go resolver.periodicCleanUp(ctx) } - return &resolver + return &resolver, nil } // triggers periodically cleanup of old log files @@ -170,9 +181,10 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response * start time.Time, durationMs int64, ) *querylog.LogEntry { entry := querylog.LogEntry{ - Start: start, - ClientIP: "0.0.0.0", - ClientNames: []string{"none"}, + Start: start, + ClientIP: "0.0.0.0", + ClientNames: []string{"none"}, + BlockyInstance: r.instanceID, } for _, f := range r.cfg.Fields { @@ -227,3 +239,19 @@ func (r *QueryLoggingResolver) writeLog(ctx context.Context) { } } } + +func readInstanceID(file string) (string, error) { + // Prefer /etc/hostname over os.Hostname to allow easy differentiation in a Docker Swarm + // See details in https://github.com/0xERR0R/blocky/pull/756 + hn, fErr := os.ReadFile(file) + if fErr == nil { + return strings.TrimSpace(string(hn)), nil + } + + hostname, osErr := os.Hostname() + if osErr == nil { + return hostname, nil + } + + return "", fmt.Errorf("cannot determine instance ID: %w", errors.Join(fErr, osErr)) +} diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index ce8c8978d..f979afde6 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -42,6 +42,7 @@ var _ = Describe("QueryLoggingResolver", func() { var ( sut *QueryLoggingResolver sutConfig config.QueryLog + err error m *mockResolver tmpDir *TmpFolder mockRType ResponseType @@ -61,8 +62,6 @@ var _ = Describe("QueryLoggingResolver", func() { ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) - var err error - sutConfig, err = config.WithDefaults[config.QueryLog]() Expect(err).Should(Succeed()) @@ -76,7 +75,8 @@ var _ = Describe("QueryLoggingResolver", func() { sutConfig.SetDefaults() // not called when using a struct literal } - sut = NewQueryLoggingResolver(ctx, sutConfig) + sut, err = NewQueryLoggingResolver(ctx, sutConfig) + Expect(err).Should(Succeed()) m = &mockResolver{ ResolveFn: func(context.Context, *Request) (*Response, error) { @@ -441,6 +441,22 @@ var _ = Describe("QueryLoggingResolver", func() { }) }) }) + + Describe("Hostname function tests", func() { + It("should use the given file if it exists", func() { + expected := "TestName" + tmpFile := NewTmpFolder("hostname").CreateStringFile("filetest1", expected+" \n") + + Expect(readInstanceID(tmpFile.Path)).Should(Equal(expected)) + }) + + It("should fallback to os.Hostname", func() { + expected, err := os.Hostname() + Expect(err).Should(Succeed()) + + Expect(readInstanceID("/var/empty/nonexistent")).Should(Equal(expected)) + }) + }) }) func readCsv(file string) ([][]string, error) { diff --git a/server/server.go b/server/server.go index 0ca7217f2..b33b85433 100644 --- a/server/server.go +++ b/server/server.go @@ -304,12 +304,14 @@ func createQueryResolver( upstreamTree, utErr := resolver.NewUpstreamTreeResolver(ctx, cfg.Upstreams, bootstrap) blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap) clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, cfg.Upstreams, bootstrap) + queryLogging, qlErr := resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(ctx, cfg.Conditional, cfg.Upstreams, bootstrap) hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap) err := multierror.Append( multierror.Prefix(utErr, "upstream tree resolver: "), multierror.Prefix(blErr, "blocking resolver: "), + multierror.Prefix(qlErr, "query logging resolver: "), multierror.Prefix(cnErr, "client names resolver: "), multierror.Prefix(cuErr, "conditional upstream resolver: "), multierror.Prefix(hfErr, "hosts file resolver: "), @@ -324,7 +326,7 @@ func createQueryResolver( resolver.NewECSResolver(cfg.ECS), clientNames, resolver.NewEDEResolver(cfg.EDE), - resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog), + queryLogging, resolver.NewMetricsResolver(cfg.Prometheus), resolver.NewRewriterResolver(cfg.CustomDNS.RewriterConfig, resolver.NewCustomDNSResolver(cfg.CustomDNS)), hostsFile, diff --git a/util/hostname.go b/util/hostname.go deleted file mode 100644 index 81bb033e1..000000000 --- a/util/hostname.go +++ /dev/null @@ -1,40 +0,0 @@ -package util - -import ( - "os" - "strings" -) - -//nolint:gochecknoglobals -var ( - hostname string - hostnameErr error -) - -const hostnameFile string = "/etc/hostname" - -//nolint:gochecknoinits -func init() { - getHostname(hostnameFile) -} - -// Direct replacement for os.Hostname -func Hostname() (string, error) { - return hostname, hostnameErr -} - -// Only return the hostname(may be empty if there was an error) -func HostnameString() string { - return hostname -} - -func getHostname(location string) { - hostname, hostnameErr = os.Hostname() - - if hn, err := os.ReadFile(location); err == nil { - hostname = strings.TrimSpace(string(hn)) - hostnameErr = nil - - return - } -} diff --git a/util/hostname_test.go b/util/hostname_test.go deleted file mode 100644 index 5d0274cb1..000000000 --- a/util/hostname_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package util - -import ( - "os" - "strings" - - "github.com/0xERR0R/blocky/helpertest" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Hostname function tests", func() { - When("file is present", func() { - var tmpDir *helpertest.TmpFolder - - BeforeEach(func() { - tmpDir = helpertest.NewTmpFolder("hostname") - }) - - It("should be used", func() { - tmpFile := tmpDir.CreateStringFile("filetest1", "TestName ") - - getHostname(tmpFile.Path) - - fhn, err := os.ReadFile(tmpFile.Path) - Expect(err).Should(Succeed()) - - hn, err := Hostname() - Expect(err).Should(Succeed()) - - Expect(hn).Should(Equal(strings.TrimSpace(string(fhn)))) - }) - }) - - When("file is not present", func() { - It("should use os.Hostname", func() { - getHostname("/does-not-exist") - - _, err := Hostname() - Expect(err).Should(Succeed()) - - ohn, err := os.Hostname() - Expect(err).Should(Succeed()) - - Expect(HostnameString()).Should(Equal(ohn)) - }) - }) -}) From 380906d1e1477ce66e5eeb27844fb9c3edd9a186 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sat, 30 Mar 2024 18:02:27 -0400 Subject: [PATCH 10/10] refactor: add `:` prefix to ports during config unmarshaling --- config/config.go | 7 +++++++ config/config_test.go | 4 ++-- helpertest/helper.go | 16 ++++++++++++---- server/server.go | 12 ++---------- server/server_test.go | 17 +++++++++-------- 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/config/config.go b/config/config.go index f512e17ce..8890b996e 100644 --- a/config/config.go +++ b/config/config.go @@ -171,6 +171,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error { *l = strings.Split(addresses, ",") + // Prefix all ports with : + for i, addr := range *l { + if !strings.ContainsRune(addr, ':') { + (*l)[i] = ":" + addr + } + } + return nil } diff --git a/config/config_test.go b/config/config_test.go index be75775d7..dc28c0b78 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -462,7 +462,7 @@ bootstrapDns: err := l.UnmarshalText([]byte("55,:56")) Expect(err).Should(Succeed()) Expect(*l).Should(HaveLen(2)) - Expect(*l).Should(ContainElements("55", ":56")) + Expect(*l).Should(ContainElements(":55", ":56")) }) }) }) @@ -958,7 +958,7 @@ bootstrapDns: }) func defaultTestFileConfig(config *Config) { - Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"})) + Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"})) Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError)) Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky")) Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3)) diff --git a/helpertest/helper.go b/helpertest/helper.go index cd5415bcc..8a3460875 100644 --- a/helpertest/helper.go +++ b/helpertest/helper.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "net" "net/http" "net/http/httptest" "os" @@ -31,20 +32,27 @@ const ( DS = dns.Type(dns.TypeDS) ) -// GetIntPort returns an port for the current testing +// GetIntPort returns a port for the current testing // process by adding the current ginkgo parallel process to -// the base port and returning it as int +// the base port and returning it as int. func GetIntPort(port int) int { return port + ginkgo.GinkgoParallelProcess() } -// GetStringPort returns an port for the current testing +// GetStringPort returns a port for the current testing // process by adding the current ginkgo parallel process to -// the base port and returning it as string +// the base port and returning it as string. func GetStringPort(port int) string { return fmt.Sprintf("%d", GetIntPort(port)) } +// GetHostPort returns a host:port string for the current testing +// process by adding the current ginkgo parallel process to +// the base port and returning it as string. +func GetHostPort(host string, port int) string { + return net.JoinHostPort(host, GetStringPort(port)) +} + // TempFile creates temp file with passed data func TempFile(data string) *os.File { f, err := os.CreateTemp("", "prefix") diff --git a/server/server.go b/server/server.go index b33b85433..7404541ad 100644 --- a/server/server.go +++ b/server/server.go @@ -60,14 +60,6 @@ func tlsCipherSuites() []uint16 { return tlsCipherSuites } -func getServerAddress(addr string) string { - if !strings.Contains(addr, ":") { - addr = fmt.Sprintf(":%s", addr) - } - - return addr -} - type NewServerFunc func(address string) (*dns.Server, error) func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { @@ -195,7 +187,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error { for _, address := range addresses { - server, err := newServer(getServerAddress(address)) + server, err := newServer(address) if err != nil { return err } @@ -236,7 +228,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene listeners := make([]net.Listener, 0, len(addresses)) for _, address := range addresses { - listener, err := net.Listen("tcp", getServerAddress(address)) + listener, err := net.Listen("tcp", address) if err != nil { return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) } diff --git a/server/server_test.go b/server/server_test.go index 60207d0ac..753c87afd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "fmt" "io" "net" "net/http" @@ -43,7 +44,7 @@ var ( ) var _ = BeforeSuite(func() { - baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/" + baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort)) queryURL = baseURL + "dns-query" var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream ctx, cancelFn := context.WithCancel(context.Background()) @@ -146,10 +147,10 @@ var _ = BeforeSuite(func() { }, Ports: config.Ports{ - DNS: config.ListenConfig{GetStringPort(dnsBasePort)}, - TLS: config.ListenConfig{GetStringPort(tlsBasePort)}, - HTTP: config.ListenConfig{GetStringPort(httpBasePort)}, - HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)}, + DNS: config.ListenConfig{GetHostPort("", dnsBasePort)}, + TLS: config.ListenConfig{GetHostPort("", tlsBasePort)}, + HTTP: config.ListenConfig{GetHostPort("", httpBasePort)}, + HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)}, }, CertFile: certPem.Path, KeyFile: keyPem.Path, @@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() { }, Blocking: config.Blocking{BlockType: "zeroIp"}, Ports: config.Ports{ - DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, + DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)}, }, }) @@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() { }, Blocking: config.Blocking{BlockType: "zeroIp"}, Ports: config.Ports{ - DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, + DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)}, }, }) @@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() { }) func requestServer(request *dns.Msg) *dns.Msg { - conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort)) + conn, err := net.Dial("udp", GetHostPort("", dnsBasePort)) if err != nil { Log().Fatal("could not connect to server: ", err) }