From 82ee6e836cccbc3a881222c7a3ad1145df661072 Mon Sep 17 00:00:00 2001 From: YoungFr <43751910+YoungFr@users.noreply.github.com> Date: Wed, 15 May 2024 14:49:53 +0800 Subject: [PATCH] {client, naming}: allow selector to define its own net.Addr parser (#176) This is used to avoid unnecessary addr parse which is commonly used in trpc-database. To properly update the DSN library, we need to introduce this feature into the open-source tRPC-Go. --- client/client.go | 15 ++++++++-- client/client_test.go | 63 +++++++++++++++++++++++++++++++++++++++++ client/stream.go | 2 +- naming/registry/node.go | 4 +++ 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 8860c91..946d58b 100644 --- a/client/client.go +++ b/client/client.go @@ -393,7 +393,7 @@ func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next if err != nil { return OptionsFromContext(ctx).fixTimeout(err) } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) // Start to process the next filter and report. begin := time.Now() @@ -471,11 +471,21 @@ func getNode(opts *Options) (*registry.Node, error) { return node, nil } -func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { +func ensureMsgRemoteAddr( + msg codec.Msg, + network, address string, + parseAddr func(network, address string) net.Addr, +) { // If RemoteAddr has already been set, just return. if msg.RemoteAddr() != nil { return } + + if parseAddr != nil { + msg.WithRemoteAddr(parseAddr(network, address)) + return + } + switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": // Check if address can be parsed as an ip. @@ -484,7 +494,6 @@ func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { return } } - var addr net.Addr switch network { case "tcp", "tcp4", "tcp6": diff --git a/client/client_test.go b/client/client_test.go index 34aaf51..22fb5be 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,6 +16,8 @@ package client_test import ( "context" "errors" + "fmt" + "net" "testing" "time" @@ -409,6 +411,31 @@ func TestFixTimeout(t *testing.T) { }) } +func TestSelectorRemoteAddrUseUserProvidedParser(t *testing.T) { + selector.Register(t.Name(), &fSelector{ + selectNode: func(s string, option ...selector.Option) (*registry.Node, error) { + return ®istry.Node{ + Network: t.Name(), + Address: t.Name(), + ParseAddr: func(network, address string) net.Addr { + return newUnresolvedAddr(network, address) + }}, nil + }, + report: func(node *registry.Node, duration time.Duration, err error) error { return nil }, + }) + fake := "fake" + codec.Register(fake, nil, &fakeCodec{}) + ctx := trpc.BackgroundContext() + require.NotNil(t, client.New().Invoke(ctx, "failbody", nil, + client.WithServiceName(t.Name()), + client.WithProtocol(fake), + client.WithTarget(fmt.Sprintf("%s://xxx", t.Name())))) + addr := trpc.Message(ctx).RemoteAddr() + require.NotNil(t, addr) + require.Equal(t, t.Name(), addr.Network()) + require.Equal(t, t.Name(), addr.String()) +} + type multiplexedTransport struct { require func(context.Context, []byte, ...transport.RoundTripOption) fakeTransport @@ -527,3 +554,39 @@ func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*regi func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error { return nil } + +type fSelector struct { + selectNode func(string, ...selector.Option) (*registry.Node, error) + report func(*registry.Node, time.Duration, error) error +} + +func (s *fSelector) Select(serviceName string, opts ...selector.Option) (*registry.Node, error) { + return s.selectNode(serviceName, opts...) +} + +func (s *fSelector) Report(node *registry.Node, cost time.Duration, err error) error { + return s.report(node, cost, err) +} + +// newUnresolvedAddr returns a new unresolvedAddr. +func newUnresolvedAddr(network, address string) *unresolvedAddr { + return &unresolvedAddr{network: network, address: address} +} + +var _ net.Addr = (*unresolvedAddr)(nil) + +// unresolvedAddr is a net.Addr which returns the original network or address. +type unresolvedAddr struct { + network string + address string +} + +// Network returns the unresolved original network. +func (a *unresolvedAddr) Network() string { + return a.network +} + +// String returns the unresolved original address. +func (a *unresolvedAddr) String() string { + return a.address +} diff --git a/client/stream.go b/client/stream.go index 3a33159..fbfd136 100644 --- a/client/stream.go +++ b/client/stream.go @@ -162,7 +162,7 @@ func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) { report.SelectNodeFail.Incr() return nil, err } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) const invalidCost = -1 opts.Node.set(node, node.Address, invalidCost) if opts.Codec == nil { diff --git a/naming/registry/node.go b/naming/registry/node.go index d4f4d19..c5f0a34 100644 --- a/naming/registry/node.go +++ b/naming/registry/node.go @@ -15,6 +15,7 @@ package registry import ( "fmt" + "net" "time" ) @@ -30,6 +31,9 @@ type Node struct { CostTime time.Duration // 当次请求耗时 EnvKey string // 透传的环境信息 Metadata map[string]interface{} + // ParseAddr should be used to convert Node to net.Addr if it's not nil. + // See test case TestSelectorRemoteAddrUseUserProvidedParser in client package. + ParseAddr func(network, address string) net.Addr } // String returns an abbreviation information of node.