From 44d71cd015b43b28f3f26f650e7486b76b66f96b Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Mon, 23 Sep 2024 12:48:57 -0700 Subject: [PATCH] server: replace use of setec.Watcher with setec.Updater --- server/tailsql/internal_test.go | 6 ++- server/tailsql/options.go | 76 +++++++++++++++------------------ server/tailsql/tailsql.go | 26 ++++++----- server/tailsql/tailsql_test.go | 63 ++++++++++++++++++++++++--- 4 files changed, 110 insertions(+), 61 deletions(-) diff --git a/server/tailsql/internal_test.go b/server/tailsql/internal_test.go index d0389ea..270618e 100644 --- a/server/tailsql/internal_test.go +++ b/server/tailsql/internal_test.go @@ -4,6 +4,7 @@ package tailsql import ( + "context" "database/sql" "os" "testing" @@ -90,13 +91,14 @@ func TestOptions(t *testing.T) { // Test that we can populate options from the config. t.Run("Options", func(t *testing.T) { - dbs, err := opts.openSources(nil) + dbs, err := opts.openSources(context.Background(), nil) if err != nil { t.Fatalf("Options: unexpected error: %v", err) } // The handles should be equinumerous and in the same order as the config. - for i, h := range dbs { + for i, u := range dbs { + h := u.Get() if got, want := h.Source(), opts.Sources[i].Source; got != want { t.Errorf("Database %d: got src %q, want %q", i+1, got, want) } diff --git a/server/tailsql/options.go b/server/tailsql/options.go index 513e90f..b0ba261 100644 --- a/server/tailsql/options.go +++ b/server/tailsql/options.go @@ -112,12 +112,12 @@ func (o Options) checkQuery() func(Query) (Query, error) { // openSources opens database handles to each of the sources defined by o. // Sources that require secrets will get them from store. // Precondition: All the sources of o have already been validated. -func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { +func (o Options) openSources(ctx context.Context, store *setec.Store) ([]*setec.Updater[*dbHandle], error) { if len(o.Sources) == 0 { return nil, nil } - srcs := make([]*dbHandle, len(o.Sources)) + srcs := make([]*setec.Updater[*dbHandle], len(o.Sources)) for i, spec := range o.Sources { if spec.Label == "" { spec.Label = "(unidentified database)" @@ -125,20 +125,45 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { // Case 1: A programmatic source. if spec.DB != nil { - srcs[i] = &dbHandle{ + srcs[i] = setec.StaticUpdater(&dbHandle{ src: spec.Source, label: spec.Label, named: spec.Named, db: spec.DB, + }) + continue + } + + // Case 2: A database managed by database/sql, with a secret from setec. + if spec.Secret != "" { + // We actually only maintain a single value, that is updated in-place. + h := &dbHandle{src: spec.Source, label: spec.Label, named: spec.Named} + u, err := setec.NewUpdater(ctx, store, spec.Secret, func(secret []byte) (*dbHandle, error) { + db, err := openAndPing(spec.Driver, string(secret)) + if err != nil { + return nil, err + } + o.logf()("[tailsql] opened new connection for source %q", spec.Source) + h.mu.Lock() + defer h.mu.Unlock() + if h.db != nil { + h.db.Close() // close the active handle + } + if up := h.checkUpdate(); up != nil { + up.newDB.Close() // close a previous pending update + } + h.db = sqlDB{DB: db} + return h, nil + }) + if err != nil { + return nil, err } + srcs[i] = u continue } - // Case 2: A database managed by database/sql. - // - // Resolve the connection string. + // Case 3: A database managed by database/sql, with a fixed URL. var connString string - var w setec.Watcher switch { case spec.URL != "": connString = spec.URL @@ -148,9 +173,6 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { return nil, fmt.Errorf("read key file for %q: %w", spec.Source, err) } connString = strings.TrimSpace(string(data)) - case spec.Secret != "": - w = store.Watcher(spec.Secret) - connString = string(w.Get()) default: panic("unexpected: no connection source is defined after validation") } @@ -160,16 +182,13 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { if err != nil { return nil, err } - srcs[i] = &dbHandle{ + srcs[i] = setec.StaticUpdater(&dbHandle{ src: spec.Source, driver: spec.Driver, label: spec.Label, named: spec.Named, db: sqlDB{DB: db}, - } - if spec.Secret != "" { - go srcs[i].handleUpdates(spec.Secret, w, o.logf()) - } + }) } return srcs, nil } @@ -325,33 +344,6 @@ type dbHandle struct { named map[string]string } -// handleUpdates polls w indefinitely for updates to the connection string for -// h, and reopens the database with the new string when a new value arrives. -// This method should be called in a goroutine. -func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf) { - logf("[tailsql] starting updater for secret %q", name) - for range w.Ready() { - // N.B. Don't log the secret value itself. It's fine to log the name of - // the secret and the source, those are already in the config. - connString := string(w.Get()) - db, err := openAndPing(h.driver, connString) - if err != nil { - logf("WARNING: opening new database for %q: %v", h.src, err) - continue - } - logf("[tailsql] opened new connection for source %q", h.src) - h.mu.Lock() - // Close the existing active handle. - h.db.Close() - // If there's a pending update, close it too. - if up := h.checkUpdate(); up != nil { - up.newDB.Close() - } - h.db = sqlDB{DB: db} - h.mu.Unlock() - } -} - // checkUpdate returns nil if there is no pending update, otherwise it swaps // out the pending database update and returns it. func (h *dbHandle) checkUpdate() *dbUpdate { diff --git a/server/tailsql/tailsql.go b/server/tailsql/tailsql.go index 8c519eb..247896a 100644 --- a/server/tailsql/tailsql.go +++ b/server/tailsql/tailsql.go @@ -68,6 +68,7 @@ import ( "time" "unicode/utf8" + "github.com/tailscale/setec/client/setec" "tailscale.com/client/tailscale/apitype" "tailscale.com/types/logger" "tailscale.com/util/httpm" @@ -119,7 +120,7 @@ type Server struct { logf logger.Logf mu sync.Mutex - dbs []*dbHandle + dbs []*setec.Updater[*dbHandle] } // NewServer constructs a new server with the given Options. @@ -134,7 +135,7 @@ func NewServer(opts Options) (*Server, error) { return nil, fmt.Errorf("have %d named secrets but no secret store", len(sec)) } - dbs, err := opts.openSources(opts.SecretStore) + dbs, err := opts.openSources(context.Background(), opts.SecretStore) if err != nil { return nil, fmt.Errorf("opening sources: %w", err) } @@ -143,14 +144,14 @@ func NewServer(opts Options) (*Server, error) { return nil, fmt.Errorf("local state: %w", err) } if state != nil && opts.LocalSource != "" { - dbs = append(dbs, &dbHandle{ + dbs = append(dbs, setec.StaticUpdater(&dbHandle{ src: opts.LocalSource, label: "tailsql local state", db: state, named: map[string]string{ "schema": `select * from sqlite_schema`, }, - }) + })) } if opts.Metrics != nil { @@ -192,18 +193,18 @@ func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool { s.mu.Lock() defer s.mu.Unlock() - for _, src := range s.dbs { - if src.Source() == source { + for _, u := range s.dbs { + if src := u.Get(); src.Source() == source { src.swap(db, opts) return true } } - s.dbs = append(s.dbs, &dbHandle{ + s.dbs = append(s.dbs, setec.StaticUpdater(&dbHandle{ db: db, src: source, label: opts.label(), named: opts.namedQueries(), - }) + })) return false } @@ -613,12 +614,15 @@ func (s *Server) getHandles() []*dbHandle { s.mu.Lock() defer s.mu.Unlock() + out := make([]*dbHandle, len(s.dbs)) + // Check for pending updates. - for _, h := range s.dbs { - h.tryUpdate() + for i, u := range s.dbs { + out[i] = u.Get() + out[i].tryUpdate() } // It is safe to return the slice because we never remove any elements, new // data are only ever appended to the end. - return s.dbs + return out } diff --git a/server/tailsql/tailsql_test.go b/server/tailsql/tailsql_test.go index d1dac6c..069806e 100644 --- a/server/tailsql/tailsql_test.go +++ b/server/tailsql/tailsql_test.go @@ -6,11 +6,13 @@ package tailsql_test import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "html" "html/template" "io" + "math/rand/v2" "net/http" "net/http/httptest" "net/url" @@ -128,10 +130,17 @@ var testUIRules = []tailsql.UIRewriteRule{ } func TestSecrets(t *testing.T) { + // Register a fake driver so we can probe for connection URLs. + // We have to use a new name each time, because there is no way to + // unregister and duplicate names trigger a panic. + driver := new(fakeDriver) + driverName := fmt.Sprintf("%s-driver-%d", t.Name(), rand.Int()) + sql.Register(driverName, driver) + t.Logf("Test driver name is %q", driverName) + const secretName = "connection-string" - url, _ := mustInitSQLite(t) db := setectest.NewDB(t, nil) - db.MustPut(db.Superuser, secretName, url) + db.MustPut(db.Superuser, secretName, "string 1") ss := setectest.NewServer(t, db, nil) hs := httptest.NewServer(ss.Mux) @@ -141,17 +150,23 @@ func TestSecrets(t *testing.T) { Sources: []tailsql.DBSpec{{ Source: "test", Label: "Test Database", - Driver: "sqlite", + Driver: driverName, Secret: secretName, }}, + RoutePrefix: "/tsql", } + + // Verify we found the expected secret names in the options. secrets, err := opts.CheckSources() if err != nil { t.Fatalf("Invalid sources: %v", err) } + + tick := setectest.NewFakeTicker() st, err := setec.NewStore(context.Background(), setec.StoreConfig{ - Client: setec.Client{Server: hs.URL}, - Secrets: secrets, + Client: setec.Client{Server: hs.URL}, + Secrets: secrets, + PollTicker: tick, }) if err != nil { t.Fatalf("Creating setec store: %v", err) @@ -162,7 +177,28 @@ func TestSecrets(t *testing.T) { if err != nil { t.Fatalf("Creating tailsql server: %v", err) } - ts.Close() + ss.Mux.Handle("/tsql/", ts.NewMux()) // so we can call /meta below + defer ts.Close() + + // After opening the server, the database should have the initial secret + // value provided on initialization. + if got, want := driver.OpenedURL, "string 1"; got != want { + t.Errorf("Initial URL: got %q, want %q", got, want) + } + + // Update the secret. + db.MustActivate(db.Superuser, secretName, db.MustPut(db.Superuser, secretName, "string 2")) + tick.Poll() + + // Make the database fetch the latest value. + if _, err := hs.Client().Get(hs.URL + "/tsql/meta"); err != nil { + t.Errorf("Get tailsql meta: %v", err) + } + + // After the update, the database should have the new secret value. + if got, want := driver.OpenedURL, "string 2"; got != want { + t.Errorf("Updated URL: got %q, want %q", got, want) + } } func TestServer(t *testing.T) { @@ -567,3 +603,18 @@ func TestRoutePrefix(t *testing.T) { } }) } + +type fakeDriver struct { + OpenedURL string +} + +func (f *fakeDriver) Open(url string) (driver.Conn, error) { + f.OpenedURL = url + return fakeConn{}, nil +} + +// fakeConn is a fake implementation of driver.Conn to satisfy the interface, +// it will panic if actually used. +type fakeConn struct{ driver.Conn } + +func (fakeConn) Close() error { return nil }