Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smaller refactors split from #1427 #1592

Merged
merged 10 commits into from
Sep 4, 2024
7 changes: 7 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {

*l = strings.Split(addresses, ",")

// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}

return nil
}

Expand Down
4 changes: 2 additions & 2 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
Expect(*l).Should(ContainElements(":55", ":56"))
})
})
})
Expand Down Expand Up @@ -958,7 +958,7 @@ bootstrapDns:
})

func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
Expand Down
16 changes: 12 additions & 4 deletions helpertest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -31,20 +32,27 @@ const (
DS = dns.Type(dns.TypeDS)
)

// GetIntPort returns an port for the current testing
// GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as int
// the base port and returning it as int.
func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess()
}

// GetStringPort returns an port for the current testing
// GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string
// the base port and returning it as string.
func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port))
}

// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}

// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")
Expand Down
2 changes: 1 addition & 1 deletion querylog/database_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (d *DatabaseWriter) Write(entry *LogEntry) {
EffectiveTLDP: eTLD,
Answer: entry.Answer,
ResponseCode: entry.ResponseCode,
Hostname: util.HostnameString(),
Hostname: entry.BlockyInstance,
}

d.lock.Lock()
Expand Down
2 changes: 1 addition & 1 deletion querylog/file_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func createQueryLogRow(logEntry *LogEntry) []string {
logEntry.ResponseCode,
logEntry.ResponseType,
logEntry.QuestionType,
util.HostnameString(),
logEntry.BlockyInstance,
}
}

Expand Down
3 changes: 1 addition & 2 deletions querylog/logger_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strings"

"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -40,7 +39,7 @@ func LogEntryFields(entry *LogEntry) logrus.Fields {
"question_type": entry.QuestionType,
"answer": entry.Answer,
"duration_ms": entry.DurationMs,
"hostname": util.HostnameString(),
"instance": entry.BlockyInstance,
})
}

