Skip to content

Commit

Permalink
feat: support gRPC graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Oct 11, 2024
1 parent 4e1dbe9 commit 119c886
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pkg/remote/trans/nphttp2/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo
opts,
p.remoteService,
func(grpc.GoAwayReason) {
// do nothing
p.Clean(network, address)
},
func() {
// do nothing
Expand Down
14 changes: 10 additions & 4 deletions pkg/remote/trans/nphttp2/grpc/controlbuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ func (h *headerFrame) isTransportResponseFrame() bool {
}

type cleanupStream struct {
streamID uint32
rst bool
rstCode http2.ErrCode
onWrite func()
streamID uint32
rst bool
rstCode http2.ErrCode
onWrite func()
onFinishWrite func()
}

func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
Expand Down Expand Up @@ -778,6 +779,11 @@ func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequ
}

func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
defer func() {
if c.onFinishWrite != nil {
c.onFinishWrite()
}
}()
c.onWrite()
if str, ok := l.estdStreams[c.streamID]; ok {
// On the server side it could be a trailers-only response or
Expand Down
52 changes: 52 additions & 0 deletions pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package grpc

import (
"context"
"math"
"testing"
"time"

"github.com/cloudwego/kitex/internal/test"
)

func TestGracefulShutdown(t *testing.T) {
srv, cli := setUp(t, 0, math.MaxUint32, gracefulShutdown)
defer cli.Close(errSelfCloseForTest)

stream, err := cli.NewStream(context.Background(), &CallHdr{})
test.Assert(t, err == nil, err)
<-srv.srvReady
go srv.gracefulShutdown()
err = cli.Write(stream, nil, []byte("hello"), &Options{})
test.Assert(t, err == nil, err)
msg := make([]byte, 5)
num, err := stream.Read(msg)
test.Assert(t, err == nil, err)
test.Assert(t, num == 5, num)
_, err = cli.NewStream(context.Background(), &CallHdr{})
test.Assert(t, err != nil, err)
t.Logf("NewStream err: %v", err)
time.Sleep(1 * time.Second)
err = cli.Write(stream, nil, []byte("hello"), &Options{})
test.Assert(t, err != nil, err)
t.Logf("After timeout, Write err: %v", err)
_, err = stream.Read(msg)
test.Assert(t, err != nil, err)
t.Logf("After timeout, Read err: %v", err)
}
50 changes: 39 additions & 11 deletions pkg/remote/trans/nphttp2/grpc/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package grpc

import (
"context"
"errors"
"fmt"
"io"
"math"
"net"
Expand Down Expand Up @@ -468,9 +470,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
if t.activeStreams == nil { // Can be niled from Close().
// Don't create a stream if the transport is in a state of graceful shutdown or already closed
if t.state == draining || t.activeStreams == nil { // Can be niled from Close().
t.mu.Unlock()
return false // Don't create a stream if the transport is already closed.
return false
}
t.activeStreams[s.id] = s
t.mu.Unlock()
Expand Down Expand Up @@ -533,7 +536,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
)
if err != nil {
rst = true
rstCode = http2.ErrCodeCancel
if errors.Is(err, errGracefulShutdown) {
rstCode = gracefulShutdownCode
} else {
rstCode = http2.ErrCodeCancel
}
}
t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false)
}
Expand Down Expand Up @@ -812,7 +819,13 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.DeadlineExceeded
}
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
var msg string
if statusCode == codes.Unavailable {
msg = gracefulShutdownMsg
} else {
msg = fmt.Sprintf("stream terminated by RST_STREAM with error code: %v", f.ErrCode)
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(statusCode, msg), nil, false)
}

func (t *http2Client) handleSettings(f *grpcframe.SettingsFrame, isFirst bool) {
Expand Down Expand Up @@ -917,29 +930,44 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) {
// Notify the clientconn about the GOAWAY before we set the state to
// draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection.
if t.onGoAway != nil {
t.onGoAway(t.goAwayReason)
if t.state != draining {
if t.onGoAway != nil {
// todo(DMwangnima): remove this lock since it is not necessary
// to hold lock processing onGoAway
t.mu.Unlock()
t.onGoAway(t.goAwayReason)
t.mu.Lock()
}
t.state = draining
}
t.state = draining
}
// All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed.
upperLimit := t.prevGoAwayID
if upperLimit == 0 { // This is the first GoAway Frame.
upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID.
}
t.prevGoAwayID = id
active := len(t.activeStreams)
if active <= 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))
return
}

var unprocessedStream []*Stream
for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server.
atomic.StoreUint32(&stream.unprocessed, 1)
unprocessedStream = append(unprocessedStream, stream)
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}
t.prevGoAwayID = id
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))

for _, stream := range unprocessedStream {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}

Expand Down
53 changes: 48 additions & 5 deletions pkg/remote/trans/nphttp2/grpc/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ import (
"github.com/cloudwego/kitex/pkg/utils"
)

const (
gracefulShutdownCode = http2.ErrCode(20)
gracefulShutdownMsg = "graceful shutdown"
)

var (
// ErrIllegalHeaderWrite indicates that setting header is illegal because of
// the stream's state.
Expand All @@ -67,6 +72,8 @@ var (
errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err()
errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err()
errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err()

errGracefulShutdown = status.Err(codes.Unavailable, gracefulShutdownMsg)
)

func init() {
Expand Down Expand Up @@ -966,6 +973,44 @@ func (t *http2Server) Close() error {
return t.closeWithErr(nil)
}

func (t *http2Server) GracefulClose() {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return
}
t.state = closing
streams := t.activeStreams
t.activeStreams = nil
t.mu.Unlock()

var wg sync.WaitGroup
for _, s := range streams {
oldState := s.swapState(streamDone)
if oldState == streamDone {
// If the stream was already done, continue
continue
}
wg.Add(1)
s.cancel(errGracefulShutdown)
t.controlBuf.put(&cleanupStream{
streamID: s.id,
rst: true,
rstCode: gracefulShutdownCode,
onWrite: func() {},
onFinishWrite: func() {
wg.Done()
},
})
}
wg.Wait()
// make sure all RSTStream Frames are sent out
t.framer.writer.Flush()
t.controlBuf.finish()
close(t.done)
t.conn.Close()
}

func (t *http2Server) closeWithErr(reason error) error {
t.mu.Lock()
if t.state == closing {
Expand Down Expand Up @@ -1081,10 +1126,8 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
if err := t.framer.WriteGoAway(sid, g.code, g.debugData); err != nil {
return false, err
}
t.framer.writer.Flush()
if g.closeConn {
// Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush()
return false, fmt.Errorf("transport: Connection closing")
}
return true, nil
Expand All @@ -1096,15 +1139,15 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
// originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process.
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
if err := t.framer.WritePing(false, goAwayPing.data); err != nil {
return false, err
}

gofunc.RecoverGoFuncWithInfo(context.Background(), func() {
timer := time.NewTimer(time.Minute)
timer := time.NewTimer(10 * time.Second)
defer timer.Stop()
select {
case <-t.drainChan:
Expand Down
1 change: 1 addition & 0 deletions pkg/remote/trans/nphttp2/grpc/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ var (
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
http2.ErrCodeHTTP11Required: codes.Internal,
gracefulShutdownCode: codes.Unavailable,
}
statusCodeConvTab = map[codes.Code]http2.ErrCode{
codes.Internal: http2.ErrCodeInternal,
Expand Down
1 change: 1 addition & 0 deletions pkg/remote/trans/nphttp2/grpc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ type ServerTransport interface {
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
Close() error
GracefulClose()

// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
Expand Down
Loading

0 comments on commit 119c886

Please sign in to comment.