Skip to content

Commit

Permalink
{client, naming}: allow selector to define its own net.Addr parser (#176
Browse files Browse the repository at this point in the history
)

This is used to avoid unnecessary addr parse which is commonly used
in trpc-database.

To properly update the DSN library, we need to introduce this feature into
the open-source tRPC-Go.
  • Loading branch information
YoungFr authored May 15, 2024
1 parent f727602 commit 82ee6e8
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
15 changes: 12 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next
if err != nil {
return OptionsFromContext(ctx).fixTimeout(err)
}
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address)
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)

// Start to process the next filter and report.
begin := time.Now()
Expand Down Expand Up @@ -471,11 +471,21 @@ func getNode(opts *Options) (*registry.Node, error) {
return node, nil
}

func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) {
func ensureMsgRemoteAddr(
msg codec.Msg,
network, address string,
parseAddr func(network, address string) net.Addr,
) {
// If RemoteAddr has already been set, just return.
if msg.RemoteAddr() != nil {
return
}

if parseAddr != nil {
msg.WithRemoteAddr(parseAddr(network, address))
return
}

switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
// Check if address can be parsed as an ip.
Expand All @@ -484,7 +494,6 @@ func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) {
return
}
}

var addr net.Addr
switch network {
case "tcp", "tcp4", "tcp6":
Expand Down
63 changes: 63 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package client_test
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"

Expand Down Expand Up @@ -409,6 +411,31 @@ func TestFixTimeout(t *testing.T) {
})
}

func TestSelectorRemoteAddrUseUserProvidedParser(t *testing.T) {
selector.Register(t.Name(), &fSelector{
selectNode: func(s string, option ...selector.Option) (*registry.Node, error) {
return &registry.Node{
Network: t.Name(),
Address: t.Name(),
ParseAddr: func(network, address string) net.Addr {
return newUnresolvedAddr(network, address)
}}, nil
},
report: func(node *registry.Node, duration time.Duration, err error) error { return nil },
})
fake := "fake"
codec.Register(fake, nil, &fakeCodec{})
ctx := trpc.BackgroundContext()
require.NotNil(t, client.New().Invoke(ctx, "failbody", nil,
client.WithServiceName(t.Name()),
client.WithProtocol(fake),
client.WithTarget(fmt.Sprintf("%s://xxx", t.Name()))))
addr := trpc.Message(ctx).RemoteAddr()
require.NotNil(t, addr)
require.Equal(t, t.Name(), addr.Network())
require.Equal(t, t.Name(), addr.String())
}

type multiplexedTransport struct {
require func(context.Context, []byte, ...transport.RoundTripOption)
fakeTransport
Expand Down Expand Up @@ -527,3 +554,39 @@ func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*regi
func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error {
return nil
}

type fSelector struct {
selectNode func(string, ...selector.Option) (*registry.Node, error)
report func(*registry.Node, time.Duration, error) error
}

func (s *fSelector) Select(serviceName string, opts ...selector.Option) (*registry.Node, error) {
return s.selectNode(serviceName, opts...)
}

func (s *fSelector) Report(node *registry.Node, cost time.Duration, err error) error {
return s.report(node, cost, err)
}

// newUnresolvedAddr returns a new unresolvedAddr.
func newUnresolvedAddr(network, address string) *unresolvedAddr {
return &unresolvedAddr{network: network, address: address}
}

var _ net.Addr = (*unresolvedAddr)(nil)

// unresolvedAddr is a net.Addr which returns the original network or address.
type unresolvedAddr struct {
network string
address string
}

// Network returns the unresolved original network.
func (a *unresolvedAddr) Network() string {
return a.network
}

// String returns the unresolved original address.
func (a *unresolvedAddr) String() string {
return a.address
}
2 changes: 1 addition & 1 deletion client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) {
report.SelectNodeFail.Incr()
return nil, err
}
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address)
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)
const invalidCost = -1
opts.Node.set(node, node.Address, invalidCost)
if opts.Codec == nil {
Expand Down
4 changes: 4 additions & 0 deletions naming/registry/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package registry

import (
"fmt"
"net"
"time"
)

Expand All @@ -30,6 +31,9 @@ type Node struct {
CostTime time.Duration // 当次请求耗时
EnvKey string // 透传的环境信息
Metadata map[string]interface{}
// ParseAddr should be used to convert Node to net.Addr if it's not nil.
// See test case TestSelectorRemoteAddrUseUserProvidedParser in client package.
ParseAddr func(network, address string) net.Addr
}

// String returns an abbreviation information of node.
Expand Down

0 comments on commit 82ee6e8

Please sign in to comment.