Skip to content

Commit

Permalink
server: replace use of setec.Watcher with setec.Updater
Browse files Browse the repository at this point in the history
  • Loading branch information
creachadair committed Sep 24, 2024
1 parent ac7081c commit 44d71cd
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 61 deletions.
6 changes: 4 additions & 2 deletions server/tailsql/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tailsql

import (
"context"
"database/sql"
"os"
"testing"
Expand Down Expand Up @@ -90,13 +91,14 @@ func TestOptions(t *testing.T) {

// Test that we can populate options from the config.
t.Run("Options", func(t *testing.T) {
dbs, err := opts.openSources(nil)
dbs, err := opts.openSources(context.Background(), nil)
if err != nil {
t.Fatalf("Options: unexpected error: %v", err)
}

// The handles should be equinumerous and in the same order as the config.
for i, h := range dbs {
for i, u := range dbs {
h := u.Get()
if got, want := h.Source(), opts.Sources[i].Source; got != want {
t.Errorf("Database %d: got src %q, want %q", i+1, got, want)
}
Expand Down
76 changes: 34 additions & 42 deletions server/tailsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,58 @@ func (o Options) checkQuery() func(Query) (Query, error) {
// openSources opens database handles to each of the sources defined by o.
// Sources that require secrets will get them from store.
// Precondition: All the sources of o have already been validated.
func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
func (o Options) openSources(ctx context.Context, store *setec.Store) ([]*setec.Updater[*dbHandle], error) {
if len(o.Sources) == 0 {
return nil, nil
}

srcs := make([]*dbHandle, len(o.Sources))
srcs := make([]*setec.Updater[*dbHandle], len(o.Sources))
for i, spec := range o.Sources {
if spec.Label == "" {
spec.Label = "(unidentified database)"
}

// Case 1: A programmatic source.
if spec.DB != nil {
srcs[i] = &dbHandle{
srcs[i] = setec.StaticUpdater(&dbHandle{
src: spec.Source,
label: spec.Label,
named: spec.Named,
db: spec.DB,
})
continue
}

// Case 2: A database managed by database/sql, with a secret from setec.
if spec.Secret != "" {
// We actually only maintain a single value, that is updated in-place.
h := &dbHandle{src: spec.Source, label: spec.Label, named: spec.Named}
u, err := setec.NewUpdater(ctx, store, spec.Secret, func(secret []byte) (*dbHandle, error) {
db, err := openAndPing(spec.Driver, string(secret))
if err != nil {
return nil, err
}
o.logf()("[tailsql] opened new connection for source %q", spec.Source)
h.mu.Lock()
defer h.mu.Unlock()
if h.db != nil {
h.db.Close() // close the active handle
}
if up := h.checkUpdate(); up != nil {
up.newDB.Close() // close a previous pending update
}
h.db = sqlDB{DB: db}
return h, nil
})
if err != nil {
return nil, err
}
srcs[i] = u
continue
}

// Case 2: A database managed by database/sql.
//
// Resolve the connection string.
// Case 3: A database managed by database/sql, with a fixed URL.
var connString string
var w setec.Watcher
switch {
case spec.URL != "":
connString = spec.URL
Expand All @@ -148,9 +173,6 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
return nil, fmt.Errorf("read key file for %q: %w", spec.Source, err)
}
connString = strings.TrimSpace(string(data))
case spec.Secret != "":
w = store.Watcher(spec.Secret)
connString = string(w.Get())
default:
panic("unexpected: no connection source is defined after validation")
}
Expand All @@ -160,16 +182,13 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
if err != nil {
return nil, err
}
srcs[i] = &dbHandle{
srcs[i] = setec.StaticUpdater(&dbHandle{
src: spec.Source,
driver: spec.Driver,
label: spec.Label,
named: spec.Named,
db: sqlDB{DB: db},
}
if spec.Secret != "" {
go srcs[i].handleUpdates(spec.Secret, w, o.logf())
}
})
}
return srcs, nil
}
Expand Down Expand Up @@ -325,33 +344,6 @@ type dbHandle struct {
named map[string]string
}

// handleUpdates polls w indefinitely for updates to the connection string for
// h, and reopens the database with the new string when a new value arrives.
// This method should be called in a goroutine.
func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf) {
logf("[tailsql] starting updater for secret %q", name)
for range w.Ready() {
// N.B. Don't log the secret value itself. It's fine to log the name of
// the secret and the source, those are already in the config.
connString := string(w.Get())
db, err := openAndPing(h.driver, connString)
if err != nil {
logf("WARNING: opening new database for %q: %v", h.src, err)
continue
}
logf("[tailsql] opened new connection for source %q", h.src)
h.mu.Lock()
// Close the existing active handle.
h.db.Close()
// If there's a pending update, close it too.
if up := h.checkUpdate(); up != nil {
up.newDB.Close()
}
h.db = sqlDB{DB: db}
h.mu.Unlock()
}
}

// checkUpdate returns nil if there is no pending update, otherwise it swaps
// out the pending database update and returns it.
func (h *dbHandle) checkUpdate() *dbUpdate {
Expand Down
26 changes: 15 additions & 11 deletions server/tailsql/tailsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import (
"time"
"unicode/utf8"

"github.com/tailscale/setec/client/setec"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/logger"
"tailscale.com/util/httpm"
Expand Down Expand Up @@ -119,7 +120,7 @@ type Server struct {
logf logger.Logf

mu sync.Mutex
dbs []*dbHandle
dbs []*setec.Updater[*dbHandle]
}

// NewServer constructs a new server with the given Options.
Expand All @@ -134,7 +135,7 @@ func NewServer(opts Options) (*Server, error) {
return nil, fmt.Errorf("have %d named secrets but no secret store", len(sec))
}

dbs, err := opts.openSources(opts.SecretStore)
dbs, err := opts.openSources(context.Background(), opts.SecretStore)
if err != nil {
return nil, fmt.Errorf("opening sources: %w", err)
}
Expand All @@ -143,14 +144,14 @@ func NewServer(opts Options) (*Server, error) {
return nil, fmt.Errorf("local state: %w", err)
}
if state != nil && opts.LocalSource != "" {
dbs = append(dbs, &dbHandle{
dbs = append(dbs, setec.StaticUpdater(&dbHandle{
src: opts.LocalSource,
label: "tailsql local state",
db: state,
named: map[string]string{
"schema": `select * from sqlite_schema`,
},
})
}))
}

if opts.Metrics != nil {
Expand Down Expand Up @@ -192,18 +193,18 @@ func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool {
s.mu.Lock()
defer s.mu.Unlock()

for _, src := range s.dbs {
if src.Source() == source {
for _, u := range s.dbs {
if src := u.Get(); src.Source() == source {
src.swap(db, opts)
return true
}
}
s.dbs = append(s.dbs, &dbHandle{
s.dbs = append(s.dbs, setec.StaticUpdater(&dbHandle{
db: db,
src: source,
label: opts.label(),
named: opts.namedQueries(),
})
}))
return false
}

Expand Down Expand Up @@ -613,12 +614,15 @@ func (s *Server) getHandles() []*dbHandle {
s.mu.Lock()
defer s.mu.Unlock()

out := make([]*dbHandle, len(s.dbs))

// Check for pending updates.
for _, h := range s.dbs {
h.tryUpdate()
for i, u := range s.dbs {
out[i] = u.Get()
out[i].tryUpdate()
}

// It is safe to return the slice because we never remove any elements, new
// data are only ever appended to the end.
return s.dbs
return out
}
63 changes: 57 additions & 6 deletions server/tailsql/tailsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package tailsql_test
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"html"
"html/template"
"io"
"math/rand/v2"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -128,10 +130,17 @@ var testUIRules = []tailsql.UIRewriteRule{
}

func TestSecrets(t *testing.T) {
// Register a fake driver so we can probe for connection URLs.
// We have to use a new name each time, because there is no way to
// unregister and duplicate names trigger a panic.
driver := new(fakeDriver)
driverName := fmt.Sprintf("%s-driver-%d", t.Name(), rand.Int())
sql.Register(driverName, driver)
t.Logf("Test driver name is %q", driverName)

const secretName = "connection-string"
url, _ := mustInitSQLite(t)
db := setectest.NewDB(t, nil)
db.MustPut(db.Superuser, secretName, url)
db.MustPut(db.Superuser, secretName, "string 1")

ss := setectest.NewServer(t, db, nil)
hs := httptest.NewServer(ss.Mux)
Expand All @@ -141,17 +150,23 @@ func TestSecrets(t *testing.T) {
Sources: []tailsql.DBSpec{{
Source: "test",
Label: "Test Database",
Driver: "sqlite",
Driver: driverName,
Secret: secretName,
}},
RoutePrefix: "/tsql",
}

// Verify we found the expected secret names in the options.
secrets, err := opts.CheckSources()
if err != nil {
t.Fatalf("Invalid sources: %v", err)
}

tick := setectest.NewFakeTicker()
st, err := setec.NewStore(context.Background(), setec.StoreConfig{
Client: setec.Client{Server: hs.URL},
Secrets: secrets,
Client: setec.Client{Server: hs.URL},
Secrets: secrets,
PollTicker: tick,
})
if err != nil {
t.Fatalf("Creating setec store: %v", err)
Expand All @@ -162,7 +177,28 @@ func TestSecrets(t *testing.T) {
if err != nil {
t.Fatalf("Creating tailsql server: %v", err)
}
ts.Close()
ss.Mux.Handle("/tsql/", ts.NewMux()) // so we can call /meta below
defer ts.Close()

// After opening the server, the database should have the initial secret
// value provided on initialization.
if got, want := driver.OpenedURL, "string 1"; got != want {
t.Errorf("Initial URL: got %q, want %q", got, want)
}

// Update the secret.
db.MustActivate(db.Superuser, secretName, db.MustPut(db.Superuser, secretName, "string 2"))
tick.Poll()

// Make the database fetch the latest value.
if _, err := hs.Client().Get(hs.URL + "/tsql/meta"); err != nil {
t.Errorf("Get tailsql meta: %v", err)
}

// After the update, the database should have the new secret value.
if got, want := driver.OpenedURL, "string 2"; got != want {
t.Errorf("Updated URL: got %q, want %q", got, want)
}
}

func TestServer(t *testing.T) {
Expand Down Expand Up @@ -567,3 +603,18 @@ func TestRoutePrefix(t *testing.T) {
}
})
}

type fakeDriver struct {
OpenedURL string
}

func (f *fakeDriver) Open(url string) (driver.Conn, error) {
f.OpenedURL = url
return fakeConn{}, nil
}

// fakeConn is a fake implementation of driver.Conn to satisfy the interface,
// it will panic if actually used.
type fakeConn struct{ driver.Conn }

func (fakeConn) Close() error { return nil }

0 comments on commit 44d71cd

Please sign in to comment.