Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server/tailsql: add a query context decorator #20

Merged
merged 1 commit into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions server/tailsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ type Options struct {
// by the rule replaces the original string.
UIRewriteRules []UIRewriteRule `json:"-"`

// If non-nil, this function is called to annotate ctx before passing it in
// to a database query for the given source. If the callback is nil, or if
// it returns nil, ctx is used unmodified. Otherwise the returned value
// replaces ctx in the query.
QueryContext func(ctx context.Context, src, query string) context.Context `json:"-"`

// If non-nil, send logs to this logger. If nil, use log.Printf.
Logf logger.Logf `json:"-"`
}
Expand Down
135 changes: 74 additions & 61 deletions server/tailsql/tailsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ type Server struct {
rules []UIRewriteRule
authorize func(string, *apitype.WhoIsResponse) error
qtimeout time.Duration
qcontext func(ctx context.Context, src, query string) context.Context
logf logger.Logf

mu sync.Mutex
Expand Down Expand Up @@ -166,6 +167,7 @@ func NewServer(opts Options) (*Server, error) {
rules: opts.UIRewriteRules,
authorize: opts.authorize(),
qtimeout: opts.QueryTimeout.Duration(),
qcontext: opts.QueryContext,
logf: opts.logf(),
dbs: dbs,
}, nil
Expand Down Expand Up @@ -438,76 +440,77 @@ func (s *Server) queryContext(ctx context.Context, caller, src, query string) (*
defer cancel()
}

return runQueryInTx(ctx, h, func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) {
start := time.Now()
var out dbResult
defer func() {
out.Elapsed = time.Since(start)
s.logf("[tailsql] query src=%q query=%q elapsed=%v err=%v",
src, query, out.Elapsed.Round(time.Millisecond), err)

// Record successful queries in the persistent log. But don't log
// queries to the state database itself.
if err == nil && src != s.self {
serr := s.state.LogQuery(ctx, caller, src, query)
if serr != nil {
s.logf("[tailsql] WARNING: Error logging query: %v", serr)
return runQueryInTx(s.getQueryContext(ctx, src, query), h,
func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) {
start := time.Now()
var out dbResult
defer func() {
out.Elapsed = time.Since(start)
s.logf("[tailsql] query src=%q query=%q elapsed=%v err=%v",
src, query, out.Elapsed.Round(time.Millisecond), err)

// Record successful queries in the persistent log. But don't log
// queries to the state database itself.
if err == nil && src != s.self {
serr := s.state.LogQuery(ctx, caller, src, query)
if serr != nil {
s.logf("[tailsql] WARNING: Error logging query: %v", serr)
}
}
}
}()
}()

// Check for a named query.
if name, ok := strings.CutPrefix(query, "named:"); ok {
real, ok := lookupNamedQuery(fctx, name)
if !ok {
return nil, statusErrorf(http.StatusBadRequest, "named query %q not recognized", name)
// Check for a named query.
if name, ok := strings.CutPrefix(query, "named:"); ok {
real, ok := lookupNamedQuery(fctx, name)
if !ok {
return nil, statusErrorf(http.StatusBadRequest, "named query %q not recognized", name)
}
s.logf("[tailsql] resolved named query %q to %#q", name, real)
query = real
}
s.logf("[tailsql] resolved named query %q to %#q", name, real)
query = real
}

rows, err := tx.QueryContext(fctx, query)
if err != nil {
return nil, err
}
defer rows.Close()

cols, err := rows.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("listing column types: %w", err)
}
for _, col := range cols {
out.Columns = append(out.Columns, col.Name())
}
rows, err := tx.QueryContext(fctx, query)
if err != nil {
return nil, err
}
defer rows.Close()

var tooMany bool
for rows.Next() && !tooMany {
if len(out.Rows) == maxRowsPerQuery {
tooMany = true
break
} else if fctx.Err() != nil {
return nil, fmt.Errorf("scanning row: %w", fctx.Err())
cols, err := rows.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("listing column types: %w", err)
}
vals := make([]any, len(cols))
vptr := make([]any, len(cols))
for i := range cols {
vptr[i] = &vals[i]
for _, col := range cols {
out.Columns = append(out.Columns, col.Name())
}
if err := rows.Scan(vptr...); err != nil {
return nil, fmt.Errorf("scanning row: %w", err)

var tooMany bool
for rows.Next() && !tooMany {
if len(out.Rows) == maxRowsPerQuery {
tooMany = true
break
} else if fctx.Err() != nil {
return nil, fmt.Errorf("scanning row: %w", fctx.Err())
}
vals := make([]any, len(cols))
vptr := make([]any, len(cols))
for i := range cols {
vptr[i] = &vals[i]
}
if err := rows.Scan(vptr...); err != nil {
return nil, fmt.Errorf("scanning row: %w", err)
}
out.Rows = append(out.Rows, vals)
}
out.Rows = append(out.Rows, vals)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scanning rows: %w", err)
}
out.NumRows = len(out.Rows)
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scanning rows: %w", err)
}
out.NumRows = len(out.Rows)

if tooMany {
return &out, errTooManyRows
}
return &out, nil
})
if tooMany {
return &out, errTooManyRows
}
return &out, nil
})
}

// queryMeta handles meta-queries for internal state.
Expand Down Expand Up @@ -610,3 +613,13 @@ func (s *Server) getHandles() []*dbHandle {
// data are only ever appended to the end.
return s.dbs
}

// getQueryContext decorates ctx if necessary using the context hook for src and query.
func (s *Server) getQueryContext(ctx context.Context, src, query string) context.Context {
if s.qcontext != nil {
if qctx := s.qcontext(ctx, src, query); qctx != nil {
return qctx
}
}
return ctx
}
18 changes: 17 additions & 1 deletion server/tailsql/tailsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/tailscale/setec/client/setec"
"github.com/tailscale/setec/setectest"
"github.com/tailscale/tailsql/authorizer"
Expand Down Expand Up @@ -186,14 +187,20 @@ func TestServer(t *testing.T) {
},
},
}
var contextHookData [2]string
s, err := tailsql.NewServer(tailsql.Options{
LocalClient: fc,
UILinks: []tailsql.UILink{
{Anchor: testAnchor, URL: testURL},
},
UIRewriteRules: testUIRules,
Authorize: authorizer.PeerCaps(nil),
Logf: t.Logf,
QueryContext: func(ctx context.Context, src, query string) context.Context {
contextHookData[0] = src
contextHookData[1] = query
return ctx
},
Logf: t.Logf,
})
if err != nil {
t.Fatalf("NewServer: unexpected error: %v", err)
Expand All @@ -210,6 +217,15 @@ func TestServer(t *testing.T) {
defer htest.Close()
cli := htest.Client()

t.Run("ContextHook", func(t *testing.T) {
q := url.Values{"q": {"select count(*) from users"}}
url := htest.URL + "?" + q.Encode()
mustGet(t, cli, url)
if diff := cmp.Diff(contextHookData, [2]string{"main", "select count(*) from users"}); diff != "" {
t.Errorf("Context hook result (-got, +want):\n%s", diff)
}
})

t.Run("UI", func(t *testing.T) {
q := make(url.Values)
q.Set("q", "select location from users where name = 'alice'")
Expand Down