From 119c88631d8020cba7a9cf2963c633ebea429392 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Thu, 19 Sep 2024 09:19:18 +0800 Subject: [PATCH] feat: support gRPC graceful shutdown --- pkg/remote/trans/nphttp2/conn_pool.go | 2 +- pkg/remote/trans/nphttp2/grpc/controlbuf.go | 14 +++- .../nphttp2/grpc/graceful_shutdown_test.go | 52 ++++++++++++ pkg/remote/trans/nphttp2/grpc/http2_client.go | 50 ++++++++--- pkg/remote/trans/nphttp2/grpc/http2_server.go | 53 ++++++++++-- pkg/remote/trans/nphttp2/grpc/http_util.go | 1 + pkg/remote/trans/nphttp2/grpc/transport.go | 1 + .../trans/nphttp2/grpc/transport_test.go | 82 +++++++++++++++++-- pkg/remote/trans/nphttp2/server_handler.go | 53 +++++++++++- 9 files changed, 280 insertions(+), 28 deletions(-) create mode 100644 pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go diff --git a/pkg/remote/trans/nphttp2/conn_pool.go b/pkg/remote/trans/nphttp2/conn_pool.go index cdd1a2f785..14c422f123 100644 --- a/pkg/remote/trans/nphttp2/conn_pool.go +++ b/pkg/remote/trans/nphttp2/conn_pool.go @@ -121,7 +121,7 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo opts, p.remoteService, func(grpc.GoAwayReason) { - // do nothing + p.Clean(network, address) }, func() { // do nothing diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 7dd4701124..a6158c882c 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -145,10 +145,11 @@ func (h *headerFrame) isTransportResponseFrame() bool { } type cleanupStream struct { - streamID uint32 - rst bool - rstCode http2.ErrCode - onWrite func() + streamID uint32 + rst bool + rstCode http2.ErrCode + onWrite func() + onFinishWrite func() } func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM @@ -778,6 +779,11 @@ func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequ } func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { + defer func() { + if c.onFinishWrite != nil { + c.onFinishWrite() + } + }() c.onWrite() if str, ok := l.estdStreams[c.streamID]; ok { // On the server side it could be a trailers-only response or diff --git a/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go new file mode 100644 index 0000000000..884aa52503 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go @@ -0,0 +1,52 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "math" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestGracefulShutdown(t *testing.T) { + srv, cli := setUp(t, 0, math.MaxUint32, gracefulShutdown) + defer cli.Close(errSelfCloseForTest) + + stream, err := cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, err == nil, err) + <-srv.srvReady + go srv.gracefulShutdown() + err = cli.Write(stream, nil, []byte("hello"), &Options{}) + test.Assert(t, err == nil, err) + msg := make([]byte, 5) + num, err := stream.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + _, err = cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, err != nil, err) + t.Logf("NewStream err: %v", err) + time.Sleep(1 * time.Second) + err = cli.Write(stream, nil, []byte("hello"), &Options{}) + test.Assert(t, err != nil, err) + t.Logf("After timeout, Write err: %v", err) + _, err = stream.Read(msg) + test.Assert(t, err != nil, err) + t.Logf("After timeout, Read err: %v", err) +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..4db387270b 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -22,6 +22,8 @@ package grpc import ( "context" + "errors" + "fmt" "io" "math" "net" @@ -468,9 +470,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea s.id = h.streamID s.fc = &inFlow{limit: uint32(t.initialWindowSize)} t.mu.Lock() - if t.activeStreams == nil { // Can be niled from Close(). + // Don't create a stream if the transport is in a state of graceful shutdown or already closed + if t.state == draining || t.activeStreams == nil { // Can be niled from Close(). t.mu.Unlock() - return false // Don't create a stream if the transport is already closed. + return false } t.activeStreams[s.id] = s t.mu.Unlock() @@ -533,7 +536,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) { ) if err != nil { rst = true - rstCode = http2.ErrCodeCancel + if errors.Is(err, errGracefulShutdown) { + rstCode = gracefulShutdownCode + } else { + rstCode = http2.ErrCodeCancel + } } t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) } @@ -812,7 +819,13 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { statusCode = codes.DeadlineExceeded } } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) + var msg string + if statusCode == codes.Unavailable { + msg = gracefulShutdownMsg + } else { + msg = fmt.Sprintf("stream terminated by RST_STREAM with error code: %v", f.ErrCode) + } + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(statusCode, msg), nil, false) } func (t *http2Client) handleSettings(f *grpcframe.SettingsFrame, isFirst bool) { @@ -917,10 +930,16 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // Notify the clientconn about the GOAWAY before we set the state to // draining, to allow the client to stop attempting to create streams // before disallowing new streams on this connection. - if t.onGoAway != nil { - t.onGoAway(t.goAwayReason) + if t.state != draining { + if t.onGoAway != nil { + // todo(DMwangnima): remove this lock since it is not necessary + // to hold lock processing onGoAway + t.mu.Unlock() + t.onGoAway(t.goAwayReason) + t.mu.Lock() + } + t.state = draining } - t.state = draining } // All streams with IDs greater than the GoAwayId // and smaller than the previous GoAway ID should be killed. @@ -928,18 +947,27 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { if upperLimit == 0 { // This is the first GoAway Frame. upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. } + t.prevGoAwayID = id + active := len(t.activeStreams) + if active <= 0 { + t.mu.Unlock() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + return + } + + var unprocessedStream []*Stream for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) + unprocessedStream = append(unprocessedStream, stream) t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } - t.prevGoAwayID = id - active := len(t.activeStreams) t.mu.Unlock() - if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + + for _, stream := range unprocessedStream { + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..01f62827fe 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -50,6 +50,11 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) +const ( + gracefulShutdownCode = http2.ErrCode(20) + gracefulShutdownMsg = "graceful shutdown" +) + var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. @@ -67,6 +72,8 @@ var ( errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() + + errGracefulShutdown = status.Err(codes.Unavailable, gracefulShutdownMsg) ) func init() { @@ -966,6 +973,44 @@ func (t *http2Server) Close() error { return t.closeWithErr(nil) } +func (t *http2Server) GracefulClose() { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + + var wg sync.WaitGroup + for _, s := range streams { + oldState := s.swapState(streamDone) + if oldState == streamDone { + // If the stream was already done, continue + continue + } + wg.Add(1) + s.cancel(errGracefulShutdown) + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: true, + rstCode: gracefulShutdownCode, + onWrite: func() {}, + onFinishWrite: func() { + wg.Done() + }, + }) + } + wg.Wait() + // make sure all RSTStream Frames are sent out + t.framer.writer.Flush() + t.controlBuf.finish() + close(t.done) + t.conn.Close() +} + func (t *http2Server) closeWithErr(reason error) error { t.mu.Lock() if t.state == closing { @@ -1081,10 +1126,8 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { if err := t.framer.WriteGoAway(sid, g.code, g.debugData); err != nil { return false, err } + t.framer.writer.Flush() if g.closeConn { - // Abruptly close the connection following the GoAway (via - // loopywriter). But flush out what's inside the buffer first. - t.framer.writer.Flush() return false, fmt.Errorf("transport: Connection closing") } return true, nil @@ -1096,7 +1139,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { // originated before the GoAway reaches the client. // After getting the ack or timer expiration send out another GoAway this // time with an ID of the max stream server intends to process. - if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil { return false, err } if err := t.framer.WritePing(false, goAwayPing.data); err != nil { @@ -1104,7 +1147,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { } gofunc.RecoverGoFuncWithInfo(context.Background(), func() { - timer := time.NewTimer(time.Minute) + timer := time.NewTimer(10 * time.Second) defer timer.Stop() select { case <-t.drainChan: diff --git a/pkg/remote/trans/nphttp2/grpc/http_util.go b/pkg/remote/trans/nphttp2/grpc/http_util.go index ecb6bb7ec1..49af9275cf 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util.go @@ -77,6 +77,7 @@ var ( http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, http2.ErrCodeInadequateSecurity: codes.PermissionDenied, http2.ErrCodeHTTP11Required: codes.Internal, + gracefulShutdownCode: codes.Unavailable, } statusCodeConvTab = map[codes.Code]http2.ErrCode{ codes.Internal: http2.ErrCodeInternal, diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index e2b5f0ef1f..3bbf8d2b63 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -715,6 +715,7 @@ type ServerTransport interface { // should not be accessed any more. All the pending streams and their // handlers will be terminated asynchronously. Close() error + GracefulClose() // RemoteAddr returns the remote network address. RemoteAddr() net.Addr diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index ec98259a52..78b5ebf8e5 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -57,6 +57,10 @@ type server struct { conns map[ServerTransport]bool h *testStreamHandler ready chan struct{} + hdlWG sync.WaitGroup + transWG sync.WaitGroup + + srvReady chan struct{} } var ( @@ -77,6 +81,7 @@ func init() { type testStreamHandler struct { t *http2Server + srv *server notify chan struct{} getNotified chan struct{} } @@ -92,6 +97,8 @@ const ( invalidHeaderField delayRead pingpong + + gracefulShutdown ) func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { @@ -292,6 +299,20 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { } } +func (h *testStreamHandler) gracefulShutdown(t *testing.T, s *Stream) { + close(h.srv.srvReady) + msg := make([]byte, 5) + num, err := s.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + test.Assert(t, string(msg) == "hello", string(msg)) + err = h.t.Write(s, nil, msg, &Options{}) + test.Assert(t, err == nil, err) + _, err = s.Read(msg) + test.Assert(t, err != nil, err) + t.Logf("Server-side after timeout err: %v", err) +} + // start starts server. Other goroutines should block on s.readyChan for further operations. func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { // 创建 listener @@ -329,6 +350,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.conns[transport] = true h := &testStreamHandler{t: transport.(*http2Server)} s.h = h + h.srv = s s.mu.Unlock() switch ht { case notifyCall: @@ -379,12 +401,26 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT }, func(ctx context.Context, method string) context.Context { return ctx }) + case gracefulShutdown: + s.transWG.Add(1) + go func() { + defer s.transWG.Done() + transport.HandleStreams(func(stream *Stream) { + s.hdlWG.Add(1) + go func() { + defer s.hdlWG.Done() + h.gracefulShutdown(t, stream) + }() + }, func(ctx context.Context, method string) context.Context { return ctx }) + }() default: - go transport.HandleStreams(func(s *Stream) { - go h.handleStream(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx - }) + go func() { + transport.HandleStreams(func(s *Stream) { + go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx + }) + }() } return ctx } @@ -434,6 +470,40 @@ func (s *server) stop() { s.mu.Unlock() } +func (s *server) gracefulShutdown() { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + s.lis.Close() + s.mu.Lock() + for trans := range s.conns { + trans.Drain() + } + s.mu.Unlock() + timeout, _ := ctx.Deadline() + graceTimer := time.NewTimer(time.Until(timeout)) + exitCh := make(chan struct{}) + go func() { + select { + case <-graceTimer.C: + s.mu.Lock() + for trans := range s.conns { + trans.Close() + } + s.mu.Unlock() + return + case <-exitCh: + return + } + }() + s.hdlWG.Wait() + s.transWG.Wait() + close(exitCh) + s.conns = nil + if err := s.eventLoop.Shutdown(ctx); err != nil { + fmt.Printf("netpoll server exit failed, err=%v", err) + } +} + func (s *server) addr() string { if s.lis == nil { return "" @@ -442,7 +512,7 @@ func (s *server) addr() string { } func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { - server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} + server := &server{startedErr: make(chan error, 1), ready: make(chan struct{}), srvReady: make(chan struct{})} go server.start(t, port, serverConfig, ht) server.wait(t, time.Second) return server diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index f249f84242..3f19f16c02 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -62,6 +62,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { opt: opt, svcSearcher: opt.SvcSearcher, codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), + transports: make(map[grpcTransport.ServerTransport]struct{}), }, nil } @@ -72,6 +73,11 @@ type svrTransHandler struct { svcSearcher remote.ServiceSearcher inkHdlFunc endpoint.Endpoint codec remote.Codec + mu sync.Mutex + transports map[grpcTransport.ServerTransport]struct{} + + hdlWG sync.WaitGroup + transWG sync.WaitGroup } var prefaceReadAtMost = func() int { @@ -119,9 +125,11 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) tr := svrTrans.tr - + defer t.transWG.Done() tr.HandleStreams(func(s *grpcTransport.Stream) { + t.hdlWG.Add(1) gofunc.GoFunc(ctx, func() { + defer t.hdlWG.Done() t.handleFunc(s, svrTrans, conn) }) }, func(ctx context.Context, method string) context.Context { @@ -315,6 +323,10 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. if err != nil { return nil, err } + t.transWG.Add(1) + t.mu.Lock() + t.transports[tr] = struct{}{} + t.mu.Unlock() pool := &sync.Pool{ New: func() interface{} { // init rpcinfo @@ -329,6 +341,9 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. // 连接关闭时回调 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr + t.mu.Lock() + delete(t.transports, tr) + t.mu.Unlock() tr.Close() } @@ -349,6 +364,42 @@ func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { } +func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { + t.mu.Lock() + for trans := range t.transports { + trans.Drain() + } + t.mu.Unlock() + + exitCh := make(chan struct{}) + // todo: think about a better grace time duration + graceTime := time.Minute * 3 + exitTimeout, ok := ctx.Deadline() + if ok { + graceTime = time.Until(exitTimeout) + } + graceTimer := time.NewTimer(graceTime) + gofunc.GoFunc(ctx, func() { + select { + case <-graceTimer.C: + t.mu.Lock() + for trans := range t.transports { + trans.GracefulClose() + } + t.mu.Unlock() + return + case <-exitCh: + return + } + }) + + t.hdlWG.Wait() + t.transWG.Wait() + close(exitCh) + + return nil +} + func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { c := t.opt.TracerCtl.DoStart(ctx, ri) return c