diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index de11c47..795467b 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -18,7 +18,7 @@ jobs: steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' - uses: contributor-assistant/github-action@v2.3.1 + uses: contributor-assistant/github-action@v2.3.2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_DATABASE_ACCESS_TOKEN }} @@ -28,4 +28,5 @@ jobs: path-to-signatures: 'signatures/${{ github.event.repository.name }}-${{ github.repository_id }}/cla.json' path-to-document: 'https://github.com/trpc-group/cla-database/blob/main/Tencent-Contributor-License-Agreement.md' # branch should not be protected - branch: 'main' \ No newline at end of file + branch: 'main' + allowlist: dependabot \ No newline at end of file diff --git a/client/client.go b/client/client.go index 8860c91..946d58b 100644 --- a/client/client.go +++ b/client/client.go @@ -393,7 +393,7 @@ func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next if err != nil { return OptionsFromContext(ctx).fixTimeout(err) } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) // Start to process the next filter and report. begin := time.Now() @@ -471,11 +471,21 @@ func getNode(opts *Options) (*registry.Node, error) { return node, nil } -func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { +func ensureMsgRemoteAddr( + msg codec.Msg, + network, address string, + parseAddr func(network, address string) net.Addr, +) { // If RemoteAddr has already been set, just return. if msg.RemoteAddr() != nil { return } + + if parseAddr != nil { + msg.WithRemoteAddr(parseAddr(network, address)) + return + } + switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": // Check if address can be parsed as an ip. @@ -484,7 +494,6 @@ func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { return } } - var addr net.Addr switch network { case "tcp", "tcp4", "tcp6": diff --git a/client/client_test.go b/client/client_test.go index 3734816..22fb5be 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,6 +16,8 @@ package client_test import ( "context" "errors" + "fmt" + "net" "testing" "time" @@ -409,6 +411,31 @@ func TestFixTimeout(t *testing.T) { }) } +func TestSelectorRemoteAddrUseUserProvidedParser(t *testing.T) { + selector.Register(t.Name(), &fSelector{ + selectNode: func(s string, option ...selector.Option) (*registry.Node, error) { + return ®istry.Node{ + Network: t.Name(), + Address: t.Name(), + ParseAddr: func(network, address string) net.Addr { + return newUnresolvedAddr(network, address) + }}, nil + }, + report: func(node *registry.Node, duration time.Duration, err error) error { return nil }, + }) + fake := "fake" + codec.Register(fake, nil, &fakeCodec{}) + ctx := trpc.BackgroundContext() + require.NotNil(t, client.New().Invoke(ctx, "failbody", nil, + client.WithServiceName(t.Name()), + client.WithProtocol(fake), + client.WithTarget(fmt.Sprintf("%s://xxx", t.Name())))) + addr := trpc.Message(ctx).RemoteAddr() + require.NotNil(t, addr) + require.Equal(t, t.Name(), addr.Network()) + require.Equal(t, t.Name(), addr.String()) +} + type multiplexedTransport struct { require func(context.Context, []byte, ...transport.RoundTripOption) fakeTransport @@ -423,7 +450,11 @@ func (t *multiplexedTransport) RoundTrip( return t.fakeTransport.RoundTrip(ctx, req, opts...) } -type fakeTransport struct{} +type fakeTransport struct { + send func() error + recv func() ([]byte, error) + close func() +} func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) { @@ -447,18 +478,15 @@ func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, } func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { + if c.send != nil { + return c.send() + } return nil } func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) { - body, ok := ctx.Value("recv-decode-error").(string) - if ok { - return []byte(body), nil - } - - err, ok := ctx.Value("recv-error").(string) - if ok { - return nil, errors.New(err) + if c.recv != nil { + return c.recv() } return []byte("body"), nil } @@ -467,7 +495,9 @@ func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOpt return nil } func (c *fakeTransport) Close(ctx context.Context) { - return + if c.close != nil { + c.close() + } } type fakeCodec struct { @@ -524,3 +554,39 @@ func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*regi func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error { return nil } + +type fSelector struct { + selectNode func(string, ...selector.Option) (*registry.Node, error) + report func(*registry.Node, time.Duration, error) error +} + +func (s *fSelector) Select(serviceName string, opts ...selector.Option) (*registry.Node, error) { + return s.selectNode(serviceName, opts...) +} + +func (s *fSelector) Report(node *registry.Node, cost time.Duration, err error) error { + return s.report(node, cost, err) +} + +// newUnresolvedAddr returns a new unresolvedAddr. +func newUnresolvedAddr(network, address string) *unresolvedAddr { + return &unresolvedAddr{network: network, address: address} +} + +var _ net.Addr = (*unresolvedAddr)(nil) + +// unresolvedAddr is a net.Addr which returns the original network or address. +type unresolvedAddr struct { + network string + address string +} + +// Network returns the unresolved original network. +func (a *unresolvedAddr) Network() string { + return a.network +} + +// String returns the unresolved original address. +func (a *unresolvedAddr) String() string { + return a.address +} diff --git a/client/config.go b/client/config.go index e1747ab..64c12ea 100644 --- a/client/config.go +++ b/client/config.go @@ -374,6 +374,11 @@ func RegisterConfig(conf map[string]*BackendConfig) error { // RegisterClientConfig is called to replace backend config of single callee service by name. func RegisterClientConfig(callee string, conf *BackendConfig) error { + if callee == "*" { + // Reset the callee and service name to enable wildcard matching. + conf.Callee = "" + conf.ServiceName = "" + } opts, err := conf.genOptions() if err != nil { return err diff --git a/client/config_test.go b/client/config_test.go index 9de0ef9..639877a 100644 --- a/client/config_test.go +++ b/client/config_test.go @@ -312,3 +312,25 @@ func TestConfig(t *testing.T) { } require.Nil(t, client.RegisterClientConfig("trpc.test.helloworld3", backconfig)) } + +func TestRegisterWildcardClient(t *testing.T) { + cfg := client.Config("*") + t.Cleanup(func() { + client.RegisterClientConfig("*", cfg) + }) + client.RegisterClientConfig("*", &client.BackendConfig{ + DisableServiceRouter: true, + }) + + ch := make(chan *client.Options, 1) + c := client.New() + ctx, _ := codec.EnsureMessage(context.Background()) + require.Nil(t, c.Invoke(ctx, nil, nil, client.WithFilter( + func(ctx context.Context, _, _ interface{}, _ filter.ClientHandleFunc) error { + ch <- client.OptionsFromContext(ctx) + // Skip next. + return nil + }))) + opts := <-ch + require.True(t, opts.DisableServiceRouter) +} diff --git a/client/stream.go b/client/stream.go index a40ac8e..fbfd136 100644 --- a/client/stream.go +++ b/client/stream.go @@ -66,11 +66,16 @@ type RecvControl interface { // It serializes the message and sends it to server through stream transport. // It's safe to call Recv and Send in different goroutines concurrently, but calling // Send in different goroutines concurrently is not thread-safe. -func (s *stream) Send(ctx context.Context, m interface{}) error { +func (s *stream) Send(ctx context.Context, m interface{}) (err error) { + defer func() { + if err != nil { + s.opts.StreamTransport.Close(ctx) + } + }() + msg := codec.Message(ctx) reqBodyBuf, err := serializeAndCompress(ctx, msg, m, s.opts) if err != nil { - s.opts.StreamTransport.Close(ctx) return err } @@ -87,7 +92,6 @@ func (s *stream) Send(ctx context.Context, m interface{}) error { } if err := s.opts.StreamTransport.Send(ctx, reqBuf); err != nil { - s.opts.StreamTransport.Close(ctx) return err } return nil @@ -97,18 +101,24 @@ func (s *stream) Send(ctx context.Context, m interface{}) error { // It decodes and decompresses the message and leaves serialization to upper layer. // It's safe to call Recv and Send in different goroutines concurrently, but calling // Send in different goroutines concurrently is not thread-safe. -func (s *stream) Recv(ctx context.Context) ([]byte, error) { +func (s *stream) Recv(ctx context.Context) (buf []byte, err error) { + defer func() { + if err != nil { + s.opts.StreamTransport.Close(ctx) + } + }() rspBuf, err := s.opts.StreamTransport.Recv(ctx) if err != nil { - s.opts.StreamTransport.Close(ctx) return nil, err } msg := codec.Message(ctx) rspBodyBuf, err := s.opts.Codec.Decode(msg, rspBuf) if err != nil { - s.opts.StreamTransport.Close(ctx) return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decode: "+err.Error()) } + if err := msg.ClientRspErr(); err != nil { + return nil, err + } if len(rspBodyBuf) > 0 { compressType := msg.CompressType() if icodec.IsValidCompressType(s.opts.CurrentCompressType) { @@ -118,9 +128,7 @@ func (s *stream) Recv(ctx context.Context) ([]byte, error) { if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { rspBodyBuf, err = codec.Decompress(compressType, rspBodyBuf) if err != nil { - s.opts.StreamTransport.Close(ctx) - return nil, - errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error()) + return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error()) } } } @@ -154,7 +162,7 @@ func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) { report.SelectNodeFail.Incr() return nil, err } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) const invalidCost = -1 opts.Node.set(node, node.Address, invalidCost) if opts.Codec == nil { diff --git a/client/stream_test.go b/client/stream_test.go index 6b9d398..24a9bd7 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -15,6 +15,7 @@ package client_test import ( "context" + "errors" "testing" "time" @@ -37,85 +38,123 @@ func TestStream(t *testing.T) { // calling without error streamCli := client.NewStream() - require.NotNil(t, streamCli) - opts, err := streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - err = streamCli.Send(ctx, reqBody) - require.Nil(t, err) - rsp, err := streamCli.Recv(ctx) - require.Nil(t, err) - require.Equal(t, []byte("body"), rsp) - err = streamCli.Close(ctx) - require.Nil(t, err) + t.Run("calling without error", func(t *testing.T) { + require.NotNil(t, streamCli) + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + err = streamCli.Send(ctx, reqBody) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, err) + require.Equal(t, []byte("body"), rsp) + err = streamCli.Close(ctx) + require.Nil(t, err) + }) - // test nil Codec - opts, err = streamCli.Init(ctx, - client.WithTarget("ip://127.0.0.1:8080"), - client.WithTimeout(time.Second), - client.WithProtocol("fake-nil"), - client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{})) - require.NotNil(t, err) - require.Nil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) + t.Run("test nil Codec", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8080"), + client.WithTimeout(time.Second), + client.WithProtocol("fake-nil"), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{})) + require.NotNil(t, err) + require.Nil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + }) - // test selectNode with error - opts, err = streamCli.Init(ctx, client.WithTarget("ip/:/127.0.0.1:8080"), - client.WithProtocol("fake")) - require.NotNil(t, err) - require.Contains(t, err.Error(), "invalid") - require.Nil(t, opts) - - // test stream recv failure - ctx = context.WithValue(ctx, "recv-error", "recv failed") - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - rsp, err = streamCli.Recv(ctx) - require.Nil(t, rsp) - require.NotNil(t, err) - - // test decode failure - ctx = context.WithValue(ctx, "recv-decode-error", "businessfail") - rsp, err = streamCli.Recv(ctx) - require.Nil(t, rsp) - require.NotNil(t, err) - - // test compress failure - ctx = context.Background() - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeGzip), - client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - _, err = streamCli.Recv(ctx) - require.NotNil(t, err) - - // test compress without error - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeNoop), - client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - rsp, err = streamCli.Recv(ctx) - require.Nil(t, err) - require.NotNil(t, rsp) + t.Run("test selectNode with error", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip/:/127.0.0.1:8080"), + client.WithProtocol("fake"), + ) + require.NotNil(t, err) + require.Contains(t, err.Error(), "invalid") + require.Nil(t, opts) + }) + + t.Run("test stream recv failure", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return nil, errors.New("recv failed") + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, rsp) + require.NotNil(t, err) + }) + + t.Run("test decode failure", func(t *testing.T) { + _, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return []byte("businessfail"), nil + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, rsp) + require.NotNil(t, err) + }) + + t.Run("test compress failure", func(t *testing.T) { + opts, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithCurrentCompressType(codec.CompressTypeGzip), + client.WithProtocol("fake")) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + _, err = streamCli.Recv(ctx) + require.NotNil(t, err) + }) + + t.Run("test compress without error", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithCurrentCompressType(codec.CompressTypeNoop), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, err) + require.NotNil(t, rsp) + }) } func TestGetStreamFilter(t *testing.T) { @@ -151,3 +190,46 @@ func TestStreamGetAddress(t *testing.T) { require.NotNil(t, msg.RemoteAddr()) require.Equal(t, addr, msg.RemoteAddr().String()) } + +func TestStreamCloseTransport(t *testing.T) { + codec.Register("fake", nil, &fakeCodec{}) + t.Run("close transport when send fail", func(t *testing.T) { + var isClose bool + streamCli := client.NewStream() + _, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(&fakeTransport{ + send: func() error { + return errors.New("expected error") + }, + close: func() { + isClose = true + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, streamCli.Send(context.Background(), nil)) + require.True(t, isClose) + }) + t.Run("close transport when recv fail", func(t *testing.T) { + var isClose bool + streamCli := client.NewStream() + _, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return nil, errors.New("expected error") + }, + close: func() { + isClose = true + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + _, err = streamCli.Recv(context.Background()) + require.NotNil(t, err) + require.True(t, isClose) + }) +} diff --git a/codec.go b/codec.go index 3c690c2..2fa0d0d 100644 --- a/codec.go +++ b/codec.go @@ -27,7 +27,6 @@ import ( "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/internal/attachment" - icodec "trpc.group/trpc-go/trpc-go/internal/codec" "trpc.group/trpc-go/trpc-go/transport" "google.golang.org/protobuf/proto" @@ -314,9 +313,7 @@ func msgWithRequestProtocol(msg codec.Msg, req *trpcpb.RequestProtocol, attm []b msg.WithCallerServiceName(string(req.GetCaller())) msg.WithCalleeServiceName(string(req.GetCallee())) // set server handler method name - rpcName := string(req.GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(req.GetFunc())) // set body serialization type msg.WithSerializationType(int(req.GetContentType())) // set body compression type @@ -665,19 +662,6 @@ func loadOrStoreDefaultUnaryFrameHead(msg codec.Msg) *FrameHead { return frameHead } -func copyRspHead(dst, src *trpcpb.ResponseProtocol) { - dst.Version = src.Version - dst.CallType = src.CallType - dst.RequestId = src.RequestId - dst.Ret = src.Ret - dst.FuncRet = src.FuncRet - dst.ErrorMsg = src.ErrorMsg - dst.MessageType = src.MessageType - dst.TransInfo = src.TransInfo - dst.ContentType = src.ContentType - dst.ContentEncoding = src.ContentEncoding -} - func updateMsg(msg codec.Msg, frameHead *FrameHead, rsp *trpcpb.ResponseProtocol, attm []byte) error { msg.WithFrameHead(frameHead) msg.WithCompressType(int(rsp.GetContentEncoding())) diff --git a/codec/message_impl.go b/codec/message_impl.go index e4d78c0..1cc89e0 100644 --- a/codec/message_impl.go +++ b/codec/message_impl.go @@ -259,6 +259,10 @@ func (m *msg) WithClientRPCName(s string) { } func (m *msg) updateMethodNameUsingRPCName(s string) { + if rpcNameIsTRPCForm(s) { + m.WithCalleeMethod(methodFromRPCName(s)) + return + } if m.CalleeMethod() == "" { m.WithCalleeMethod(s) } @@ -616,6 +620,9 @@ func WithCloneContextAndMessage(ctx context.Context) (context.Context, Msg) { // copyCommonMessage copy common data of message. func copyCommonMessage(m *msg, newMsg *msg) { + // Do not copy compress type here, as it will cause subsequence RPC calls to inherit the upstream + // compress type which is not the expected behavior. Compress type should not be propagated along + // the entire RPC invocation chain. newMsg.frameHead = m.frameHead newMsg.requestTimeout = m.requestTimeout newMsg.serializationType = m.serializationType @@ -729,3 +736,48 @@ func getAppServerService(s string) (app, server, service string) { service = s[j:] return } + +// methodFromRPCName returns the method parsed from rpc string. +func methodFromRPCName(s string) string { + return s[strings.LastIndex(s, "/")+1:] +} + +// rpcNameIsTRPCForm checks whether the given string is of trpc form. +// It is equivalent to: +// +// var r = regexp.MustCompile(`^/[^/.]+\.[^/]+/[^/.]+$`) +// +// func rpcNameIsTRPCForm(s string) bool { +// return r.MatchString(s) +// } +// +// But regexp is much slower than the current version. +// Refer to BenchmarkRPCNameIsTRPCForm in message_bench_test.go. +func rpcNameIsTRPCForm(s string) bool { + if len(s) == 0 { + return false + } + if s[0] != '/' { // ^/ + return false + } + const start = 1 + firstDot := strings.Index(s[start:], ".") + if firstDot == -1 || firstDot == 0 { // [^.]+\. + return false + } + if strings.Contains(s[start:start+firstDot], "/") { // [^/]+\. + return false + } + secondSlash := strings.Index(s[start+firstDot:], "/") + if secondSlash == -1 || secondSlash == 1 { // [^/]+/ + return false + } + if start+firstDot+secondSlash == len(s)-1 { // The second slash should not be the last character. + return false + } + const offset = 1 + if strings.ContainsAny(s[start+firstDot+secondSlash+offset:], "/.") { // [^/.]+$ + return false + } + return true +} diff --git a/codec/message_internal_test.go b/codec/message_internal_test.go new file mode 100644 index 0000000..d9a546b --- /dev/null +++ b/codec/message_internal_test.go @@ -0,0 +1,67 @@ +package codec + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkRPCNameIsTRPCForm(b *testing.B) { + rpcNames := []string{ + "/trpc.app.server.service/method", + "/sdadfasd/xadfasdf/zxcasd/asdfasd/v2", + "trpc.app.server.service", + "/trpc.app.server.service", + "/trpc.app.", + "/trpc/asdf/asdf", + "/trpc.asdfasdf/asdfasdf/sdfasdfa/", + "/trpc.app/method/", + "/trpc.app/method/hhhhh", + } + b.Run("bench regexp", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for j := range rpcNames { + rpcNameIsTRPCFormRegExp(rpcNames[j]) + } + } + }) + b.Run("bench vanilla", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for j := range rpcNames { + rpcNameIsTRPCForm(rpcNames[j]) + } + } + }) +} + +func TestEnsureEqualSemacticOfTRPCFormChecking(t *testing.T) { + rpcNames := []string{ + "/trpc.app.server.service/method", + "/trpc.app.server.service/", + "/trpc", + "//", + "/./", + "/xx/.", + "/x./method", + "/.x/method", + "/sdadfasd/xadfasdf/zxcasd/asdfasd/v2", + "trpc.app.server.service", + "/trpc.app.server.service", + "/trpc.app.", + "/trpc/asdf/asdf", + "/trpc.asdfasdf/asdfasdf/sdfasdfa/", + "/trpc.app/method/", + "/trpc.app/method/hhhhh", + } + for _, s := range rpcNames { + v1, v2 := rpcNameIsTRPCFormRegExp(s), rpcNameIsTRPCForm(s) + require.True(t, v1 == v2, "%s %v %v", s, v1, v2) + } +} + +var r = regexp.MustCompile(`^/[^/.]+\.[^/]+/[^/.]+$`) + +func rpcNameIsTRPCFormRegExp(s string) bool { + return r.MatchString(s) +} diff --git a/codec/message_test.go b/codec/message_test.go index 31f9308..2f49427 100644 --- a/codec/message_test.go +++ b/codec/message_test.go @@ -477,3 +477,44 @@ func TestEnsureMessage(t *testing.T) { require.Equal(t, ctx, newCtx) require.Equal(t, msg, newMsg) } + +func TestSetMethodNameUsingRPCName(t *testing.T) { + msg := codec.Message(context.Background()) + testSetMethodNameUsingRPCName(t, msg, msg.WithServerRPCName) + testSetMethodNameUsingRPCName(t, msg, msg.WithClientRPCName) +} + +func testSetMethodNameUsingRPCName(t *testing.T, msg codec.Msg, msgWithRPCName func(string)) { + var cases = []struct { + name string + originalMethod string + rpcName string + expectMethod string + }{ + {"normal trpc rpc name", "", "/trpc.app.server.service/method", "method"}, + {"normal http url path", "", "/v1/subject/info/get", "/v1/subject/info/get"}, + {"invalid trpc rpc name (method name is empty)", "", "trpc.app.server.service", "trpc.app.server.service"}, + {"invalid trpc rpc name (method name is not mepty)", "/v1/subject/info/get", "trpc.app.server.service", "/v1/subject/info/get"}, + {"valid trpc rpc name will override existing method name", "/v1/subject/info/get", "/trpc.app.server.service/method", "method"}, + {"invalid trpc rpc will not override existing method name", "/v1/subject/info/get", "/trpc.app.server.service", "/v1/subject/info/get"}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + resetMsgRPCNameAndMethodName(msg) + msg.WithCalleeMethod(tt.originalMethod) + msgWithRPCName(tt.rpcName) + method := msg.CalleeMethod() + if method != tt.expectMethod { + t.Errorf("given original method %s and rpc name %s, expect new method name %s, got %s", + tt.originalMethod, tt.rpcName, tt.expectMethod, method) + } + }) + } +} + +func resetMsgRPCNameAndMethodName(msg codec.Msg) { + msg.WithCalleeMethod("") + msg.WithClientRPCName("") + msg.WithServerRPCName("") +} diff --git a/codec/serialization_json.go b/codec/serialization_json.go index f7fad9e..10dfc66 100644 --- a/codec/serialization_json.go +++ b/codec/serialization_json.go @@ -17,8 +17,15 @@ import ( jsoniter "github.com/json-iterator/go" ) -// JSONAPI is json packing and unpacking object, users can change -// the internal parameter. +// JSONAPI is used by tRPC JSON serialization when the object does +// not conform to protobuf proto.Message interface. +// +// Deprecated: This global variable is exportable due to backward comparability issue but +// should not be modified. If users want to change the default behavior of +// internal JSON serialization, please use register your customized serializer +// function like: +// +// codec.RegisterSerializer(codec.SerializationTypeJSON, yourOwnJSONSerializer) var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary // JSONSerialization provides json serialization mode. diff --git a/codec/serialization_jsonpb.go b/codec/serialization_jsonpb.go index cc50d45..79e245c 100644 --- a/codec/serialization_jsonpb.go +++ b/codec/serialization_jsonpb.go @@ -29,7 +29,9 @@ var Marshaler = protojson.MarshalOptions{EmitUnpopulated: true, UseProtoNames: t var Unmarshaler = protojson.UnmarshalOptions{DiscardUnknown: false} // JSONPBSerialization provides jsonpb serialization mode. It is based on -// protobuf/jsonpb. +// protobuf/jsonpb. This serializer will firstly try jsonpb's serialization. If +// object does not conform to protobuf proto.Message interface, json-iterator +// will be used. type JSONPBSerialization struct{} // Unmarshal deserialize the in bytes into body. diff --git a/codec_stream.go b/codec_stream.go index a7e6bad..d6ee81d 100644 --- a/codec_stream.go +++ b/codec_stream.go @@ -23,7 +23,6 @@ import ( "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/internal/addrutil" - icodec "trpc.group/trpc-go/trpc-go/internal/codec" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "google.golang.org/protobuf/proto" @@ -377,9 +376,7 @@ func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error { defer s.m.RUnlock() if streamIDToInitMeta, ok := s.initMetas[addr]; ok { if initMeta, ok := streamIDToInitMeta[streamID]; ok { - rpcName := string(initMeta.GetRequestMeta().GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(initMeta.GetRequestMeta().GetFunc())) return nil } } @@ -468,9 +465,7 @@ func (s *ServerStreamCodec) updateMsg(msg codec.Msg, initMeta *trpcpb.TrpcStream msg.WithCallerServiceName(string(req.GetCaller())) msg.WithCalleeServiceName(string(req.GetCallee())) // set server handler method name - rpcName := string(req.GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(req.GetFunc())) // set body serialization type msg.WithSerializationType(int(initMeta.GetContentType())) // set body compression type diff --git a/config.go b/config.go index 6ea24e6..fc8184c 100644 --- a/config.go +++ b/config.go @@ -25,6 +25,7 @@ import ( "time" yaml "gopkg.in/yaml.v3" + "trpc.group/trpc-go/trpc-go/internal/expandenv" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "trpc.group/trpc-go/trpc-go/client" @@ -608,11 +609,8 @@ func parseConfigFromFile(configPath string) (*Config, error) { if err != nil { return nil, err } - // expand environment variables - buf = []byte(expandEnv(string(buf))) - cfg := defaultConfig() - if err := yaml.Unmarshal(buf, cfg); err != nil { + if err := yaml.Unmarshal(expandenv.ExpandEnv(buf), cfg); err != nil { return nil, err } return cfg, nil @@ -678,6 +676,17 @@ func RepairConfig(cfg *Config) error { } codec.SetReaderSize(*cfg.Global.ReadBufferSize) + // nic -> ip + if err := repairServiceIPWithNic(cfg); err != nil { + return err + } + + // Set empty ip to "0.0.0.0" to prevent malformed key matching + // for passed listeners during hot restart. + const defaultIP = "0.0.0.0" + setDefault(&cfg.Global.LocalIP, defaultIP) + setDefault(&cfg.Server.Admin.IP, cfg.Global.LocalIP) + // protocol network ip empty for _, serviceCfg := range cfg.Server.Service { setDefault(&serviceCfg.Protocol, cfg.Server.Protocol) diff --git a/config/README.md b/config/README.md index 7f4320e..5ac469b 100644 --- a/config/README.md +++ b/config/README.md @@ -31,7 +31,7 @@ For managing business configuration, we recommend the best practice of using a c ## What is Multiple Data Sources? -A data source is the source from which configuration is retrieved and where it is stored. Common data sources include: file, etcd, configmap, etc. The tRPC framework supports setting different data sources for different business configurations. The framework uses a plugin-based approach to extend support for more data sources. In the implementation principle section later, we will describe in detail how the framework supports multiple data sources. +A data source is the source from which configuration is retrieved and where it is stored. Common data sources include: file, etcd, configmap, env, etc. The tRPC framework supports setting different data sources for different business configurations. The framework uses a plugin-based approach to extend support for more data sources. In the implementation principle section later, we will describe in detail how the framework supports multiple data sources. ## What is Codec? diff --git a/config/README.zh_CN.md b/config/README.zh_CN.md index 7642b1f..57c6b34 100644 --- a/config/README.zh_CN.md +++ b/config/README.zh_CN.md @@ -26,7 +26,7 @@ 业务配置也支持本地文件。对于本地文件,大部分使用场景是客户端作为独立的工具使用,或者程序在开发调试阶段使用。好处在于不需要依赖外部系统就能工作。 ## 什么是多数据源 -数据源就获取配置的来源,配置存储的地方。常见的数据源包括:file,etcd,configmap 等。tRPC 框架支持对不同业务配置设定不同的数据源。框架采用插件化方式来扩展对更多数据源的支持。在后面的实现原理章节,我们会详细介绍框架是如何实现对多数据源的支持的。 +数据源就获取配置的来源,配置存储的地方。常见的数据源包括:file,etcd,configmap,env 等。tRPC 框架支持对不同业务配置设定不同的数据源。框架采用插件化方式来扩展对更多数据源的支持。在后面的实现原理章节,我们会详细介绍框架是如何实现对多数据源的支持的。 ## 什么是 Codec 业务配置中的 Codec 是指从配置源获取到的配置的格式,常见的配置文件格式为:yaml,json,toml 等。框架采用插件化方式来扩展对更多解码格式的支持。 diff --git a/config/config.go b/config/config.go index 3648305..8bf13bf 100644 --- a/config/config.go +++ b/config/config.go @@ -23,7 +23,6 @@ import ( "sync" "github.com/BurntSushi/toml" - yaml "gopkg.in/yaml.v3" ) diff --git a/config/options.go b/config/options.go index dc6d2f9..ef7144b 100644 --- a/config/options.go +++ b/config/options.go @@ -27,6 +27,28 @@ func WithProvider(name string) LoadOption { } } +// WithExpandEnv replaces ${var} in raw bytes with environment value of var. +// Note, method TrpcConfig.Bytes will return the replaced bytes. +func WithExpandEnv() LoadOption { + return func(c *TrpcConfig) { + c.expandEnv = true + } +} + +// WithWatch returns an option to start watch model +func WithWatch() LoadOption { + return func(c *TrpcConfig) { + c.watch = true + } +} + +// WithWatchHook returns an option to set log func for config change logger +func WithWatchHook(f func(msg WatchMessage)) LoadOption { + return func(c *TrpcConfig) { + c.watchHook = f + } +} + // options is config option. type options struct{} diff --git a/config/trpc_config.go b/config/trpc_config.go index dea82c6..4a88cdf 100644 --- a/config/trpc_config.go +++ b/config/trpc_config.go @@ -23,6 +23,7 @@ import ( "github.com/BurntSushi/toml" "github.com/spf13/cast" yaml "gopkg.in/yaml.v3" + "trpc.group/trpc-go/trpc-go/internal/expandenv" "trpc.group/trpc-go/trpc-go/log" ) @@ -49,68 +50,59 @@ type LoadOption func(*TrpcConfig) // TrpcConfigLoader is a config loader for trpc. type TrpcConfigLoader struct { - configMap map[string]Config - rwl sync.RWMutex + watchers sync.Map } // Load returns the config specified by input parameter. func (loader *TrpcConfigLoader) Load(path string, opts ...LoadOption) (Config, error) { - yc := newTrpcConfig(path) - for _, o := range opts { - o(yc) - } - if yc.decoder == nil { - return nil, ErrCodecNotExist - } - if yc.p == nil { - return nil, ErrProviderNotExist + c, err := newTrpcConfig(path, opts...) + if err != nil { + return nil, err } - key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path) - loader.rwl.RLock() - if c, ok := loader.configMap[key]; ok { - loader.rwl.RUnlock() - return c, nil + w := &watcher{} + i, loaded := loader.watchers.LoadOrStore(c.p, w) + if !loaded { + c.p.Watch(w.watch) + } else { + w = i.(*watcher) } - loader.rwl.RUnlock() - if err := yc.Load(); err != nil { + c = w.getOrCreate(c.path).getOrStore(c) + if err = c.init(); err != nil { return nil, err } - - loader.rwl.Lock() - loader.configMap[key] = yc - loader.rwl.Unlock() - - yc.p.Watch(func(p string, data []byte) { - if p == path { - loader.rwl.Lock() - delete(loader.configMap, key) - loader.rwl.Unlock() - } - }) - return yc, nil + return c, nil } // Reload reloads config data. func (loader *TrpcConfigLoader) Reload(path string, opts ...LoadOption) error { - yc := newTrpcConfig(path) - for _, o := range opts { - o(yc) + c, err := newTrpcConfig(path, opts...) + if err != nil { + return err } - key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path) - loader.rwl.RLock() - if config, ok := loader.configMap[key]; ok { - loader.rwl.RUnlock() - config.Reload() - return nil + + v, ok := loader.watchers.Load(c.p) + if !ok { + return ErrConfigNotExist + } + w := v.(*watcher) + + s := w.get(path) + if s == nil { + return ErrConfigNotExist } - loader.rwl.RUnlock() - return ErrConfigNotExist + + oc := s.get(c.id) + if oc == nil { + return ErrConfigNotExist + } + + return oc.Load() } func newTrpcConfigLoad() *TrpcConfigLoader { - return &TrpcConfigLoader{configMap: map[string]Config{}, rwl: sync.RWMutex{}} + return &TrpcConfigLoader{} } // DefaultConfigLoader is the default config loader. @@ -155,63 +147,256 @@ func (c *TomlCodec) Unmarshal(in []byte, out interface{}) error { return toml.Unmarshal(in, out) } +// watch manage one data provider +type watcher struct { + sets sync.Map // *set +} + +// get config item by path +func (w *watcher) get(path string) *set { + if i, ok := w.sets.Load(path); ok { + return i.(*set) + } + return nil +} + +// getOrCreate get config item by path if not exist and create and return +func (w *watcher) getOrCreate(path string) *set { + i, _ := w.sets.LoadOrStore(path, &set{}) + return i.(*set) +} + +// watch func +func (w *watcher) watch(path string, data []byte) { + if v := w.get(path); v != nil { + v.watch(data) + } +} + +// set manages configs with same provider and name with different type +// used config.id as unique identifier +type set struct { + path string + mutex sync.RWMutex + items []*TrpcConfig +} + +// get data +func (s *set) get(id string) *TrpcConfig { + s.mutex.RLock() + defer s.mutex.RUnlock() + for _, v := range s.items { + if v.id == id { + return v + } + } + return nil +} + +func (s *set) getOrStore(tc *TrpcConfig) *TrpcConfig { + if v := s.get(tc.id); v != nil { + return v + } + + s.mutex.Lock() + for _, item := range s.items { + if item.id == tc.id { + s.mutex.Unlock() + return item + } + } + // not found and add + s.items = append(s.items, tc) + s.mutex.Unlock() + return tc +} + +// watch data change, delete no watch model config and update watch model config and target notify +func (s *set) watch(data []byte) { + var items []*TrpcConfig + var del []*TrpcConfig + s.mutex.Lock() + for _, v := range s.items { + if v.watch { + items = append(items, v) + } else { + del = append(del, v) + } + } + s.items = items + s.mutex.Unlock() + + for _, item := range items { + err := item.doWatch(data) + item.notify(data, err) + } + + for _, item := range del { + item.notify(data, nil) + } +} + +// defaultNotifyChange default hook for notify config changed +var defaultWatchHook = func(message WatchMessage) {} + +// SetDefaultWatchHook set default hook notify when config changed +func SetDefaultWatchHook(f func(message WatchMessage)) { + defaultWatchHook = f +} + +// WatchMessage change message +type WatchMessage struct { + Provider string // provider name + Path string // config path + ExpandEnv bool // expend env status + Codec string // codec + Watch bool // status for start watch + Value []byte // config content diff ? + Error error // load error message, success is empty string +} + +var _ Config = (*TrpcConfig)(nil) + // TrpcConfig is used to parse yaml config file for trpc. type TrpcConfig struct { - p DataProvider - unmarshalledData interface{} - path string - decoder Codec - rawData []byte + id string // config identity + msg WatchMessage // new to init message for notify only copy + + p DataProvider // config provider + path string // config name + decoder Codec // config codec + expandEnv bool // status for whether replace the variables in the configuration with environment variables + + // because function is not support comparable in singleton, so the following options work only for the first load + watch bool + watchHook func(message WatchMessage) + + mutex sync.RWMutex + value *entity // store config value } -func newTrpcConfig(path string) *TrpcConfig { - return &TrpcConfig{ - p: GetProvider("file"), - unmarshalledData: make(map[string]interface{}), - path: path, - decoder: &YamlCodec{}, - } +type entity struct { + raw []byte // current binary data + data interface{} // unmarshal type to use point type, save latest no error data } -// Unmarshal deserializes the config into input param. -func (c *TrpcConfig) Unmarshal(out interface{}) error { - return c.decoder.Unmarshal(c.rawData, out) +func newEntity() *entity { + return &entity{ + data: make(map[string]interface{}), + } } -// Load loads config. -func (c *TrpcConfig) Load() error { +func newTrpcConfig(path string, opts ...LoadOption) (*TrpcConfig, error) { + c := &TrpcConfig{ + path: path, + p: GetProvider("file"), + decoder: GetCodec("yaml"), + watchHook: func(message WatchMessage) { + defaultWatchHook(message) + }, + } + for _, o := range opts { + o(c) + } if c.p == nil { - return ErrProviderNotExist + return nil, ErrProviderNotExist + } + if c.decoder == nil { + return nil, ErrCodecNotExist + } + + c.msg.Provider = c.p.Name() + c.msg.Path = c.path + c.msg.Codec = c.decoder.Name() + c.msg.ExpandEnv = c.expandEnv + c.msg.Watch = c.watch + + // since reflect.String() cannot uniquely identify a type, this id is used as a preliminary judgment basis + const idFormat = "provider:%s path:%s codec:%s env:%t watch:%t" + c.id = fmt.Sprintf(idFormat, c.p.Name(), c.path, c.decoder.Name(), c.expandEnv, c.watch) + return c, nil +} + +func (c *TrpcConfig) get() *entity { + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.value != nil { + return c.value + } + return newEntity() +} + +// init return config entity error when entity is empty and load run loads config once +func (c *TrpcConfig) init() error { + c.mutex.RLock() + if c.value != nil { + c.mutex.RUnlock() + return nil + } + c.mutex.RUnlock() + + c.mutex.Lock() + defer c.mutex.Unlock() + if c.value != nil { + return nil } data, err := c.p.Read(c.path) if err != nil { - return fmt.Errorf("trpc/config: failed to load %s: %s", c.path, err.Error()) + return fmt.Errorf("trpc/config failed to load error: %w config id: %s", err, c.id) + } + return c.set(data) +} +func (c *TrpcConfig) doWatch(data []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.set(data) +} +func (c *TrpcConfig) set(data []byte) error { + if c.expandEnv { + data = expandenv.ExpandEnv(data) } - c.rawData = data - if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil { - return fmt.Errorf("trpc/config: failed to parse %s: %s", c.path, err.Error()) + e := newEntity() + e.raw = data + err := c.decoder.Unmarshal(data, &e.data) + if err != nil { + return fmt.Errorf("trpc/config: failed to parse:%w, id:%s", err, c.id) } + c.value = e return nil } +func (c *TrpcConfig) notify(data []byte, err error) { + m := c.msg -// Reload reloads config. -func (c *TrpcConfig) Reload() { + m.Value = data + if err != nil { + m.Error = err + } + + c.watchHook(m) +} + +// Load loads config. +func (c *TrpcConfig) Load() error { if c.p == nil { - return + return ErrProviderNotExist } + c.mutex.Lock() + defer c.mutex.Unlock() data, err := c.p.Read(c.path) if err != nil { - log.Tracef("trpc/config: failed to reload %s: %v", c.path, err) - return + return fmt.Errorf("trpc/config failed to load error: %w config id: %s", err, c.id) } - c.rawData = data - if err := c.decoder.Unmarshal(data, &c.unmarshalledData); err != nil { - log.Tracef("trpc/config: failed to parse %s: %v", c.path, err) - return + return c.set(data) +} + +// Reload reloads config. +func (c *TrpcConfig) Reload() { + if err := c.Load(); err != nil { + log.Tracef("trpc/config: failed to reload %s: %v", c.id, err) } } @@ -223,9 +408,14 @@ func (c *TrpcConfig) Get(key string, defaultValue interface{}) interface{} { return defaultValue } +// Unmarshal deserializes the config into input param. +func (c *TrpcConfig) Unmarshal(out interface{}) error { + return c.decoder.Unmarshal(c.get().raw, out) +} + // Bytes returns original config data as bytes. func (c *TrpcConfig) Bytes() []byte { - return c.rawData + return c.get().raw } // GetInt returns int value by key, the second parameter @@ -333,7 +523,9 @@ func (c *TrpcConfig) findWithDefaultValue(key string, defaultValue interface{}) } func (c *TrpcConfig) search(key string) (interface{}, bool) { - unmarshalledData, ok := c.unmarshalledData.(map[string]interface{}) + e := c.get() + + unmarshalledData, ok := e.data.(map[string]interface{}) if !ok { return nil, false } diff --git a/config/trpc_config_test.go b/config/trpc_config_test.go index a83e60c..11eb852 100644 --- a/config/trpc_config_test.go +++ b/config/trpc_config_test.go @@ -16,10 +16,16 @@ package config import ( "errors" "fmt" + "os" + "reflect" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "trpc.group/trpc-go/trpc-go/errs" + "trpc.group/trpc-go/trpc-go/log" ) func Test_search(t *testing.T) { @@ -115,6 +121,15 @@ func Test_search(t *testing.T) { } } +func TestTrpcConfig_Load(t *testing.T) { + t.Run("parse failed", func(t *testing.T) { + c, err := newTrpcConfig("../testdata/trpc_go.yaml") + require.Nil(t, err) + c.decoder = &TomlCodec{} + err = c.Load() + require.Contains(t, errs.Msg(err), "failed to parse") + }) +} func TestYamlCodec_Unmarshal(t *testing.T) { t.Run("interface", func(t *testing.T) { var tt interface{} @@ -126,3 +141,169 @@ func TestYamlCodec_Unmarshal(t *testing.T) { require.NotNil(t, GetCodec("yaml").Unmarshal([]byte("[1, 2]"), &tt)) }) } + +func TestEnvExpanded(t *testing.T) { + RegisterProvider(NewEnvProvider(t.Name(), []byte(` +password: ${pwd} +`))) + + t.Setenv("pwd", t.Name()) + cfg, err := DefaultConfigLoader.Load( + t.Name(), + WithProvider(t.Name()), + WithExpandEnv()) + require.Nil(t, err) + + require.Equal(t, t.Name(), cfg.GetString("password", "")) + require.Contains(t, string(cfg.Bytes()), fmt.Sprintf("password: %s", t.Name())) +} + +func TestCodecUnmarshalDstMustBeMap(t *testing.T) { + filePath := t.TempDir() + "/conf.map" + require.Nil(t, os.WriteFile(filePath, []byte{}, 0600)) + RegisterCodec(dstMustBeMapCodec{}) + _, err := DefaultConfigLoader.Load(filePath, WithCodec(dstMustBeMapCodec{}.Name())) + require.Nil(t, err) +} + +func NewEnvProvider(name string, data []byte) *EnvProvider { + return &EnvProvider{ + name: name, + data: data, + } +} + +type EnvProvider struct { + name string + data []byte +} + +func (ep *EnvProvider) Name() string { + return ep.name +} + +func (ep *EnvProvider) Read(string) ([]byte, error) { + return ep.data, nil +} + +func (ep *EnvProvider) Watch(cb ProviderCallback) { + cb("", ep.data) +} + +func TestWatch(t *testing.T) { + p := manualTriggerWatchProvider{} + var msgs = make(chan WatchMessage) + SetDefaultWatchHook(func(msg WatchMessage) { + if msg.Error != nil { + log.Errorf("config watch error: %+v", msg) + } else { + log.Infof("config watch error: %+v", msg) + } + msgs <- msg + }) + + RegisterProvider(&p) + p.Set("key", []byte(`key: value`)) + ops := []LoadOption{WithProvider(p.Name()), WithCodec("yaml"), WithWatch()} + c1, err := DefaultConfigLoader.Load("key", ops...) + require.Nilf(t, err, "first load config:%+v", c1) + require.True(t, c1.IsSet("key"), "first load config key exist") + require.Equal(t, c1.Get("key", "default"), "value", "first load config get key value") + + var c2 Config + c2, err = DefaultConfigLoader.Load("key", ops...) + require.Nil(t, err, "second load config:%+v", c2) + require.Equal(t, c1, c2, "first and second load config not equal") + require.True(t, c2.IsSet("key"), "second load config key exist") + require.Equal(t, c2.Get("key", "default"), "value", "second load config get key value") + + var gw sync.WaitGroup + gw.Add(1) + go func() { + defer gw.Done() + tt := time.NewTimer(time.Second) + select { + case <-msgs: + case <-tt.C: + t.Errorf("receive message timeout") + } + }() + + p.Set("key", []byte(`:key: value:`)) + gw.Wait() + + var c3 Config + c3, err = DefaultConfigLoader.Load("key", WithProvider(p.Name()), WithWatchHook(func(msg WatchMessage) { + msgs <- msg + })) + require.Contains(t, errs.Msg(err), "failed to parse") + require.Nil(t, c3, "update error") + + require.True(t, c2.IsSet("key"), "third load config key exist") + require.Equal(t, c2.Get("key", "default"), "value", "third load config get key value") + + gw.Add(1) + go func() { + defer gw.Done() + for i := 0; i < 2; i++ { + tt := time.NewTimer(time.Second) + select { + case <-msgs: + case <-tt.C: + t.Errorf("receive message timeout number%d ", i) + } + } + }() + p.Set("key", []byte(`key: value2`)) + gw.Wait() + + require.Truef(t, c2.IsSet("key"), "after update config and get key exist") + require.Equal(t, c2.Get("key", "default"), "value2", "after update config and config get value") +} + +var _ DataProvider = (*manualTriggerWatchProvider)(nil) + +type manualTriggerWatchProvider struct { + values sync.Map + callbacks []ProviderCallback +} + +func (m *manualTriggerWatchProvider) Name() string { + return "manual_trigger_watch_provider" +} + +func (m *manualTriggerWatchProvider) Read(s string) ([]byte, error) { + if v, ok := m.values.Load(s); ok { + return v.([]byte), nil + } + return nil, fmt.Errorf("not found config") +} + +func (m *manualTriggerWatchProvider) Watch(callback ProviderCallback) { + m.callbacks = append(m.callbacks, callback) +} + +func (m *manualTriggerWatchProvider) Set(key string, v []byte) { + m.values.Store(key, v) + for _, callback := range m.callbacks { + callback(key, v) + } +} + +type dstMustBeMapCodec struct{} + +func (c dstMustBeMapCodec) Name() string { + return "map" +} + +func (c dstMustBeMapCodec) Unmarshal(bts []byte, dst interface{}) error { + rv := reflect.ValueOf(dst) + if rv.Kind() != reflect.Ptr || + rv.Elem().Kind() != reflect.Interface || + rv.Elem().Elem().Kind() != reflect.Map || + rv.Elem().Elem().Type().Key().Kind() != reflect.String || + rv.Elem().Elem().Type().Elem().Kind() != reflect.Interface { + return errors.New("the dst of codec.Unmarshal must be a map") + } + return nil +} diff --git a/docs/basics_tutorial.md b/docs/basics_tutorial.md index c53395f..71295e7 100644 --- a/docs/basics_tutorial.md +++ b/docs/basics_tutorial.md @@ -39,7 +39,7 @@ Note that `Method` has a `{}` at the end, which can also have content. We will s ### Write Client and Server Code -What protobuf gives is a language-independent service definition, and we need to use [trpc command line tool](https://github.com/trpc-group/trpc-cmdline) to translate it into a corresponding language stub code. You can see the various options it supports with `$ tprc create -h`. You can refer to the quick start [helloworld](/examples/helloworld/pb/Makefile) project to quickly create your own stub code. +What protobuf gives is a language-independent service definition, and we need to use [trpc command line tool](https://github.com/trpc-group/trpc-cmdline) to translate it into a corresponding language stub code. You can see the various options it supports with `$ trpc create -h`. You can refer to the quick start [helloworld](/examples/helloworld/pb/Makefile) project to quickly create your own stub code. The stub code is mainly divided into two parts: client and server. Below is part of the generated client code. In [Quick Start](./quick_start.md), we use `NewGreeterClientProxy` to create a client instance and call its `Hello` method: diff --git a/docs/basics_tutorial.zh_CN.md b/docs/basics_tutorial.zh_CN.md index 35bb22f..cc178a9 100644 --- a/docs/basics_tutorial.zh_CN.md +++ b/docs/basics_tutorial.zh_CN.md @@ -39,7 +39,7 @@ message HelloRsp { ### 编写客户端和服务端代码 -protobuf 给出的是一个语言无关的服务定义,我们还要用 [trpc 命令行工具](https://github.com/trpc-group/trpc-cmdline)将它翻译成对应语言的桩代码。你可以通过 `$ tprc create -h` 查看它支持的各种选项。你可以参考快速开始的 [helloworld](/examples/helloworld/pb/Makefile) 项目来快速创建你自己的桩代码。 +protobuf 给出的是一个语言无关的服务定义,我们还要用 [trpc 命令行工具](https://github.com/trpc-group/trpc-cmdline)将它翻译成对应语言的桩代码。你可以通过 `$ trpc create -h` 查看它支持的各种选项。你可以参考快速开始的 [helloworld](/examples/helloworld/pb/Makefile) 项目来快速创建你自己的桩代码。 桩代码主要分为 client 和 server 两部分。 下面是生成的部分 client 代码。在[快速开始](./quick_start.zh_CN.md)中,我们通过 `NewGreeterClientProxy` 来创建一个 client 实例,并调用了它的 `Hello` 方法: diff --git a/docs/user_guide/client/connection_mode.md b/docs/user_guide/client/connection_mode.md index 2a00bb6..5040874 100644 --- a/docs/user_guide/client/connection_mode.md +++ b/docs/user_guide/client/connection_mode.md @@ -128,6 +128,48 @@ if err != nil { log.Info("req:%v, rsp:%v, err:%v", req, rsp, err) ``` + +#### Setting Idle Connection Timeout + +For the client's connection pool mode, the framework sets a default idle timeout of 50 seconds. + +* For `go-net`, the connection pool maintains a list of idle connections. The idle timeout only affects the connections in this idle list and is only triggered when the connection is retrieved next time, causing idle connections to be closed due to the idle timeout. +* For `tnet`, the idle timeout is implemented by maintaining a timer on each connection. Even if a connection is being used for a client's call, if the downstream does not return a result within the idle timeout period, the connection will still be triggered by the idle timeout and forcibly closed. + +The methods to change the idle timeout are as follows: + +* `go-net` + +```go +import "trpc.group/trpc-go/trpc-go/pool/connpool" + +func init() { + connpool.DefaultConnectionPool = connpool.NewConnectionPool( + connpool.WithIdleTimeout(0), // Setting to 0 disables it. + ) +} +``` + +tnet + +```go +import ( + "trpc.group/trpc-go/trpc-go/pool/connpool" + tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" +) + +func init() { + tnettrans.DefaultConnPool = connpool.NewConnectionPool( + connpool.WithDialFunc(tnettrans.Dial), + connpool.WithIdleTimeout(0), // Setting to 0 disables it. + connpool.WithHealthChecker(tnettrans.HealthChecker), + ) +} +``` + +**Note**: The server also has a default idle timeout, which is 60 seconds. This time is designed to be longer than the 50 seconds, so that under default conditions, it is the client that triggers the idle connection timeout to actively close the connection, rather than the server triggering a forced cleanup. For methods to change the server's idle timeout, see the server usage documentation. + + ### I/O multiplexing ```go diff --git a/docs/user_guide/client/connection_mode.zh_CN.md b/docs/user_guide/client/connection_mode.zh_CN.md index dbe025b..de1b4d2 100644 --- a/docs/user_guide/client/connection_mode.zh_CN.md +++ b/docs/user_guide/client/connection_mode.zh_CN.md @@ -126,7 +126,46 @@ if err != nil { log.Info("req:%v, rsp:%v, err:%v", req, rsp, err) ``` -###连接多路复用 +#### 设置空闲连接超时 + +在客户端的连接池模式中,框架默认会设置一个 50 秒的空闲超时时间。 + +* 对于 `go-net` 来说,连接池会维护一个空闲连接列表。空闲超时时间仅对列表中的空闲连接有效,并且只有在下一次尝试获取连接时,才会触发检查并关闭超时的空闲连接。 +* 对于 `tnet`,则是通过在每个连接上设置定时器来实现空闲超时。即便连接正在被用于客户端的调用,如果下游服务在空闲超时时间内没有返回结果,该连接仍然会因为空闲超时而被强制关闭。 + +可以按照以下方式更改空闲超时时间: + +* `go-net` + +```go +import "trpc.group/trpc-go/trpc-go/pool/connpool" + +func init() { + connpool.DefaultConnectionPool = connpool.NewConnectionPool( + connpool.WithIdleTimeout(0), // 设置为 0 以禁用空闲超时 + ) +} +``` + +* `tnet` + +```go +import ( + "trpc.group/trpc-go/trpc-go/pool/connpool" + tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" +) + +func init() { + tnettrans.DefaultConnPool = connpool.NewConnectionPool( + connpool.WithDialFunc(tnettrans.Dial), + connpool.WithIdleTimeout(0), // 设置为 0 以禁用空闲超时 + connpool.WithHealthChecker(tnettrans.HealthChecker), + ) +} +``` + +**注**:服务端默认也设置了一个空闲超时时间,为 60 秒。这个时间比客户端的默认时间长,以确保在大多数情况下,是客户端主动触发空闲超时并关闭连接,而不是服务端强制进行清理。服务端空闲超时时间的修改方法,请参见服务端使用文档。 +### 连接多路复用 ```go opts := []client.Option{ diff --git a/docs/user_guide/framework_conf.md b/docs/user_guide/framework_conf.md index 0198b47..a934838 100644 --- a/docs/user_guide/framework_conf.md +++ b/docs/user_guide/framework_conf.md @@ -110,6 +110,7 @@ server: - # Optional, whether to prohibit inheriting the upstream timeout time, used to close the full link timeout mechanism, the default is false disable_request_timeout: Boolean # Optional, the IP address of the service monitors, if it is empty, it will try to get the network card IP, if it is still empty, use global.local_ip + # To listen on all addresses, please use "0.0.0.0" (IPv4) or "::" (IPv6). ip: String(ipv4 or ipv6) # Required, the service name, used for service discovery name: String diff --git a/docs/user_guide/framework_conf.zh_CN.md b/docs/user_guide/framework_conf.zh_CN.md index fe85470..4ab74e7 100644 --- a/docs/user_guide/framework_conf.zh_CN.md +++ b/docs/user_guide/framework_conf.zh_CN.md @@ -108,6 +108,7 @@ server: - # 选填,是否禁止继承上游的超时时间,用于关闭全链路超时机制,默认为 false disable_request_timeout: Boolean # 选填,service 监听的 IP 地址,如果为空,则会尝试获取网卡 IP,如果仍为空,则使用 global.local_ip + # 如果需要监听所有地址的话,请使用 "0.0.0.0" (ipv4) 或 "::" (ipv6) ip: String(ipv4 or ipv6) # 必填,服务名,用于服务发现 name: String diff --git a/docs/user_guide/server/overview.md b/docs/user_guide/server/overview.md index d2bc2a5..eac2b36 100644 --- a/docs/user_guide/server/overview.md +++ b/docs/user_guide/server/overview.md @@ -268,6 +268,19 @@ tRPC-Go provides three timeout mechanisms for RPC calls: link timeout, message t This feature requires protocol support (the protocol needs to carry timeout metadata downstream). The tRPC protocol, generic HTTP RPC protocol all support timeout control. +## Idle Timeout + +The server has a default idle timeout of 60 seconds to prevent excessive idle connections from consuming server-side resources. This value can be modified through the `idletimeout` setting in the framework configuration: + +```yaml +server: + service: + - name: trpc.server.service.Method + network: tcp + protocol: trpc + idletime: 60000 # The unit is milliseconds. Setting it to -1 means there is no idle timeout (setting it to 0 will still default to the 60s by the framework) +``` + ## Link transmission The tRPC-Go framework provides a mechanism for passing fields between the client and server and passing them down the entire call chain. For the mechanism and usage of link transmission, please refer to [tRPC-Go Link Transmission](/docs/user_guide/metadata_transmission.md). @@ -288,4 +301,4 @@ tRPC-Go allows businesses to define and register serialization and deserializati ## Setting the maximum number of service coroutines -tRPC-Go supports service-level synchronous/asynchronous packet processing modes. For asynchronous mode, a coroutine pool is used to improve coroutine usage efficiency and performance. Users can set the maximum number of service coroutines through framework configuration and Option configuration. For details, please refer to the service configuration in the [tPRC-Go Framework Configuration](/docs/user_guide/framework_conf.md) section. +tRPC-Go supports service-level synchronous/asynchronous packet processing modes. For asynchronous mode, a coroutine pool is used to improve coroutine usage efficiency and performance. Users can set the maximum number of service coroutines through framework configuration and Option configuration. For details, please refer to the service configuration in the [tRPC-Go Framework Configuration](/docs/user_guide/framework_conf.md) section. diff --git a/docs/user_guide/server/overview.zh_CN.md b/docs/user_guide/server/overview.zh_CN.md index 72ee65b..7c74efd 100644 --- a/docs/user_guide/server/overview.zh_CN.md +++ b/docs/user_guide/server/overview.zh_CN.md @@ -266,7 +266,20 @@ tRPC-Go 从设计之初就考虑了框架的易测性,在通过 pb 生成桩 tRPC-Go 为 RPC 调用提供了 3 种超时机制控制:链路超时,消息超时和调用超时。关于这 3 种超时机制的原理介绍和相关配置,请参考 [tRPC-Go 超时控制](/docs/user_guide/timeout_control.zh_CN.md)。 -此功能需要协议的支持(协议需要携带 timeout 元数据到下游),tRPC 协议,泛 HTTP RPC 协议均支持超时控制功能。其 +此功能需要协议的支持(协议需要携带 timeout 元数据到下游),tRPC 协议,泛 HTTP RPC 协议均支持超时控制功能。 + +## 空闲超时 + +服务默认存在一个 60s 的空闲超时时间,以防止过多空闲连接消耗服务侧的资源,这个值可以通过框架配置中的 `idletimeout` 来进行修改: + +```yaml +server: + service: + - name: trpc.server.service.Method + network: tcp + protocol: trpc + idletime: 60000 # 单位是毫秒, 设置为 -1 的时候表示没有空闲超时(这里设置为 0 时框架仍会自动转为默认的 60s) +``` ## 链路透传 @@ -288,4 +301,4 @@ tRPC-Go 自定义 RPC 消息体的序列化、反序列化方式,业务可以 ## 设置服务最大协程数 -tRPC-Go 支持服务级别的同/异步包处理模式,对于异步模式采用协程池来提升协程使用效率和性能。用户可以通过框架配置和 Option 配置两种方式来设置服务的最大协程数,具体请参考 [tPRC-Go 框架配置](/docs/user_guide/framework_conf.zh_CN.md) 章节的 service 配置。 +tRPC-Go 支持服务级别的同/异步包处理模式,对于异步模式采用协程池来提升协程使用效率和性能。用户可以通过框架配置和 Option 配置两种方式来设置服务的最大协程数,具体请参考 [tRPC-Go 框架配置](/docs/user_guide/framework_conf.zh_CN.md) 章节的 service 配置。 diff --git a/examples/features/config/client/main.go b/examples/features/config/client/main.go index 55f68f9..23eeb78 100644 --- a/examples/features/config/client/main.go +++ b/examples/features/config/client/main.go @@ -41,4 +41,26 @@ func main() { return } fmt.Printf("Get msg: %s\n", rsp.GetMsg()) + // print + // + // Get msg: trpc-go-server response: Hello trpc-go-client + // load once config: number_1 + // start watch config:number_1 + + req = &pb.HelloRequest{ + Msg: "change config", // target config change + } + + // Send request. + rsp, err = clientProxy.SayHello(ctx, req) + if err != nil { + fmt.Println("Say hi err:%v", err) + return + } + fmt.Printf("Get msg: %s\n", rsp.GetMsg()) + // print + // + // Get msg: trpc-go-server response: Hello trpc-go-client + // load once config: number_1 + // start watch config:number_2 } diff --git a/examples/features/config/server/main.go b/examples/features/config/server/main.go index 2e36ff8..0351fbd 100644 --- a/examples/features/config/server/main.go +++ b/examples/features/config/server/main.go @@ -17,6 +17,7 @@ package main import ( "context" "fmt" + "sync" trpc "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/config" @@ -27,33 +28,43 @@ import ( func main() { // Parse configuration files in yaml format. - conf, err := config.Load("server/custom.yaml", config.WithCodec("yaml"), config.WithProvider("file")) + // Load default codec is `yaml` and provider is `file` + c, err := config.Load("custom.yaml", config.WithCodec("yaml"), config.WithProvider("file")) if err != nil { fmt.Println(err) return } + fmt.Printf("test : %s \n", c.GetString("custom.test", "")) + fmt.Printf("key1 : %s \n", c.GetString("custom.test_obj.key1", "")) + fmt.Printf("key2 : %t \n", c.GetBool("custom.test_obj.key2", false)) + fmt.Printf("key2 : %d \n", c.GetInt32("custom.test_obj.key3", 0)) + + // print + // test : customConfigFromServer + // key1 : value1 + // key2 : true + // key3 : 1234 + // The format of the configuration file corresponds to custom struct. var custom customStruct - if err := conf.Unmarshal(&custom); err != nil { + if err := c.Unmarshal(&custom); err != nil { fmt.Println(err) } fmt.Printf("Get config - custom : %v \n", custom) - - fmt.Printf("test : %s \n", conf.GetString("custom.test", "")) - fmt.Printf("key1 : %s \n", conf.GetString("custom.test_obj.key1", "")) - fmt.Printf("key2 : %t \n", conf.GetBool("custom.test_obj.key2", false)) - fmt.Printf("key2 : %d \n", conf.GetInt32("custom.test_obj.key3", 0)) + // print: Get config - custom : {{customConfigFromServer {value1 true 1234}}} // Init server. s := trpc.NewServer() + config.RegisterProvider(p) // Register service. - greeterImpl := &greeterImpl{ - customConf: conf.GetString("custom.test", ""), - } - pb.RegisterGreeterService(s, greeterImpl) + imp := &greeterImpl{} + imp.once, _ = config.Load(p.Name(), config.WithProvider(p.Name())) + imp.watch, _ = config.Load(p.Name(), config.WithProvider(p.Name()), config.WithWatch()) + + pb.RegisterGreeterService(s, imp) // Serve and listen. if err := s.Serve(); err != nil { @@ -62,6 +73,58 @@ func main() { } +const cf = `custom : + test : number_%d + test_obj : + key1 : value_%d + key2 : %t + key3 : %d` + +var p = &provider{} + +// mock provider to trigger config change +type provider struct { + mu sync.Mutex + data []byte + num int + callbacks []config.ProviderCallback +} + +func (p *provider) Name() string { + return "test" +} + +func (p *provider) Read(s string) ([]byte, error) { + if s != p.Name() { + return nil, fmt.Errorf("not found config %s", s) + } + p.mu.Lock() + defer p.mu.Unlock() + + if p.data == nil { + p.num++ + p.data = []byte(fmt.Sprintf(cf, p.num, p.num, p.num%2 == 0, p.num)) + } + return p.data, nil +} + +func (p *provider) Watch(callback config.ProviderCallback) { + p.mu.Lock() + defer p.mu.Unlock() + p.callbacks = append(p.callbacks, callback) +} + +func (p *provider) update() { + p.mu.Lock() + p.num++ + p.data = []byte(fmt.Sprintf(cf, p.num, p.num, p.num%2 == 0, p.num)) + callbacks := p.callbacks + p.mu.Unlock() + for _, callback := range callbacks { + callback(p.Name(), p.data) + } +} + // customStruct it defines the struct of the custom configuration file read. type customStruct struct { Custom struct { @@ -78,16 +141,37 @@ type customStruct struct { type greeterImpl struct { common.GreeterServerImpl - customConf string + once config.Config + watch config.Config } // SayHello say hello request. Rewrite SayHello to inform server config. func (g *greeterImpl) SayHello(_ context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) { fmt.Printf("trpc-go-server SayHello, req.msg:%s\n", req.Msg) + if req.Msg == "change config" { + p.update() + } + rsp := &pb.HelloReply{} - rsp.Msg = "trpc-go-server response: Hello " + req.Msg + ". Custom config from server: " + g.customConf + rsp.Msg = "trpc-go-server response: Hello " + req.Msg + + fmt.Sprintf("\nload once config: %s", g.once.GetString("custom.test", "")) + + fmt.Sprintf("\nstart watch config: %s", g.watch.GetString("custom.test", "")) + fmt.Printf("trpc-go-server SayHello, rsp.msg:%s\n", rsp.Msg) return rsp, nil } + +// first print +// +// trpc-go-server SayHello, rsp.msg:trpc-go-server response: Hello trpc-go-client +// load once config: number_1 +// start watch config:number_1 +// +// second print +// +// trpc-go-server SayHello, req.msg:change config +// trpc-go-server SayHello, rsp.msg:trpc-go-server response: Hello change config +// load once config: number_1 +// start watch config:number_2 diff --git a/examples/go.mod b/examples/go.mod index d6d8b8b..08de7b7 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -6,8 +6,7 @@ replace trpc.group/trpc-go/trpc-go => ../ require ( github.com/golang/protobuf v1.5.2 - github.com/stretchr/testify v1.8.0 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 trpc.group/trpc-go/trpc-go v0.0.0-00010101000000-000000000000 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 ) @@ -15,7 +14,6 @@ require ( require ( github.com/BurntSushi/toml v0.3.1 // indirect github.com/andybalholm/brotli v1.0.4 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-playground/form/v4 v4.2.0 // indirect github.com/golang/mock v1.4.4 // indirect @@ -31,7 +29,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/panjf2000/ants/v2 v2.4.6 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.43.0 // indirect @@ -44,5 +41,5 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - trpc.group/trpc-go/tnet v1.0.0 // indirect + trpc.group/trpc-go/tnet v1.0.1 // indirect ) diff --git a/examples/go.sum b/examples/go.sum index f50aca0..53365b8 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -59,13 +59,10 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.43.0 h1:Gy4sb32C98fbzVWZlTM1oTMdLWGyvxR03VhM6cBIU4g= @@ -120,19 +117,19 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/go.mod b/go.mod index c98e1c8..ad60b03 100644 --- a/go.mod +++ b/go.mod @@ -26,9 +26,9 @@ require ( golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.13.0 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 - trpc.group/trpc-go/tnet v1.0.0 + trpc.group/trpc-go/tnet v1.0.1 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 ) diff --git a/go.sum b/go.sum index 258b5d9..237009b 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -137,7 +139,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/http/codec.go b/http/codec.go index 6f58a28..0ea39f7 100644 --- a/http/codec.go +++ b/http/codec.go @@ -361,8 +361,8 @@ func unmarshalTransInfo(msg codec.Msg, v string) (map[string][]byte, error) { // getReqbody gets the body of request. func (sc *ServerCodec) getReqbody(head *Header, msg codec.Msg) ([]byte, error) { - msg.WithServerRPCName(head.Request.URL.Path) msg.WithCalleeMethod(head.Request.URL.Path) + msg.WithServerRPCName(head.Request.URL.Path) if !sc.AutoReadBody { return nil, nil diff --git a/http/restful_server_transport.go b/http/restful_server_transport.go index 499b787..3a4a40e 100644 --- a/http/restful_server_transport.go +++ b/http/restful_server_transport.go @@ -83,7 +83,6 @@ func putRESTMsgInCtx( ctx, msg := codec.WithNewMessage(ctx) msg.WithCalleeServiceName(service) msg.WithServerRPCName(method) - msg.WithCalleeMethod(method) msg.WithSerializationType(codec.SerializationTypePB) if v := headerGetter(TrpcTimeout); v != "" { i, _ := strconv.Atoi(v) @@ -160,12 +159,20 @@ func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transp ln = tls.NewListener(ln, tlsConf) } + go func() { + <-opts.StopListening + ln.Close() + }() + return st.serve(ctx, ln, opts) } // serve starts service. -func (st *RESTServerTransport) serve(ctx context.Context, ln net.Listener, - opts *transport.ListenServeOptions) error { +func (st *RESTServerTransport) serve( + ctx context.Context, + ln net.Listener, + opts *transport.ListenServeOptions, +) error { // Get router. router := restful.GetRouter(opts.ServiceName) if router == nil { diff --git a/http/serialization_form.go b/http/serialization_form.go index 8de5736..916b067 100644 --- a/http/serialization_form.go +++ b/http/serialization_form.go @@ -41,16 +41,16 @@ func NewFormSerialization(tag string) codec.Serializer { decoder.SetTagName(tag) return &FormSerialization{ tagname: tag, - encoder: encoder, - decoder: decoder, + encode: encoder.Encode, + decode: wrapDecodeWithRecovery(decoder.Decode), } } // FormSerialization packages the kv structure of http get request. type FormSerialization struct { tagname string - encoder *form.Encoder - decoder *form.Decoder + encode func(interface{}) (url.Values, error) + decode func(interface{}, url.Values) error } // Unmarshal unpacks kv structure. @@ -61,13 +61,14 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { } switch body.(type) { // go-playground/form does not support map structure. - case map[string]interface{}, *map[string]interface{}, map[string]string, *map[string]string: + case map[string]interface{}, *map[string]interface{}, map[string]string, *map[string]string, + url.Values, *url.Values: // Essentially, the underlying type of 'url.Values' is also a map. return unmarshalValues(j.tagname, values, body) default: } // First try using go-playground/form, it can handle nested struct. // But it cannot handle Chinese characters in byte slice. - err = j.decoder.Decode(body, values) + err = j.decode(body, values) if err == nil { return nil } @@ -78,8 +79,38 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { return nil } +// wrapDecodeWithRecovery wraps the decode function, adding panic recovery to handle +// panics as errors. This function is designed to prevent malformed query parameters +// from causing a panic, which is the default behavior of the go-playground/form decoder +// implementation. This is because, in certain cases, it's more acceptable to receive +// a degraded result rather than experiencing a direct server crash. +// Besides, the behavior of not panicking also ensures backward compatibility ( 0 { + // invalid matching, remove the $ + } else if name == nil { + buf = append(buf, s[j]) // keep the $ + } else { + buf = append(buf, os.Getenv(string(name))...) + } + j += w + i = j + 1 + } + } + if buf == nil { + return s + } + return append(buf, s[i:]...) +} + +// getEnvName gets env name, that is, var from ${var}. +// The env name and its len will be returned. +func getEnvName(s []byte) ([]byte, int) { + // look for right curly bracket '}' + // it's guaranteed that the first char is '{' and the string has at least two char + for i := 1; i < len(s); i++ { + if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx" + return nil, 0 // encounter invalid char, keep the $ + } + if s[i] == '}' { + if i == 1 { // ${} + return nil, 2 // remove ${} + } + return s[1:i], i + 1 + } + } + return nil, 0 // no },keep the $ +} diff --git a/internal/expandenv/expand_env_test.go b/internal/expandenv/expand_env_test.go new file mode 100644 index 0000000..96439b5 --- /dev/null +++ b/internal/expandenv/expand_env_test.go @@ -0,0 +1,47 @@ +package expandenv_test + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + . "trpc.group/trpc-go/trpc-go/internal/expandenv" +) + +func TestExpandEnv(t *testing.T) { + key := "env_key" + t.Run("no env", func(t *testing.T) { + require.Equal(t, []byte("abc"), ExpandEnv([]byte("abc"))) + }) + t.Run("${..} is expanded", func(t *testing.T) { + t.Setenv(key, t.Name()) + require.Equal(t, fmt.Sprintf("head_%s_tail", t.Name()), + string(ExpandEnv([]byte(fmt.Sprintf("head_${%s}_tail", key))))) + }) + t.Run("${ is not expanded", func(t *testing.T) { + require.Equal(t, "head_${_tail", + string(ExpandEnv([]byte(fmt.Sprintf("head_${_tail"))))) + }) + t.Run("${} is expanded as empty", func(t *testing.T) { + require.Equal(t, "head__tail", + string(ExpandEnv([]byte("head_${}_tail")))) + }) + t.Run("${..} is not expanded if .. contains any space", func(t *testing.T) { + t.Setenv("key key", t.Name()) + require.Equal(t, "head_${key key}_tail", + string(ExpandEnv([]byte("head_${key key}_tail")))) + }) + t.Run("${..} is not expanded if .. contains any new line", func(t *testing.T) { + t.Setenv("key\nkey", t.Name()) + require.Equal(t, t.Name(), os.Getenv("key\nkey")) + require.Equal(t, "head_${key\nkey}_tail", + string(ExpandEnv([]byte("head_${key\nkey}_tail")))) + }) + t.Run(`${..} is not expanded if .. contains any "`, func(t *testing.T) { + t.Setenv(`key"key`, t.Name()) + require.Equal(t, t.Name(), os.Getenv(`key"key`)) + require.Equal(t, `head_${key"key}_tail`, + string(ExpandEnv([]byte(`head_${key"key}_tail`)))) + }) +} diff --git a/log/log.go b/log/log.go index 311dc95..6460660 100644 --- a/log/log.go +++ b/log/log.go @@ -96,7 +96,7 @@ func RedirectStdLogAt(logger Logger, level zapcore.Level) (func(), error) { return nil, errors.New("log: only supports redirecting std logs to trpc zap logger") } -// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. func Trace(args ...interface{}) { if traceEnabled { GetDefaultLogger().Trace(args...) @@ -110,7 +110,7 @@ func Tracef(format string, args ...interface{}) { } } -// TraceContext logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// TraceContext logs to TRACE log. Arguments are handled in the manner of fmt.Println. func TraceContext(ctx context.Context, args ...interface{}) { if !traceEnabled { return @@ -134,7 +134,7 @@ func TraceContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Tracef(format, args...) } -// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func Debug(args ...interface{}) { GetDefaultLogger().Debug(args...) } @@ -144,7 +144,7 @@ func Debugf(format string, args ...interface{}) { GetDefaultLogger().Debugf(format, args...) } -// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. +// Info logs to INFO log. Arguments are handled in the manner of fmt.Println. func Info(args ...interface{}) { GetDefaultLogger().Info(args...) } @@ -154,7 +154,7 @@ func Infof(format string, args ...interface{}) { GetDefaultLogger().Infof(format, args...) } -// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. func Warn(args ...interface{}) { GetDefaultLogger().Warn(args...) } @@ -164,7 +164,7 @@ func Warnf(format string, args ...interface{}) { GetDefaultLogger().Warnf(format, args...) } -// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. func Error(args ...interface{}) { GetDefaultLogger().Error(args...) } @@ -174,7 +174,7 @@ func Errorf(format string, args ...interface{}) { GetDefaultLogger().Errorf(format, args...) } -// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. func Fatal(args ...interface{}) { @@ -211,7 +211,7 @@ func WithContextFields(ctx context.Context, fields ...string) context.Context { return ctx } -// DebugContext logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// DebugContext logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func DebugContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Debug(args...) @@ -229,7 +229,7 @@ func DebugContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Debugf(format, args...) } -// InfoContext logs to INFO log. Arguments are handled in the manner of fmt.Print. +// InfoContext logs to INFO log. Arguments are handled in the manner of fmt.Println. func InfoContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Info(args...) @@ -247,7 +247,7 @@ func InfoContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Infof(format, args...) } -// WarnContext logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// WarnContext logs to WARNING log. Arguments are handled in the manner of fmt.Println. func WarnContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Warn(args...) @@ -266,7 +266,7 @@ func WarnContextf(ctx context.Context, format string, args ...interface{}) { } -// ErrorContext logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// ErrorContext logs to ERROR log. Arguments are handled in the manner of fmt.Println. func ErrorContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Error(args...) @@ -284,7 +284,7 @@ func ErrorContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Errorf(format, args...) } -// FatalContext logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// FatalContext logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. func FatalContext(ctx context.Context, args ...interface{}) { diff --git a/log/logger.go b/log/logger.go index ff1a208..9c8c31f 100644 --- a/log/logger.go +++ b/log/logger.go @@ -74,27 +74,27 @@ type Field struct { // Logger is the underlying logging work for tRPC framework. type Logger interface { - // Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. + // Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. Trace(args ...interface{}) // Tracef logs to TRACE log. Arguments are handled in the manner of fmt.Printf. Tracef(format string, args ...interface{}) - // Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. + // Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. Debug(args ...interface{}) // Debugf logs to DEBUG log. Arguments are handled in the manner of fmt.Printf. Debugf(format string, args ...interface{}) - // Info logs to INFO log. Arguments are handled in the manner of fmt.Print. + // Info logs to INFO log. Arguments are handled in the manner of fmt.Println. Info(args ...interface{}) // Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf. Infof(format string, args ...interface{}) - // Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. + // Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. Warn(args ...interface{}) // Warnf logs to WARNING log. Arguments are handled in the manner of fmt.Printf. Warnf(format string, args ...interface{}) - // Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. + // Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. Error(args ...interface{}) // Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf. Errorf(format string, args ...interface{}) - // Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print. + // Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. Fatal(args ...interface{}) diff --git a/log/rollwriter/async_roll_writer.go b/log/rollwriter/async_roll_writer.go index 697ee42..1b84534 100644 --- a/log/rollwriter/async_roll_writer.go +++ b/log/rollwriter/async_roll_writer.go @@ -117,6 +117,17 @@ func (w *AsyncRollWriter) batchWriteLog() { buffer.Reset() } case data := <-w.logQueue: + if len(data) >= w.opts.WriteLogSize { + // If the length of the current data exceeds the expected maximum value, + // we directly write it to the underlying logger instead of placing it into the buffer. + // This prevents the buffer from being overwhelmed by excessively large data, + // which could lead to memory leaks. + // Prior to that, we need to write the existing data in the buffer to the underlying logger. + _, _ = w.logger.Write(buffer.Bytes()) + buffer.Reset() + _, _ = w.logger.Write(data) + continue + } buffer.Write(data) if buffer.Len() >= w.opts.WriteLogSize { _, err := w.logger.Write(buffer.Bytes()) diff --git a/log/rollwriter/roll_writer_test.go b/log/rollwriter/roll_writer_test.go index 79bc7a2..357c70a 100644 --- a/log/rollwriter/roll_writer_test.go +++ b/log/rollwriter/roll_writer_test.go @@ -376,6 +376,16 @@ func TestAsyncRollWriterSyncTwice(t *testing.T) { require.Nil(t, w.Close()) } +func TestAsyncRollWriterDirectWrite(t *testing.T) { + logSize := 1 + w := NewAsyncRollWriter(&noopWriteCloser{}, WithWriteLogSize(logSize)) + _, _ = w.Write([]byte("hello")) + time.Sleep(time.Millisecond) + require.Nil(t, w.Sync()) + require.Nil(t, w.Sync()) + require.Nil(t, w.Close()) +} + func TestRollWriterError(t *testing.T) { logDir := t.TempDir() t.Run("reopen file", func(t *testing.T) { diff --git a/log/zaplogger.go b/log/zaplogger.go index 63b48af..4b61c5e 100644 --- a/log/zaplogger.go +++ b/log/zaplogger.go @@ -119,14 +119,26 @@ func newEncoder(c *OutputConfig) zapcore.Encoder { if c.EnableColor { encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder } - switch c.Formatter { - case "console": - return zapcore.NewConsoleEncoder(encoderCfg) - case "json": - return zapcore.NewJSONEncoder(encoderCfg) - default: - return zapcore.NewConsoleEncoder(encoderCfg) + if newFormatEncoder, ok := formatEncoders[c.Formatter]; ok { + return newFormatEncoder(encoderCfg) } + // Defaults to console encoder. + return zapcore.NewConsoleEncoder(encoderCfg) +} + +var formatEncoders = map[string]NewFormatEncoder{ + "console": zapcore.NewConsoleEncoder, + "json": zapcore.NewJSONEncoder, +} + +// NewFormatEncoder is the function type for creating a format encoder out of an encoder config. +type NewFormatEncoder func(zapcore.EncoderConfig) zapcore.Encoder + +// RegisterFormatEncoder registers a NewFormatEncoder with the specified formatName key. +// The existing formats include "console" and "json", but you can override these format encoders +// or provide a new custom one. +func RegisterFormatEncoder(formatName string, newFormatEncoder NewFormatEncoder) { + formatEncoders[formatName] = newFormatEncoder } // GetLogEncoderKey gets user defined log output name, uses defKey if empty. @@ -270,7 +282,8 @@ func (l *zapLog) With(fields ...Field) Logger { } func getLogMsg(args ...interface{}) string { - msg := fmt.Sprint(args...) + msg := fmt.Sprintln(args...) + msg = msg[:len(msg)-1] report.LogWriteSize.IncrBy(float64(len(msg))) return msg } @@ -281,7 +294,7 @@ func getLogMsgf(format string, args ...interface{}) string { return msg } -// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Trace(args ...interface{}) { if l.logger.Core().Enabled(zapcore.DebugLevel) { l.logger.Debug(getLogMsg(args...)) @@ -295,7 +308,7 @@ func (l *zapLog) Tracef(format string, args ...interface{}) { } } -// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Debug(args ...interface{}) { if l.logger.Core().Enabled(zapcore.DebugLevel) { l.logger.Debug(getLogMsg(args...)) @@ -309,7 +322,7 @@ func (l *zapLog) Debugf(format string, args ...interface{}) { } } -// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. +// Info logs to INFO log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Info(args ...interface{}) { if l.logger.Core().Enabled(zapcore.InfoLevel) { l.logger.Info(getLogMsg(args...)) @@ -323,7 +336,7 @@ func (l *zapLog) Infof(format string, args ...interface{}) { } } -// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Warn(args ...interface{}) { if l.logger.Core().Enabled(zapcore.WarnLevel) { l.logger.Warn(getLogMsg(args...)) @@ -337,7 +350,7 @@ func (l *zapLog) Warnf(format string, args ...interface{}) { } } -// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Error(args ...interface{}) { if l.logger.Core().Enabled(zapcore.ErrorLevel) { l.logger.Error(getLogMsg(args...)) @@ -351,7 +364,7 @@ func (l *zapLog) Errorf(format string, args ...interface{}) { } } -// Fatal logs to FATAL log. Arguments are handled in the manner of fmt.Print. +// Fatal logs to FATAL log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Fatal(args ...interface{}) { if l.logger.Core().Enabled(zapcore.FatalLevel) { l.logger.Fatal(getLogMsg(args...)) diff --git a/log/zaplogger_test.go b/log/zaplogger_test.go index 58eca32..5846fec 100644 --- a/log/zaplogger_test.go +++ b/log/zaplogger_test.go @@ -17,12 +17,14 @@ import ( "errors" "fmt" "runtime" + "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" + "go.uber.org/zap/buffer" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" @@ -334,3 +336,50 @@ func TestLogEnableColor(t *testing.T) { l.Warn("hello") l.Error("hello") } + +func TestLogNewFormatEncoder(t *testing.T) { + const myFormatter = "myformatter" + log.RegisterFormatEncoder(myFormatter, func(ec zapcore.EncoderConfig) zapcore.Encoder { + return &consoleEncoder{ + Encoder: zapcore.NewJSONEncoder(zapcore.EncoderConfig{}), + pool: buffer.NewPool(), + cfg: ec, + } + }) + cfg := []log.OutputConfig{{Writer: "console", Level: "trace", Formatter: myFormatter}} + l := log.NewZapLog(cfg).With(log.Field{Key: "trace-id", Value: "xx"}) + l.Trace("hello") + l.Debug("hello") + l.Info("hello") + l.Warn("hello") + l.Error("hello") + // 2023/12/14 10:54:55 {"trace-id":"xx"} DEBUG hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} DEBUG hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} INFO hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} WARN hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} ERROR hello +} + +type consoleEncoder struct { + zapcore.Encoder + pool buffer.Pool + cfg zapcore.EncoderConfig +} + +func (c consoleEncoder) Clone() zapcore.Encoder { + return consoleEncoder{Encoder: c.Encoder.Clone(), pool: buffer.NewPool(), cfg: c.cfg} +} + +func (c consoleEncoder) EncodeEntry(entry zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + buf, err := c.Encoder.EncodeEntry(zapcore.Entry{}, nil) + if err != nil { + return nil, err + } + buffer := c.pool.Get() + buffer.AppendString(entry.Time.Format("2006/01/02 15:04:05")) + field := buf.String() + buffer.AppendString(" " + field[:len(field)-1] + " ") + buffer.AppendString(strings.ToUpper(entry.Level.String()) + " ") + buffer.AppendString(entry.Message + "\n") + return buffer, nil +} diff --git a/naming/registry/node.go b/naming/registry/node.go index d4f4d19..c5f0a34 100644 --- a/naming/registry/node.go +++ b/naming/registry/node.go @@ -15,6 +15,7 @@ package registry import ( "fmt" + "net" "time" ) @@ -30,6 +31,9 @@ type Node struct { CostTime time.Duration // 当次请求耗时 EnvKey string // 透传的环境信息 Metadata map[string]interface{} + // ParseAddr should be used to convert Node to net.Addr if it's not nil. + // See test case TestSelectorRemoteAddrUseUserProvidedParser in client package. + ParseAddr func(network, address string) net.Addr } // String returns an abbreviation information of node. diff --git a/pool/connpool/checker_unix_test.go b/pool/connpool/checker_unix_test.go index 6f303a6..5ec677c 100644 --- a/pool/connpool/checker_unix_test.go +++ b/pool/connpool/checker_unix_test.go @@ -53,7 +53,7 @@ func TestRemoteEOF(t *testing.T) { require.Nil(t, pc.Close()) } -func TestUnexceptedRead(t *testing.T) { +func TestUnexpectedRead(t *testing.T) { var s server require.Nil(t, s.init()) diff --git a/restful/router.go b/restful/router.go index 2a7df82..2683e9f 100644 --- a/restful/router.go +++ b/restful/router.go @@ -199,7 +199,6 @@ var DefaultHeaderMatcher = func( func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context { ctx, msg := codec.WithNewMessage(ctx) msg.WithServerRPCName(methodName) - msg.WithCalleeMethod(methodName) msg.WithCalleeServiceName(serviceName) msg.WithSerializationType(codec.SerializationTypePB) return ctx diff --git a/restful/serialize_jsonpb.go b/restful/serialize_jsonpb.go index 0e6d1ea..67297f4 100644 --- a/restful/serialize_jsonpb.go +++ b/restful/serialize_jsonpb.go @@ -31,12 +31,23 @@ func init() { // JSONPBSerializer is used for content-Type: application/json. // It's based on google.golang.org/protobuf/encoding/protojson. +// +// This serializer will firstly try jsonpb's serialization. If object does not +// conform to protobuf proto.Message interface, the serialization will switch to +// json-iterator. type JSONPBSerializer struct { AllowUnmarshalNil bool // allow unmarshalling nil body } // JSONAPI is a copy of jsoniter.ConfigCompatibleWithStandardLibrary. // github.com/json-iterator/go is faster than Go's standard json library. +// +// Deprecated: This global variable is exportable due to backward comparability issue but +// should not be modified. If users want to change the default behavior of +// internal JSON serialization, please use register your customized serializer +// function like: +// +// restful.RegisterSerializer(yourOwnJSONSerializer) var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary // Marshaller is a configurable protojson marshaler. diff --git a/server/serve_unix.go b/server/serve_unix.go index 16f4902..102f8a6 100644 --- a/server/serve_unix.go +++ b/server/serve_unix.go @@ -115,9 +115,6 @@ func (s *Server) StartNewProcess(args ...string) (uintptr, error) { return 0, err } - for _, f := range listenersFds { - f.OriginalListenCloser.Close() - } return uintptr(childPID), nil } diff --git a/server/server_unix_test.go b/server/server_unix_test.go index 8447e17..4873e06 100644 --- a/server/server_unix_test.go +++ b/server/server_unix_test.go @@ -59,20 +59,10 @@ func TestStartNewProcess(t *testing.T) { s.AddService("trpc.test.helloworld.Greeter1", service) err := s.Register(nil, nil) - assert.NotNil(t, err) impl := &GreeterServerImpl{} err = s.Register(&GreeterServerServiceDesc, impl) - assert.Nil(t, err) - go func() { - var netOpError *net.OpError - assert.ErrorAs( - t, - s.Serve(), - &netOpError, - `it is normal to have "use of closed network connection" error during hot restart`, - ) - }() + go s.Serve() time.Sleep(time.Second * 1) log.Info(os.Environ()) diff --git a/server/service.go b/server/service.go index 631a54f..25980fe 100644 --- a/server/service.go +++ b/server/service.go @@ -102,13 +102,14 @@ type Stream interface { // service is an implementation of Service type service struct { + activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime ctx context.Context // context of this service cancel context.CancelFunc // function that cancels this service opts *Options // options of this service handlers map[string]Handler // rpcname => handler streamHandlers map[string]StreamHandler streamInfo map[string]*StreamServerInfo - activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime + stopListening chan<- struct{} } // New creates a service. @@ -534,21 +535,28 @@ func (s *service) Close(ch chan struct{}) error { } } } - s.waitBeforeClose() + if remains := s.waitBeforeClose(); remains > 0 { + log.Infof("process %d service %s remains %d requests before close", + os.Getpid(), s.opts.ServiceName, remains) + } // this will cancel all children ctx. s.cancel() + timeout := time.Millisecond * 300 if s.opts.Timeout > timeout { // use the larger one timeout = s.opts.Timeout } - time.Sleep(timeout) + if remains := s.waitInactive(timeout); remains > 0 { + log.Infof("process %d service %s remains %d requests after close", + os.Getpid(), s.opts.ServiceName, remains) + } log.Infof("process:%d, %s service:%s, closed", pid, s.opts.protocol, s.opts.ServiceName) ch <- struct{}{} return nil } -func (s *service) waitBeforeClose() { +func (s *service) waitBeforeClose() int64 { closeWaitTime := s.opts.CloseWaitTime if closeWaitTime > MaxCloseWaitTime { closeWaitTime = MaxCloseWaitTime @@ -562,18 +570,17 @@ func (s *service) waitBeforeClose() { os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount), closeWaitTime) time.Sleep(closeWaitTime) } + return s.waitInactive(s.opts.MaxCloseWaitTime - closeWaitTime) +} + +func (s *service) waitInactive(maxWaitTime time.Duration) int64 { const sleepTime = 100 * time.Millisecond - if s.opts.MaxCloseWaitTime > closeWaitTime { - spinCount := int((s.opts.MaxCloseWaitTime - closeWaitTime) / sleepTime) - for i := 0; i < spinCount; i++ { - if atomic.LoadInt64(&s.activeCount) <= 0 { - break - } - time.Sleep(sleepTime) + for start := time.Now(); time.Since(start) < maxWaitTime; time.Sleep(sleepTime) { + if atomic.LoadInt64(&s.activeCount) <= 0 { + return 0 } - log.Infof("process %d service %s remain %d requests when closing service", - os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount)) } + return atomic.LoadInt64(&s.activeCount) } func checkProcessStatus() (isGracefulRestart, isParentalProcess bool) { diff --git a/server/service_test.go b/server/service_test.go index 61b222a..70a28f3 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -306,55 +306,84 @@ func TestServiceUDP(t *testing.T) { require.Nil(t, err) } -func TestServiceCloseWait(t *testing.T) { - const waitChildTime = 300 * time.Millisecond - const schTime = 10 * time.Millisecond - cases := []struct { - closeWaitTime time.Duration - maxCloseWaitTime time.Duration - waitTime time.Duration - }{ - { - waitTime: waitChildTime, - }, - { - closeWaitTime: 50 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, - { - closeWaitTime: 50 * time.Millisecond, - maxCloseWaitTime: 30 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, - { - closeWaitTime: 50 * time.Millisecond, - maxCloseWaitTime: 100 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, +func TestCloseWaitTime(t *testing.T) { + startService := func(opts ...server.Option) (chan struct{}, func()) { + received, done := make(chan struct{}), make(chan struct{}) + addr, stop := startService(t, &Greeter{}, append([]server.Option{server.WithFilter( + func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) { + received <- struct{}{} + <-done + return nil, errors.New("must fail") + })}, opts...)...) + go func() { + _, _ = pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr)). + SayHello(context.Background(), &pb.HelloRequest{}) + }() + <-received + return done, stop } - for _, c := range cases { - service := server.New( - server.WithRegistry(&fakeRegistry{}), - server.WithCloseWaitTime(c.closeWaitTime), - server.WithMaxCloseWaitTime(c.maxCloseWaitTime), - ) + t.Run("active requests feature is not enabled on missing MaxCloseWaitTime", func(t *testing.T) { + done, stop := startService() + defer close(done) start := time.Now() - err := service.Close(nil) - assert.Nil(t, err) - cost := time.Since(start) - assert.GreaterOrEqual(t, cost, c.waitTime) - assert.LessOrEqual(t, cost, c.waitTime+schTime) - } + stop() + require.Less(t, time.Since(start), time.Millisecond*100) + }) + t.Run("total wait time should not significantly greater than MaxCloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + defer close(done) + start := time.Now() + stop() + require.WithinRange(t, time.Now(), + // 300ms comes from the internal implementation when close service + start.Add(maxCloseWaitTime).Add(time.Millisecond*300), + start.Add(maxCloseWaitTime).Add(time.Millisecond*500)) + }) + t.Run("total wait time is at least CloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + start := time.Now() + time.AfterFunc(closeWaitTime/2, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(closeWaitTime+time.Millisecond*100)) + }) + t.Run("no active request before MaxCloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + start := time.Now() + time.AfterFunc((closeWaitTime+maxCloseWaitTime)/2, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(maxCloseWaitTime)) + }) + t.Run("no active request before service timeout", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime, timeout = time.Millisecond * 500, time.Second, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime), + server.WithTimeout(timeout)) + start := time.Now() + time.AfterFunc(maxCloseWaitTime+time.Millisecond*100, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(maxCloseWaitTime+time.Millisecond*100), start.Add(maxCloseWaitTime+timeout)) + }) } func startService(t *testing.T, gs GreeterServer, opts ...server.Option) (addr string, stop func()) { l, err := net.Listen("tcp", "0.0.0.0:0") require.Nil(t, err) - s := server.New(append(append([]server.Option{ - server.WithNetwork("tcp"), - server.WithProtocol("trpc"), - }, opts...), + s := server.New(append(append( + []server.Option{ + server.WithNetwork("tcp"), + server.WithProtocol("trpc"), + }, opts...), server.WithListener(l), )...) require.Nil(t, s.Register(&GreeterServerServiceDesc, gs)) diff --git a/stream/README.zh_CN.md b/stream/README.zh_CN.md index 892d870..0bedb1f 100644 --- a/stream/README.zh_CN.md +++ b/stream/README.zh_CN.md @@ -322,7 +322,7 @@ func main() { - 接收端每消费 1/4 的初始窗口大小进行 feedback,发送一个 feedback 帧,携带增量的 window size,发送端接收到这个增量 window size 之后加到本地可发送的 window 大小 - 帧分优先级,对于 feedback 的帧不做流控,优先级高于 Data 帧,防止因为优先级问题导致 feedback 帧发生阻塞 -tPRC-Go 默认启用流控,目前默认窗口大小为 65535,如果连续发送超过 65535 大小的数据(序列化和压缩后),接收方没调用 Recv,则发送方会 block +tRPC-Go 默认启用流控,目前默认窗口大小为 65535,如果连续发送超过 65535 大小的数据(序列化和压缩后),接收方没调用 Recv,则发送方会 block 如果要设置客户端接收窗口大小,使用 client option `WithMaxWindowSize` ```go diff --git a/stream/client.go b/stream/client.go index 65a6f85..e68ba6e 100644 --- a/stream/client.go +++ b/stream/client.go @@ -17,6 +17,7 @@ package stream import ( "context" "errors" + "fmt" "io" "sync" "sync/atomic" @@ -187,7 +188,7 @@ func (cs *clientStream) SendMsg(m interface{}) error { msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, cs.streamID)) msg.WithStreamID(cs.streamID) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) + msg.WithCompressType(codec.Message(cs.ctx).CompressType()) return cs.stream.Send(ctx, m) } @@ -215,7 +216,6 @@ func (cs *clientStream) CloseSend() error { func (cs *clientStream) prepare(opt ...client.Option) error { msg := codec.Message(cs.ctx) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) msg.WithStreamID(cs.streamID) opt = append([]client.Option{client.WithStreamTransport(transport.DefaultClientStreamTransport)}, opt...) @@ -237,8 +237,8 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) copyMetaData(newMsg, codec.Message(cs.ctx)) newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, cs.streamID)) newMsg.WithClientRPCName(cs.method) - newMsg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) newMsg.WithStreamID(cs.streamID) + newMsg.WithCompressType(codec.Message(cs.ctx).CompressType()) newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{ RequestMeta: &trpcpb.TrpcStreamInitRequestMeta{}, InitWindowSize: w, @@ -252,15 +252,13 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) if _, err := cs.stream.Recv(newCtx); err != nil { return nil, err } - if newMsg.ClientRspErr() != nil { - return nil, newMsg.ClientRspErr() - } - - initWindowSize := defaultInitWindowSize - if initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta); ok { - // If the server has feedback, use the server's window, if not, use the default window. - initWindowSize = initRspMeta.GetInitWindowSize() + initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) + if !ok { + return nil, fmt.Errorf("client stream (method = %s, streamID = %d) recv "+ + "unexpected frame type: %T, expected: %T", + cs.method, cs.streamID, newMsg.StreamFrame(), (*trpcpb.TrpcStreamInitMeta)(nil)) } + initWindowSize := initRspMeta.GetInitWindowSize() cs.configSendControl(initWindowSize) // Start the dispatch goroutine loop to send packets. @@ -286,7 +284,6 @@ func (cs *clientStream) feedback(i uint32) error { msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, cs.streamID)) msg.WithStreamID(cs.streamID) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: i}) return cs.stream.Send(ctx, nil) } @@ -339,12 +336,13 @@ func (cs *clientStream) dispatch() { }() for { ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) + msg.WithCompressType(codec.Message(cs.ctx).CompressType()) msg.WithStreamID(cs.streamID) respData, err := cs.stream.Recv(ctx) if err != nil { // return to client on error. cs.recvQueue.Put(&response{ - err: errs.Wrap(err, errs.RetClientStreamReadEnd, streamClosed), + err: errs.WrapFrameError(err, errs.RetClientStreamReadEnd, streamClosed), }) return } diff --git a/stream/client_test.go b/stream/client_test.go index 281770d..3e28f2d 100644 --- a/stream/client_test.go +++ b/stream/client_test.go @@ -16,11 +16,11 @@ package stream_test import ( "context" + "crypto/rand" "encoding/binary" "errors" "fmt" "io" - "math/rand" "testing" "time" @@ -41,6 +41,8 @@ var ctx = context.Background() type fakeTransport struct { expectChan chan recvExpect + send func() error + close func() } // RoundTrip Mock RoundTrip method. @@ -51,9 +53,8 @@ func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, // Send Mock Send method. func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { - err, ok := ctx.Value("send-error").(string) - if ok { - return errors.New(err) + if c.send != nil { + return c.send() } return nil } @@ -80,7 +81,9 @@ func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOpt // Close Mock Close method. func (c *fakeTransport) Close(ctx context.Context) { - return + if c.close != nil { + c.close() + } } type fakeCodec struct { @@ -330,6 +333,18 @@ func TestClientError(t *testing.T) { assert.Nil(t, cs) assert.NotNil(t, err) + // receive unexpected stream frame type + f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { + msg.WithStreamFrame(int(1)) + return nil, nil + } + ft.expectChan <- f + cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", + client.WithTarget("ip://127.0.0.1:8000"), + client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000)) + assert.Nil(t, cs) + assert.Contains(t, err.Error(), "unexpected frame type") } // TestClientContext tests the case of streaming client context cancel and timeout. @@ -687,7 +702,8 @@ func TestClientStreamReturn(t *testing.T) { rsp := getBytes(dataLen) err = clientStream.RecvMsg(rsp) - assert.EqualValues(t, int32(101), err.(*errs.Error).Code) + + assert.EqualValues(t, int32(101), errs.Code(err.(*errs.Error).Unwrap())) } // TestClientSendFailWhenServerUnavailable test when the client blocks @@ -746,3 +762,79 @@ func TestClientReceiveErrorWhenServerUnavailable(t *testing.T) { assert.NotEqual(t, io.EOF, err) assert.ErrorIs(t, err, io.EOF) } + +func TestClientNewStreamFail(t *testing.T) { + codec.Register("mock", nil, &fakeCodec{}) + t.Run("Close Transport when Send Fail", func(t *testing.T) { + var isClosed bool + tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} + tp.send = func() error { + return errors.New("client error") + } + tp.close = func() { + isClosed = true + } + _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", + client.WithProtocol("mock"), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(tp), + ) + assert.NotNil(t, err) + assert.True(t, isClosed) + }) + t.Run("Close Transport when Recv Fail", func(t *testing.T) { + var isClosed bool + tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} + tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { + m.WithClientRspErr(errors.New("server error")) + return nil, nil + } + tp.close = func() { + isClosed = true + } + _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", + client.WithProtocol("mock"), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(tp), + ) + assert.NotNil(t, err) + assert.True(t, isClosed) + }) +} + +func TestClientServerCompress(t *testing.T) { + var ( + dataLen = 1024 + compressType = codec.CompressTypeSnappy + ) + svrOpts := []server.Option{ + server.WithAddress("127.0.0.1:30211"), + } + handle := func(s server.Stream) error { + assert.Equal(t, compressType, codec.Message(s.Context()).CompressType()) + req := getBytes(dataLen) + s.RecvMsg(req) + rsp := req + s.SendMsg(rsp) + return nil + } + svr := startStreamServer(handle, svrOpts) + defer closeStreamServer(svr) + + cliOpts := []client.Option{ + client.WithTarget("ip://127.0.0.1:30211"), + client.WithCompressType(compressType), + } + + clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts) + assert.Nil(t, err) + req := getBytes(dataLen) + rand.Read(req.Data) + err = clientStream.SendMsg(req) + assert.Nil(t, err) + + rsp := getBytes(dataLen) + err = clientStream.RecvMsg(rsp) + assert.Equal(t, rsp.Data, req.Data) + assert.Nil(t, err) +} diff --git a/stream/flow_control.go b/stream/flow_control.go index 1395345..a379897 100644 --- a/stream/flow_control.go +++ b/stream/flow_control.go @@ -85,7 +85,6 @@ func checkUpdate(updatedWindow, increment int64) bool { type receiveControl struct { buffer uint32 // upper limit. unUpdated uint32 // Consumed, no window update sent. - left uint32 // remaining available buffer. fb feedback // function for feedback. } @@ -93,7 +92,6 @@ func newReceiveControl(buffer uint32, fb feedback) *receiveControl { return &receiveControl{ buffer: buffer, fb: fb, - left: buffer, } } @@ -103,17 +101,10 @@ func (r *receiveControl) OnRecv(n uint32) error { if r.unUpdated >= r.buffer/4 { increment := r.unUpdated r.unUpdated = 0 - r.updateLeft() if r.fb != nil { return r.fb(increment) } return nil } - r.updateLeft() return nil } - -// updateLeft updates the remaining available buffers. -func (r *receiveControl) updateLeft() { - atomic.StoreUint32(&r.left, r.buffer-r.unUpdated) -} diff --git a/stream/flow_control_test.go b/stream/flow_control_test.go index 4ffa4e4..e0b4935 100644 --- a/stream/flow_control_test.go +++ b/stream/flow_control_test.go @@ -15,7 +15,6 @@ package stream import ( "errors" - "sync/atomic" "testing" "time" @@ -41,8 +40,8 @@ func TestSendControl(t *testing.T) { }() err = sc.GetWindow(200) assert.Nil(t, err) - t2 := int64(time.Now().Sub(t1)) - assert.GreaterOrEqual(t, t2, int64(500*time.Millisecond)) + t2 := time.Since(t1) + assert.GreaterOrEqual(t, t2, 500*time.Millisecond) } // TestReceiveControl test. @@ -54,9 +53,6 @@ func TestReceiveControl(t *testing.T) { err := rc.OnRecv(100) assert.Nil(t, err) - n := atomic.LoadUint32(&rc.left) - assert.Equal(t, defaultInitWindowSize-uint32(100), n) - // need to send updates. err = rc.OnRecv(defaultInitWindowSize / 4) assert.Nil(t, err) diff --git a/stream/server.go b/stream/server.go index 3e13185..0a669f0 100644 --- a/stream/server.go +++ b/stream/server.go @@ -17,9 +17,10 @@ import ( "context" "errors" "io" - "net" "sync" + "go.uber.org/atomic" + "trpc.group/trpc-go/trpc-go/internal/addrutil" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" trpc "trpc.group/trpc-go/trpc-go" @@ -40,7 +41,7 @@ type serverStream struct { opts *server.Options recvQueue *queue.Queue[*response] done chan struct{} - err error // Carry the server tcp failure information. + err atomic.Error // Carry the server tcp failure information. once sync.Once rControl *receiveControl // Receiver flow control. sControl *sendControl // Sender flow control. @@ -48,11 +49,15 @@ type serverStream struct { // SendMsg is the API that users use to send streaming messages. func (s *serverStream) SendMsg(m interface{}) error { + if err := s.err.Load(); err != nil { + return errs.WrapFrameError(err, errs.Code(err), "stream sending error") + } msg := codec.Message(s.ctx) ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx) defer codec.PutBackMessage(newMsg) newMsg.WithLocalAddr(msg.LocalAddr()) newMsg.WithRemoteAddr(msg.RemoteAddr()) + newMsg.WithCompressType(msg.CompressType()) newMsg.WithStreamID(s.streamID) // Refer to the pb code generated by trpc.proto, common to each language, automatically generated code. newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID)) @@ -61,7 +66,7 @@ func (s *serverStream) SendMsg(m interface{}) error { err error reqBodyBuffer []byte ) - serializationType, compressType := s.serializationAndCompressType(msg) + serializationType, compressType := s.serializationAndCompressType(newMsg) if icodec.IsValidSerializationType(serializationType) { reqBodyBuffer, err = codec.Marshal(serializationType, m) if err != nil { @@ -119,8 +124,8 @@ func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) { func (s *serverStream) RecvMsg(m interface{}) error { resp, ok := s.recvQueue.Get() if !ok { - if s.err != nil { - return s.err + if err := s.err.Load(); err != nil { + return err } return errs.NewFrameError(errs.RetServerSystemErr, streamClosed) } @@ -217,9 +222,12 @@ func (s *serverStream) Context() context.Context { // The structure of streamDispatcher is used to distribute streaming data. type streamDispatcher struct { - m sync.RWMutex - streamIDToServerStream map[net.Addr]map[uint32]*serverStream - opts *server.Options + m sync.RWMutex + // local address + remote address + network + // => stream ID + // => serverStream + addrToServerStream map[string]map[uint32]*serverStream + opts *server.Options } // DefaultStreamDispatcher is the default implementation of the trpc dispatcher, @@ -229,45 +237,45 @@ var DefaultStreamDispatcher = NewStreamDispatcher() // NewStreamDispatcher returns a new dispatcher. func NewStreamDispatcher() server.StreamHandle { return &streamDispatcher{ - streamIDToServerStream: make(map[net.Addr]map[uint32]*serverStream), + addrToServerStream: make(map[string]map[uint32]*serverStream), } } // storeServerStream msg contains the socket address of the client connection, // there are multiple streams under each socket address, and map it to serverStream // again according to the id of the stream. -func (sd *streamDispatcher) storeServerStream(addr net.Addr, streamID uint32, ss *serverStream) { +func (sd *streamDispatcher) storeServerStream(addr string, streamID uint32, ss *serverStream) { sd.m.Lock() defer sd.m.Unlock() - if addrToStreamID, ok := sd.streamIDToServerStream[addr]; !ok { + if addrToStreamID, ok := sd.addrToServerStream[addr]; !ok { // Does not exist, indicating that a new connection is coming, re-create the structure. - sd.streamIDToServerStream[addr] = map[uint32]*serverStream{streamID: ss} + sd.addrToServerStream[addr] = map[uint32]*serverStream{streamID: ss} } else { addrToStreamID[streamID] = ss } } // deleteServerStream deletes the serverStream from cache. -func (sd *streamDispatcher) deleteServerStream(addr net.Addr, streamID uint32) { +func (sd *streamDispatcher) deleteServerStream(addr string, streamID uint32) { sd.m.Lock() defer sd.m.Unlock() - if addrToStreamID, ok := sd.streamIDToServerStream[addr]; ok { + if addrToStreamID, ok := sd.addrToServerStream[addr]; ok { if _, ok = addrToStreamID[streamID]; ok { delete(addrToStreamID, streamID) } if len(addrToStreamID) == 0 { - delete(sd.streamIDToServerStream, addr) + delete(sd.addrToServerStream, addr) } } } // loadServerStream loads the stored serverStream through the socket address // of the client connection and the id of the stream. -func (sd *streamDispatcher) loadServerStream(addr net.Addr, streamID uint32) (*serverStream, error) { +func (sd *streamDispatcher) loadServerStream(addr string, streamID uint32) (*serverStream, error) { sd.m.RLock() defer sd.m.RUnlock() - addrToStream, ok := sd.streamIDToServerStream[addr] - if !ok || addr == nil { + addrToStream, ok := sd.addrToServerStream[addr] + if !ok { return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) } @@ -293,7 +301,7 @@ func (sd *streamDispatcher) Init(opts *server.Options) error { // startStreamHandler is used to start the goroutine, execute streamHandler, // streamHandler is implemented for the specific streaming server. -func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32, +func (sd *streamDispatcher) startStreamHandler(addr string, streamID uint32, ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) { defer func() { sd.deleteServerStream(addr, streamID) @@ -320,7 +328,7 @@ func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32, err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "") } if err != nil { - ss.err = err + ss.err.Store(err) log.Trace(closeSendFail, err) } } @@ -357,7 +365,7 @@ func (sd *streamDispatcher) handleInit(ctx context.Context, ss := newServerStream(ctx, streamID, sd.opts) w := getWindowSize(sd.opts.MaxWindowSize) ss.rControl = newReceiveControl(w, ss.feedback) - sd.storeServerStream(msg.RemoteAddr(), streamID, ss) + sd.storeServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss) cw, err := ss.setSendControl(msg) if err != nil { @@ -390,13 +398,13 @@ func (sd *streamDispatcher) handleInit(ctx context.Context, } // Initiate a goroutine to execute specific business logic. - go sd.startStreamHandler(msg.RemoteAddr(), streamID, ss, si, sh) + go sd.startStreamHandler(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss, si, sh) return nil, errs.ErrServerNoResponse } // handleData handles data messages. func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { return nil, err } @@ -406,7 +414,7 @@ func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error // handleClose handles the Close message. func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { // The server has sent the Close frame. // Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client, @@ -429,17 +437,17 @@ func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { sd.m.Lock() defer sd.m.Unlock() - addr := msg.RemoteAddr() - addrToStream, ok := sd.streamIDToServerStream[addr] - if !ok || addr == nil { + addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) + addrToStream, ok := sd.addrToServerStream[addr] + if !ok { return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) } for streamID, ss := range addrToStream { - ss.err = msg.ServerRspErr() + ss.err.Store(msg.ServerRspErr()) ss.once.Do(func() { close(ss.done) }) delete(addrToStream, streamID) } - delete(sd.streamIDToServerStream, addr) + delete(sd.addrToServerStream, addr) return nil, errs.ErrServerNoResponse } @@ -462,7 +470,7 @@ func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context, // handleFeedback handles the feedback frame. func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { return nil, err } diff --git a/stream/server_test.go b/stream/server_test.go index 7a5cb80..4fd79ae 100644 --- a/stream/server_test.go +++ b/stream/server_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/rand" + "net" "sync" "testing" "time" @@ -147,6 +148,7 @@ func TestStreamDispatcherHandleInit(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) @@ -207,6 +209,7 @@ func TestStreamDispatcherHandleData(t *testing.T) { msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) @@ -219,7 +222,8 @@ func TestStreamDispatcherHandleData(t *testing.T) { assert.Equal(t, err, errs.ErrServerNoResponse) // handleData error no such addr - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) @@ -264,6 +268,7 @@ func TestStreamDispatcherHandleClose(t *testing.T) { addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) msg.WithFrameHead(fh) rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) @@ -278,7 +283,8 @@ func TestStreamDispatcherHandleClose(t *testing.T) { // handle close no such addr msg.WithFrameHead(fh) - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) assert.Nil(t, rsp) assert.Equal(t, errs.ErrServerNoResponse, err) @@ -330,6 +336,7 @@ func TestServerStreamSendMsg(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeNoop @@ -407,6 +414,7 @@ func TestServerStreamRecvMsg(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeNoop opts.CurrentSerializationType = codec.SerializationTypeNoop @@ -467,6 +475,7 @@ func TestServerStreamRecvMsgFail(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeGzip @@ -520,6 +529,7 @@ func TestHandleError(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeGzip @@ -585,12 +595,15 @@ func TestStreamDispatcherHandleFeedback(t *testing.T) { addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) // handle feedback get server stream fail - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) + msg.WithLocalAddr(raddr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) @@ -599,6 +612,7 @@ func TestStreamDispatcherHandleFeedback(t *testing.T) { // handle feedback invalid stream msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) @@ -635,6 +649,7 @@ func TestServerFlowControl(t *testing.T) { msg.WithStreamID(uint32(100)) addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 65535}) opts.CurrentCompressType = codec.CompressTypeNoop opts.CurrentSerializationType = codec.SerializationTypeNoop @@ -661,6 +676,7 @@ func TestServerFlowControl(t *testing.T) { newCtx, newMsg := codec.WithNewMessage(newCtx) newMsg.WithStreamID(uint32(100)) newMsg.WithRemoteAddr(addr) + newMsg.WithLocalAddr(addr) newFh := &trpc.FrameHead{} newFh.StreamID = uint32(100) newFh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) @@ -911,3 +927,119 @@ func serverFilterAdd2(ss server.Stream, si *server.StreamServerInfo, err := handler(newWrappedServerStream(ss)) return err } + +// TestServerStreamAllFailWhenConnectionClosedAndReconnect tests when a connection +// is closed and then reconnected (with the same client IP and port), both SendMsg +// and RecvMsg on the server side result in errors. +func TestServerStreamAllFailWhenConnectionClosedAndReconnect(t *testing.T) { + ch := make(chan struct{}) + addr := "127.0.0.1:30211" + svrOpts := []server.Option{ + server.WithAddress(addr), + } + handle := func(s server.Stream) error { + <-ch + err := s.SendMsg(getBytes(100)) + assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) + err = s.RecvMsg(getBytes(100)) + assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) + ch <- struct{}{} + return nil + } + svr := startStreamServer(handle, svrOpts) + defer closeStreamServer(svr) + + // Init a stream + dialer := net.Dialer{ + LocalAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 20001}, + } + conn, err := dialer.Dial("tcp", addr) + assert.Nil(t, err) + _, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{ + FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), + StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT), + }) + msg.WithClientRPCName("/trpc.test.stream.Greeter/StreamSayHello") + initReq, err := trpc.DefaultClientCodec.Encode(msg, nil) + assert.Nil(t, err) + _, err = conn.Write(initReq) + assert.Nil(t, err) + + // Close the connection + conn.Close() + + // Dial another connection using the same client ip:port + time.Sleep(time.Millisecond * 200) + _, err = dialer.Dial("tcp", addr) + assert.Nil(t, err) + + // Notify server to send and receive + ch <- struct{}{} + + // Wait server sending and receiving result assertion + <-ch +} + +func TestSameClientAddrDiffServerAddr(t *testing.T) { + dp := stream.NewStreamDispatcher() + dp.Init(&server.Options{ + Transport: &fakeServerTransport{}, + Codec: &fakeServerCodec{}, + CurrentSerializationType: codec.SerializationTypeNoop}) + wg := sync.WaitGroup{} + + initFrame := func(localAddr, remoteAddr net.Addr) { + ctx, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)}) + msg.WithStreamID(200) + msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) + msg.WithRemoteAddr(remoteAddr) + msg.WithLocalAddr(localAddr) + msg.WithSerializationType(codec.SerializationTypeNoop) + wg.Add(1) + rsp, err := dp.StreamHandleFunc( + ctx, + func(s server.Stream) error { + err := s.RecvMsg(&codec.Body{}) + assert.Nil(t, err) + wg.Done() + return nil + }, + &server.StreamServerInfo{}, + []byte("init")) + assert.Nil(t, rsp) + assert.Equal(t, errs.ErrServerNoResponse, err) + } + + dataFrame := func(localAddr, remoteAddr net.Addr) { + ctx, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)}) + msg.WithStreamID(200) + msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) + msg.WithRemoteAddr(remoteAddr) + msg.WithLocalAddr(localAddr) + rsp, err := dp.StreamHandleFunc(ctx, nil, &server.StreamServerInfo{}, []byte("data")) + assert.Nil(t, rsp) + assert.Equal(t, errs.ErrServerNoResponse, err) + } + + clientAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9000") + serverAddr1, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10001") + serverAddr2, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10002") + initFrame(serverAddr1, clientAddr) + initFrame(serverAddr2, clientAddr) + dataFrame(serverAddr1, clientAddr) + dataFrame(serverAddr2, clientAddr) + + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: // completed normally + case <-time.After(time.Millisecond * 500): // timed out + assert.FailNow(t, "server did not receive data frame") + } +} diff --git a/test/filter_test.go b/test/filter_test.go index d534779..b23dec4 100644 --- a/test/filter_test.go +++ b/test/filter_test.go @@ -15,6 +15,7 @@ package test import ( "context" + "errors" "time" "github.com/stretchr/testify/require" @@ -189,6 +190,9 @@ func (s *TestSuite) TestStreamServerFilter() { s1.Send(&testpb.StreamingInputCallRequest{}) require.Nil(s.T(), err) _, err = s1.CloseAndRecv() + require.Equal(s.T(), errs.RetClientStreamReadEnd, errs.Code(err)) + + err = errors.Unwrap(err) require.Equal(s.T(), errs.Code(filterTestError), errs.Code(err)) require.Equal(s.T(), errs.Msg(filterTestError), errs.Msg(err)) diff --git a/test/go.mod b/test/go.mod index 632f2bb..a92e774 100644 --- a/test/go.mod +++ b/test/go.mod @@ -10,7 +10,7 @@ require ( go.uber.org/zap v1.26.0 golang.org/x/net v0.17.0 golang.org/x/sync v0.4.0 - google.golang.org/protobuf v1.31.0 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 trpc.group/trpc-go/trpc-go v0.0.0-00010101000000-000000000000 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 @@ -43,5 +43,5 @@ require ( go.uber.org/multierr v1.10.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect - trpc.group/trpc-go/tnet v1.0.0 // indirect + trpc.group/trpc-go/tnet v1.0.1 // indirect ) diff --git a/test/go.sum b/test/go.sum index 2821134..c0689bf 100644 --- a/test/go.sum +++ b/test/go.sum @@ -114,8 +114,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -125,7 +125,7 @@ gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/test/graceful_restart_test.go b/test/graceful_restart_test.go index 741cffb..dd4bf3c 100644 --- a/test/graceful_restart_test.go +++ b/test/graceful_restart_test.go @@ -42,6 +42,9 @@ func (s *TestSuite) TestServerGracefulRestart() { s.Run("SendNonGracefulRestartSignal", func() { s.testSendNonGracefulRestartSignal() }) + s.Run("GracefulRestartForEmptyIP", func() { + s.testGracefulRestartForEmptyIP() + }) } func (s *TestSuite) testServerGracefulRestartIsIdempotent() { @@ -250,6 +253,37 @@ func (s *TestSuite) testSendNonGracefulRestartSignal() { }) } +func (s *TestSuite) testGracefulRestartForEmptyIP() { + const ( + binaryFile = "./gracefulrestart/trpc/server.o" + sourceFile = "./gracefulrestart/trpc/server.go" + configFile = "./gracefulrestart/trpc/trpc_go_emptyip.yaml" + ) + + cmd, err := startServerFromBash( + sourceFile, + configFile, + binaryFile, + ) + require.Nil(s.T(), err) + defer func() { + require.Nil(s.T(), exec.Command("rm", binaryFile).Run()) + require.Nil(s.T(), cmd.Process.Kill()) + }() + + const target = "ip://127.0.0.1:17777" + sp, err := getServerProcessByEmptyCall(target) + require.Nil(s.T(), err) + pid := sp.Pid + require.Nil(s.T(), sp.Signal(server.DefaultServerGracefulSIG)) + time.Sleep(1 * time.Second) + sp, err = getServerProcessByEmptyCall(target) + require.Nil(s.T(), err) + require.NotEqual(s.T(), pid, sp.Pid) + pid = sp.Pid + require.Nil(s.T(), sp.Kill()) +} + func startServerFromBash(sourceFile, configFile, targetFile string) (*exec.Cmd, error) { cmd := exec.Command( "bash", diff --git a/test/gracefulrestart/trpc/trpc_go_emptyip.yaml b/test/gracefulrestart/trpc/trpc_go_emptyip.yaml new file mode 100644 index 0000000..28f6367 --- /dev/null +++ b/test/gracefulrestart/trpc/trpc_go_emptyip.yaml @@ -0,0 +1,13 @@ +global: + namespace: Development + env_name: test +server: + app: testing + server: end2end + admin: + port: 19999 + service: + - name: trpc.testing.end2end.TestTRPC + protocol: trpc + network: tcp + port: 17777 diff --git a/test/http_test.go b/test/http_test.go index c0cc32c..91cbcaa 100644 --- a/test/http_test.go +++ b/test/http_test.go @@ -151,6 +151,34 @@ func (s *TestSuite) testSendHTTPSRequestToHTTPServer(e *httpRPCEnv) { require.Contains(s.T(), errs.Msg(err), "codec empty") } +func (s *TestSuite) TestHandleErrServerNoResponse() { + for _, e := range allHTTPRPCEnvs { + if e.client.multiplexed { + continue + } + s.Run(e.String(), func() { s.testHandleErrServerNoResponse(e) }) + } +} +func (s *TestSuite) testHandleErrServerNoResponse(e *httpRPCEnv) { + s.startServer(&testHTTPService{TRPCService: TRPCService{UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return nil, errs.ErrServerNoResponse + }}}, server.WithServerAsync(e.server.async)) + + s.T().Cleanup(func() { s.closeServer(nil) }) + + bts, err := proto.Marshal(s.defaultSimpleRequest) + require.Nil(s.T(), err) + + c := thttp.NewStdHTTPClient("http-client") + rsp, err := c.Post(s.unaryCallCustomURL(), "application/pb", bytes.NewReader(bts)) + require.Nil(s.T(), err) + require.Equal(s.T(), http.StatusInternalServerError, rsp.StatusCode) + + bts, err = io.ReadAll(rsp.Body) + require.Nil(s.T(), err) + require.Containsf(s.T(), string(bts), "http server handle error: type:framework, code:0, msg:server no response", "full err: %+v", err) +} + func (s *TestSuite) TestStatusBadRequestDueToServerValidateFail() { for _, e := range allHTTPRPCEnvs { if e.client.multiplexed { diff --git a/transport/client_transport_test.go b/transport/client_transport_test.go index 70db95f..1f94911 100644 --- a/transport/client_transport_test.go +++ b/transport/client_transport_test.go @@ -18,7 +18,9 @@ import ( "errors" "fmt" "io" + "math" "net" + "strings" "testing" "time" @@ -348,6 +350,24 @@ func TestClientTransport_RoundTrip(t *testing.T) { }() time.Sleep(20 * time.Millisecond) + t.Run("write: message too long", func(t *testing.T) { + c := mustListenUDP(t) + t.Cleanup(func() { + if err := c.Close(); err != nil { + t.Log(err) + } + }) + largeRequest := encodeLengthDelimited(strings.Repeat("1", math.MaxInt32/4)) + _, err := transport.RoundTrip(context.Background(), largeRequest, + transport.WithClientFramerBuilder(fb), + transport.WithDialNetwork("udp"), + transport.WithDialAddress(c.LocalAddr().String()), + transport.WithReqType(transport.SendAndRecv), + ) + require.Equal(t, errs.RetClientNetErr, errs.Code(err)) + require.Contains(t, errs.Msg(err), "udp client transport WriteTo") + }) + var err error _, err = transport.RoundTrip(context.Background(), encodeLengthDelimited("helloworld")) assert.NotNil(t, err) @@ -489,6 +509,14 @@ func TestClientTransport_RoundTrip(t *testing.T) { assert.Contains(t, err.Error(), remainingBytesError.Error()) } +func mustListenUDP(t *testing.T) net.PacketConn { + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + return c +} + // Frame a stream of bytes based on a length prefix // +------------+--------------------------------+ // | len: uint8 | frame payload | @@ -609,7 +637,6 @@ func TestClientTransport_MultiplexedErr(t *testing.T) { } func TestClientTransport_RoundTrip_PreConnected(t *testing.T) { - go func() { err := transport.ListenAndServe( transport.WithListenNetwork("udp"), diff --git a/transport/server_listenserve_options.go b/transport/server_listenserve_options.go index f8b6c8f..dac0397 100644 --- a/transport/server_listenserve_options.go +++ b/transport/server_listenserve_options.go @@ -43,6 +43,9 @@ type ListenServeOptions struct { // This used for rpc transport layer like http, it's unrelated to // the TCP keep-alives. DisableKeepAlives bool + + // StopListening is used to instruct the server transport to stop listening. + StopListening <-chan struct{} } // ListenServeOption modifies the ListenServeOptions. @@ -149,3 +152,10 @@ func WithServerIdleTimeout(timeout time.Duration) ListenServeOption { options.IdleTimeout = timeout } } + +// WithStopListening returns a ListenServeOption which notifies the transport to stop listening. +func WithStopListening(ch <-chan struct{}) ListenServeOption { + return func(options *ListenServeOptions) { + options.StopListening = ch + } +} diff --git a/transport/server_transport.go b/transport/server_transport.go index 6ca98eb..36c86c7 100644 --- a/transport/server_transport.go +++ b/transport/server_transport.go @@ -18,7 +18,6 @@ import ( "crypto/tls" "errors" "fmt" - "io" "net" "os" "runtime" @@ -242,6 +241,23 @@ func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listen } func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { + var once sync.Once + closeListener := func() { ln.Close() } + defer once.Do(closeListener) + // Create a goroutine to watch ctx.Done() channel. + // Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection. + go func() { + select { + case <-ctx.Done(): + // ctx.Done will perform the following two actions: + // 1. Stop listening. + // 2. Cancel all currently established connections. + // Whereas opts.StopListening will only stop listening. + case <-opts.StopListening: + } + log.Tracef("recv server close event") + once.Do(closeListener) + }() return s.serveTCP(ctx, ln, opts) } @@ -386,11 +402,10 @@ func getPassedListener(network, address string) (interface{}, error) { // ListenFd is the listener fd. type ListenFd struct { - OriginalListenCloser io.Closer - Fd uintptr - Name string - Network string - Address string + Fd uintptr + Name string + Network string + Address string } // inheritListeners stores the listener according to start listenfd and number of listenfd passed @@ -460,11 +475,10 @@ func getPacketConnFd(c net.PacketConn) (*ListenFd, error) { return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err) } return &ListenFd{ - OriginalListenCloser: c, - Fd: lnFd, - Name: "a udp listener fd", - Network: c.LocalAddr().Network(), - Address: c.LocalAddr().String(), + Fd: lnFd, + Name: "a udp listener fd", + Network: c.LocalAddr().Network(), + Address: c.LocalAddr().String(), }, nil } @@ -478,11 +492,10 @@ func getListenerFd(ln net.Listener) (*ListenFd, error) { return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err) } return &ListenFd{ - OriginalListenCloser: ln, - Fd: fd, - Name: "a tcp listener fd", - Network: ln.Addr().Network(), - Address: ln.Addr().String(), + Fd: fd, + Name: "a tcp listener fd", + Network: ln.Addr().Network(), + Address: ln.Addr().String(), }, nil } diff --git a/transport/server_transport_tcp.go b/transport/server_transport_tcp.go index ec91715..b930ee6 100644 --- a/transport/server_transport_tcp.go +++ b/transport/server_transport_tcp.go @@ -79,16 +79,6 @@ func createRoutinePool(size int) *ants.PoolWithFunc { } func (s *serverTransport) serveTCP(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { - var once sync.Once - closeListener := func() { ln.Close() } - defer once.Do(closeListener) - // Create a goroutine to watch ctx.Done() channel. - // Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection. - go func() { - <-ctx.Done() - log.Tracef("recv server close event") - once.Do(closeListener) - }() // Create a goroutine pool if ServerAsync enabled. var pool *ants.PoolWithFunc if opts.ServerAsync { diff --git a/transport/server_transport_test.go b/transport/server_transport_test.go index 7596237..fd5ad20 100644 --- a/transport/server_transport_test.go +++ b/transport/server_transport_test.go @@ -1024,3 +1024,24 @@ func TestListenAndServeTLSFail(t *testing.T) { transport.WithListener(ln), )) } + +func TestListenAndServeWithStopListener(t *testing.T) { + s := transport.NewServerTransport() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.Nil(t, err) + ch := make(chan struct{}) + require.Nil(t, s.ListenAndServe(ctx, + transport.WithListenNetwork("tcp"), + transport.WithServerFramerBuilder(&framerBuilder{}), + transport.WithListener(ln), + transport.WithStopListening(ch), + )) + _, err = net.Dial("tcp", ln.Addr().String()) + require.Nil(t, err) + close(ch) + time.Sleep(time.Millisecond) + _, err = net.Dial("tcp", ln.Addr().String()) + require.NotNil(t, err) +} diff --git a/transport/tnet/client_transport_test.go b/transport/tnet/client_transport_test.go index 503e3f9..e3b807b 100644 --- a/transport/tnet/client_transport_test.go +++ b/transport/tnet/client_transport_test.go @@ -50,19 +50,6 @@ func TestDial(t *testing.T) { return assert.Contains(t, err.Error(), "unknown network") }, }, - { - name: "invalid idle timeout", - opts: &connpool.DialOptions{ - CACertFile: "", - Network: "tcp", - Address: l.Addr().String(), - IdleTimeout: -1, - }, - want: nil, - wantErr: func(t assert.TestingT, err error, msg ...interface{}) bool { - return assert.Contains(t, err.Error(), "delay time is too short") - }, - }, { name: "wrong CACertFile and TLSServerName ", opts: &connpool.DialOptions{ diff --git a/transport/tnet/server_transport_tcp.go b/transport/tnet/server_transport_tcp.go index ec5ab4f..442ce85 100644 --- a/transport/tnet/server_transport_tcp.go +++ b/transport/tnet/server_transport_tcp.go @@ -154,6 +154,10 @@ func (s *serverTransport) startService( pool *ants.PoolWithFunc, opts *transport.ListenServeOptions, ) error { + go func() { + <-opts.StopListening + listener.Close() + }() tnetOpts := []tnet.Option{ tnet.WithOnTCPOpened(func(conn tnet.Conn) error { tc := s.onConnOpened(conn, pool, opts) diff --git a/trpc_util.go b/trpc_util.go index ae2e4d0..1aaf73e 100644 --- a/trpc_util.go +++ b/trpc_util.go @@ -16,7 +16,6 @@ package trpc import ( "context" "net" - "os" "runtime" "sync" "time" @@ -275,58 +274,6 @@ func Go(ctx context.Context, timeout time.Duration, handler func(context.Context return DefaultGoer.Go(ctx, timeout, handler) } -// expandEnv looks for ${var} in s and replaces them with value of the -// corresponding environment variable. -// $var is considered invalid. -// It's not like os.ExpandEnv which will handle both ${var} and $var. -// Since configurations like password for redis/mysql may contain $, this -// method is needed. -func expandEnv(s string) string { - var buf []byte - i := 0 - for j := 0; j < len(s); j++ { - if s[j] == '$' && j+2 < len(s) && s[j+1] == '{' { // only ${var} instead of $var is valid - if buf == nil { - buf = make([]byte, 0, 2*len(s)) - } - buf = append(buf, s[i:j]...) - name, w := getEnvName(s[j+1:]) - if name == "" && w > 0 { - // invalid matching, remove the $ - } else if name == "" { - buf = append(buf, s[j]) // keep the $ - } else { - buf = append(buf, os.Getenv(name)...) - } - j += w - i = j + 1 - } - } - if buf == nil { - return s - } - return string(buf) + s[i:] -} - -// getEnvName gets env name, that is, var from ${var}. -// And content of var and its len will be returned. -func getEnvName(s string) (string, int) { - // look for right curly bracket '}' - // it's guaranteed that the first char is '{' and the string has at least two char - for i := 1; i < len(s); i++ { - if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx" - return "", 0 // encounter invalid char, keep the $ - } - if s[i] == '}' { - if i == 1 { // ${} - return "", 2 // remove ${} - } - return s[1:i], i + 1 - } - } - return "", 0 // no },keep the $ -} - // --------------- the following code is IP Config related -----------------// // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic.