Skip to content

Commit

Permalink
go: sqle,remotesrv: Implement sql.Session lifecycle callbacks for sql…
Browse files Browse the repository at this point in the history
….Contexts used in remotesrv RPCs.

This PR changes each RPC invocation against the gRPC and HTTP servers
implementing remotesapi and cluster replication to create a sql.Context
which lives the duration of the call. The Session for that call gets
SessionCommand{Begin,End} and SessionEnd lifecycle callbacks so that it
can participate in GC safepoint rendezvous appropriately.

Previously the remotesrv.DBCache and the user/password remotesapi
authentication implementation would simply create new sql.Contexts
whenever they needed them. There could be multiple sql.Contexts for the
single server call.
  • Loading branch information
reltuk committed Jan 28, 2025
1 parent 5b55d32 commit a6b1a26
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 8 deletions.
11 changes: 8 additions & 3 deletions go/cmd/dolt/commands/sqlserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,22 +559,27 @@ func ConfigureServices(
}

listenaddr := fmt.Sprintf(":%d", port)
sqlContextInterceptor := sqle.SqlContextServerInterceptor{
Factory: sqlEngine.NewDefaultContext,
}
args := remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: apiReadOnly || serverConfig.ReadOnly(),
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET,
Options: sqlContextInterceptor.Options(),
HttpInterceptor: sqlContextInterceptor.HTTP(nil),
}
var err error
args.FS = sqlEngine.FileSystem()
args.DBCache, err = sqle.RemoteSrvDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases)
args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases)
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
return err
}

authenticator := newAccessController(sqlEngine.NewDefaultContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig

Expand Down Expand Up @@ -636,7 +641,7 @@ func ConfigureServices(
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer())
clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer())

clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners()
if err != nil {
Expand Down
11 changes: 8 additions & 3 deletions go/libraries/doltcore/sqle/cluster/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,14 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.
listenaddr := c.RemoteSrvListenAddr()
args.HttpListenAddr = listenaddr
args.GrpcListenAddr = listenaddr
args.Options = c.ServerOptions()
ctxInterceptor := sqle.SqlContextServerInterceptor{
Factory: ctxFactory,
}
args.Options = append(args.Options, ctxInterceptor.Options()...)
args.Options = append(args.Options, c.ServerOptions()...)
args.HttpInterceptor = ctxInterceptor.HTTP(args.HttpInterceptor)
var err error
args.DBCache, err = sqle.RemoteSrvDBCache(ctxFactory, sqle.CreateUnknownDatabases)
args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.CreateUnknownDatabases)
if err != nil {
return remotesrv.ServerArgs{}, err
}
Expand All @@ -699,7 +704,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.

keyID := creds.PubKeyToKID(c.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub)
args.HttpInterceptor = JWKSHandlerInterceptor(args.HttpInterceptor, keyIDStr, c.pub)

return args, nil
}
Expand Down
9 changes: 7 additions & 2 deletions go/libraries/doltcore/sqle/cluster/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,21 @@ func (h JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(b)
}

func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler {
func JWKSHandlerInterceptor(existing func(http.Handler) http.Handler, keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler {
jh := JWKSHandler{KeyID: keyID, PublicKey: pub}
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.EscapedPath() == "/.well-known/jwks.json" {
jh.ServeHTTP(w, r)
return
}
h.ServeHTTP(w, r)
})
if existing != nil {
return existing(this)
} else {
return this
}
}
}

Expand Down
88 changes: 88 additions & 0 deletions go/libraries/doltcore/sqle/remotesrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ package sqle

import (
"context"
"errors"
"net/http"

"github.com/dolthub/go-mysql-server/sql"
"google.golang.org/grpc"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/remotesrv"
Expand Down Expand Up @@ -96,3 +99,88 @@ func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessCont
args.Options = append(args.Options, si.Options()...)
return args
}

type SqlContextServerInterceptor struct {
Factory func(context.Context) (*sql.Context, error)
}

type serverStreamWrapper struct {
grpc.ServerStream
ctx context.Context
}

func (s serverStreamWrapper) Context() context.Context {
return s.ctx
}

type sqlContextInterceptorKey struct{}

func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) {
if v := ctx.Value(sqlContextInterceptorKey{}); v != nil {
return v.(*sql.Context), nil
}
return nil, errors.New("misconfiguration; a sql.Context should always be available from the intercetpor chain.")
}

func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
sqlCtx, err := si.Factory(ss.Context())
if err != nil {
return err
}
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
defer sql.SessionEnd(sqlCtx.Session)
newCtx := context.WithValue(ss.Context(), sqlContextInterceptorKey{}, sqlCtx)
newSs := serverStreamWrapper{
ServerStream: ss,
ctx: newCtx,
}
return handler(srv, newSs)
}
}

func (si SqlContextServerInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
sqlCtx, err := si.Factory(ctx)
if err != nil {
return nil, err
}
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
defer sql.SessionEnd(sqlCtx.Session)
newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx)
return handler(newCtx, req)
}
}

func (si SqlContextServerInterceptor) HTTP(existing func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sqlCtx, err := si.Factory(ctx)
if err != nil {
http.Error(w, "could not initialize sql.Context", http.StatusInternalServerError)
return
}
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
defer sql.SessionEnd(sqlCtx.Session)
newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx)
newReq := r.WithContext(newCtx)
h.ServeHTTP(w, newReq)
})
if existing != nil {
return existing(this)
} else {
return this
}
}
}

func (si SqlContextServerInterceptor) Options() []grpc.ServerOption {
return []grpc.ServerOption{
grpc.ChainUnaryInterceptor(si.Unary()),
grpc.ChainStreamInterceptor(si.Stream()),
}
}

0 comments on commit a6b1a26

Please sign in to comment.