Expand Down
1 change: 0 additions & 1 deletion querylog/logger_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ var _ = Describe("LoggerWriter", func() {
Expect(fields).Should(HaveKeyWithValue("duration_ms", entry.DurationMs))
Expect(fields).Should(HaveKeyWithValue("question_type", entry.QuestionType))
Expect(fields).Should(HaveKeyWithValue("response_code", entry.ResponseCode))
Expect(fields).Should(HaveKey("hostname"))

Expect(fields).ShouldNot(HaveKey("client_names"))
Expect(fields).ShouldNot(HaveKey("question_name"))
Expand Down
1 change: 1 addition & 0 deletions querylog/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type LogEntry struct {
QuestionType string
QuestionName string
Answer string
BlockyInstance string
}

type Writer interface {
Expand Down
46 changes: 37 additions & 9 deletions resolver/query_logging_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package resolver

import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"

"github.com/0xERR0R/blocky/config"
Expand All @@ -25,8 +29,9 @@ type QueryLoggingResolver struct {
NextResolver
typed

logChan chan *querylog.LogEntry
writer querylog.Writer
logChan chan *querylog.LogEntry
writer querylog.Writer
instanceID string
}

func GetQueryLoggingWriter(ctx context.Context, cfg config.QueryLog) (querylog.Writer, error) {
Expand Down Expand Up @@ -58,7 +63,7 @@ func GetQueryLoggingWriter(ctx context.Context, cfg config.QueryLog) (querylog.W
}

// NewQueryLoggingResolver returns a new resolver instance
func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLoggingResolver {
func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) (*QueryLoggingResolver, error) {
logger := log.PrefixedLog(queryLoggingResolverType)

var writer querylog.Writer
Expand Down Expand Up @@ -86,14 +91,20 @@ func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLog
cfg.Type = config.QueryLogTypeConsole
}

instanceID, err := readInstanceID("/etc/hostname")
if err != nil {
return nil, err
}

logChan := make(chan *querylog.LogEntry, logChanCap)

resolver := QueryLoggingResolver{
configurable: withConfig(&cfg),
typed: withType(queryLoggingResolverType),

logChan: logChan,
writer: writer,
logChan: logChan,
writer: writer,
instanceID: instanceID,
}

go resolver.writeLog(ctx)
Expand All @@ -103,7 +114,7 @@ func NewQueryLoggingResolver(ctx context.Context, cfg config.QueryLog) *QueryLog
go resolver.periodicCleanUp(ctx)
}

return &resolver
return &resolver, nil
}

// triggers periodically cleanup of old log files
Expand Down Expand Up @@ -170,9 +181,10 @@ func (r *QueryLoggingResolver) createLogEntry(request *model.Request, response *
start time.Time, durationMs int64,
) *querylog.LogEntry {
entry := querylog.LogEntry{
Start: start,
ClientIP: "0.0.0.0",
ClientNames: []string{"none"},
Start: start,
ClientIP: "0.0.0.0",
ClientNames: []string{"none"},
BlockyInstance: r.instanceID,
}

for _, f := range r.cfg.Fields {
Expand Down Expand Up @@ -227,3 +239,19 @@ func (r *QueryLoggingResolver) writeLog(ctx context.Context) {
}
}
}

func readInstanceID(file string) (string, error) {
// Prefer /etc/hostname over os.Hostname to allow easy differentiation in a Docker Swarm
// See details in https://github.com/0xERR0R/blocky/pull/756
hn, fErr := os.ReadFile(file)
if fErr == nil {
return strings.TrimSpace(string(hn)), nil
}

hostname, osErr := os.Hostname()
if osErr == nil {
return hostname, nil
}

return "", fmt.Errorf("cannot determine instance ID: %w", errors.Join(fErr, osErr))
}
22 changes: 19 additions & 3 deletions resolver/query_logging_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var _ = Describe("QueryLoggingResolver", func() {
var (
sut *QueryLoggingResolver
sutConfig config.QueryLog
err error
m *mockResolver
tmpDir *TmpFolder
mockRType ResponseType
Expand All @@ -61,8 +62,6 @@ var _ = Describe("QueryLoggingResolver", func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)

var err error

sutConfig, err = config.WithDefaults[config.QueryLog]()
Expect(err).Should(Succeed())

Expand All @@ -76,7 +75,8 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig.SetDefaults() // not called when using a struct literal
}

sut = NewQueryLoggingResolver(ctx, sutConfig)
sut, err = NewQueryLoggingResolver(ctx, sutConfig)
Expect(err).Should(Succeed())

m = &mockResolver{
ResolveFn: func(context.Context, *Request) (*Response, error) {
Expand Down Expand Up @@ -441,6 +441,22 @@ var _ = Describe("QueryLoggingResolver", func() {
})
})
})

Describe("Hostname function tests", func() {
It("should use the given file if it exists", func() {
expected := "TestName"
tmpFile := NewTmpFolder("hostname").CreateStringFile("filetest1", expected+" \n")

Expect(readInstanceID(tmpFile.Path)).Should(Equal(expected))
})

It("should fallback to os.Hostname", func() {
expected, err := os.Hostname()
Expect(err).Should(Succeed())

Expect(readInstanceID("/var/empty/nonexistent")).Should(Equal(expected))
})
})
})

func readCsv(file string) ([][]string, error) {
Expand Down
96 changes: 96 additions & 0 deletions server/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package server

import (
"context"
"net"
"net/http"
"time"

"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
)

type httpServer struct {
inner http.Server

name string
}

func newHTTPServer(name string, handler http.Handler) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)

return &httpServer{
inner: http.Server{
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,

Handler: withCommonMiddleware(handler),
},

name: name,
}
}

func (s *httpServer) String() string {
return s.name
}

func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()

s.inner.Close()
}()

return s.inner.Serve(l)
}

func withCommonMiddleware(inner http.Handler) *chi.Mux {
// Middleware must be defined before routes, so
// create a new router and mount the inner handler
mux := chi.NewMux()

mux.Use(
secureHeadersMiddleware,
newCORSMiddleware(),
)

mux.Mount("/", inner)

return mux
}

type httpMiddleware = func(http.Handler) http.Handler

func secureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
}

next.ServeHTTP(w, r)
})
}

func newCORSMiddleware() httpMiddleware {
const corsMaxAge = 5 * time.Minute

options := cors.Options{
AllowCredentials: true,
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
AllowedMethods: []string{"GET", "POST"},
AllowedOrigins: []string{"*"},
ExposedHeaders: []string{"Link"},
MaxAge: int(corsMaxAge.Seconds()),
}

return cors.New(options).Handler
}
Loading
Loading