From 4f4687a739a691faa1ab8b7ac5fa1652f322ff11 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Thu, 23 Nov 2023 13:32:45 -0800 Subject: [PATCH] server/tailsql: add a query context decorator Our cgo sqlite driver takes some options on the request context. To give us a way to plumb those in, add an optional callback. --- server/tailsql/options.go | 6 ++ server/tailsql/tailsql.go | 135 ++++++++++++++++++--------------- server/tailsql/tailsql_test.go | 18 ++++- 3 files changed, 97 insertions(+), 62 deletions(-) diff --git a/server/tailsql/options.go b/server/tailsql/options.go index 7124039..a6dc0bc 100644 --- a/server/tailsql/options.go +++ b/server/tailsql/options.go @@ -82,6 +82,12 @@ type Options struct { // by the rule replaces the original string. UIRewriteRules []UIRewriteRule `json:"-"` + // If non-nil, this function is called to annotate ctx before passing it in + // to a database query for the given source. If the callback is nil, or if + // it returns nil, ctx is used unmodified. Otherwise the returned value + // replaces ctx in the query. + QueryContext func(ctx context.Context, src, query string) context.Context `json:"-"` + // If non-nil, send logs to this logger. If nil, use log.Printf. Logf logger.Logf `json:"-"` } diff --git a/server/tailsql/tailsql.go b/server/tailsql/tailsql.go index 0f8231d..128a18e 100644 --- a/server/tailsql/tailsql.go +++ b/server/tailsql/tailsql.go @@ -114,6 +114,7 @@ type Server struct { rules []UIRewriteRule authorize func(string, *apitype.WhoIsResponse) error qtimeout time.Duration + qcontext func(ctx context.Context, src, query string) context.Context logf logger.Logf mu sync.Mutex @@ -166,6 +167,7 @@ func NewServer(opts Options) (*Server, error) { rules: opts.UIRewriteRules, authorize: opts.authorize(), qtimeout: opts.QueryTimeout.Duration(), + qcontext: opts.QueryContext, logf: opts.logf(), dbs: dbs, }, nil @@ -438,76 +440,77 @@ func (s *Server) queryContext(ctx context.Context, caller, src, query string) (* defer cancel() } - return runQueryInTx(ctx, h, func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) { - start := time.Now() - var out dbResult - defer func() { - out.Elapsed = time.Since(start) - s.logf("[tailsql] query src=%q query=%q elapsed=%v err=%v", - src, query, out.Elapsed.Round(time.Millisecond), err) - - // Record successful queries in the persistent log. But don't log - // queries to the state database itself. - if err == nil && src != s.self { - serr := s.state.LogQuery(ctx, caller, src, query) - if serr != nil { - s.logf("[tailsql] WARNING: Error logging query: %v", serr) + return runQueryInTx(s.getQueryContext(ctx, src, query), h, + func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) { + start := time.Now() + var out dbResult + defer func() { + out.Elapsed = time.Since(start) + s.logf("[tailsql] query src=%q query=%q elapsed=%v err=%v", + src, query, out.Elapsed.Round(time.Millisecond), err) + + // Record successful queries in the persistent log. But don't log + // queries to the state database itself. + if err == nil && src != s.self { + serr := s.state.LogQuery(ctx, caller, src, query) + if serr != nil { + s.logf("[tailsql] WARNING: Error logging query: %v", serr) + } } - } - }() + }() - // Check for a named query. - if name, ok := strings.CutPrefix(query, "named:"); ok { - real, ok := lookupNamedQuery(fctx, name) - if !ok { - return nil, statusErrorf(http.StatusBadRequest, "named query %q not recognized", name) + // Check for a named query. + if name, ok := strings.CutPrefix(query, "named:"); ok { + real, ok := lookupNamedQuery(fctx, name) + if !ok { + return nil, statusErrorf(http.StatusBadRequest, "named query %q not recognized", name) + } + s.logf("[tailsql] resolved named query %q to %#q", name, real) + query = real } - s.logf("[tailsql] resolved named query %q to %#q", name, real) - query = real - } - rows, err := tx.QueryContext(fctx, query) - if err != nil { - return nil, err - } - defer rows.Close() - - cols, err := rows.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("listing column types: %w", err) - } - for _, col := range cols { - out.Columns = append(out.Columns, col.Name()) - } + rows, err := tx.QueryContext(fctx, query) + if err != nil { + return nil, err + } + defer rows.Close() - var tooMany bool - for rows.Next() && !tooMany { - if len(out.Rows) == maxRowsPerQuery { - tooMany = true - break - } else if fctx.Err() != nil { - return nil, fmt.Errorf("scanning row: %w", fctx.Err()) + cols, err := rows.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("listing column types: %w", err) } - vals := make([]any, len(cols)) - vptr := make([]any, len(cols)) - for i := range cols { - vptr[i] = &vals[i] + for _, col := range cols { + out.Columns = append(out.Columns, col.Name()) } - if err := rows.Scan(vptr...); err != nil { - return nil, fmt.Errorf("scanning row: %w", err) + + var tooMany bool + for rows.Next() && !tooMany { + if len(out.Rows) == maxRowsPerQuery { + tooMany = true + break + } else if fctx.Err() != nil { + return nil, fmt.Errorf("scanning row: %w", fctx.Err()) + } + vals := make([]any, len(cols)) + vptr := make([]any, len(cols)) + for i := range cols { + vptr[i] = &vals[i] + } + if err := rows.Scan(vptr...); err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + out.Rows = append(out.Rows, vals) } - out.Rows = append(out.Rows, vals) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("scanning rows: %w", err) - } - out.NumRows = len(out.Rows) + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scanning rows: %w", err) + } + out.NumRows = len(out.Rows) - if tooMany { - return &out, errTooManyRows - } - return &out, nil - }) + if tooMany { + return &out, errTooManyRows + } + return &out, nil + }) } // queryMeta handles meta-queries for internal state. @@ -610,3 +613,13 @@ func (s *Server) getHandles() []*dbHandle { // data are only ever appended to the end. return s.dbs } + +// getQueryContext decorates ctx if necessary using the context hook for src and query. +func (s *Server) getQueryContext(ctx context.Context, src, query string) context.Context { + if s.qcontext != nil { + if qctx := s.qcontext(ctx, src, query); qctx != nil { + return qctx + } + } + return ctx +} diff --git a/server/tailsql/tailsql_test.go b/server/tailsql/tailsql_test.go index ecaf1b0..72818fb 100644 --- a/server/tailsql/tailsql_test.go +++ b/server/tailsql/tailsql_test.go @@ -19,6 +19,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/tailscale/setec/client/setec" "github.com/tailscale/setec/setectest" "github.com/tailscale/tailsql/authorizer" @@ -186,6 +187,7 @@ func TestServer(t *testing.T) { }, }, } + var contextHookData [2]string s, err := tailsql.NewServer(tailsql.Options{ LocalClient: fc, UILinks: []tailsql.UILink{ @@ -193,7 +195,12 @@ func TestServer(t *testing.T) { }, UIRewriteRules: testUIRules, Authorize: authorizer.PeerCaps(nil), - Logf: t.Logf, + QueryContext: func(ctx context.Context, src, query string) context.Context { + contextHookData[0] = src + contextHookData[1] = query + return ctx + }, + Logf: t.Logf, }) if err != nil { t.Fatalf("NewServer: unexpected error: %v", err) @@ -210,6 +217,15 @@ func TestServer(t *testing.T) { defer htest.Close() cli := htest.Client() + t.Run("ContextHook", func(t *testing.T) { + q := url.Values{"q": {"select count(*) from users"}} + url := htest.URL + "?" + q.Encode() + mustGet(t, cli, url) + if diff := cmp.Diff(contextHookData, [2]string{"main", "select count(*) from users"}); diff != "" { + t.Errorf("Context hook result (-got, +want):\n%s", diff) + } + }) + t.Run("UI", func(t *testing.T) { q := make(url.Values) q.Set("q", "select location from users where name = 'alice'")