Skip to content

Commit

Permalink
Setup the test suite
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Jan 22, 2025
1 parent 48ff881 commit 46fd376
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 13 deletions.
22 changes: 22 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,16 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e
return minTS.Physical, minTS.Logical, nil
}

// EnableRouterClient enables the router client.
// This is only for test currently.
func (c *client) EnableRouterClient() {
c.inner.enableRouterClient.Store(true)
}

func (c *client) isRouterClientEnabled() bool {
return c.inner.enableRouterClient.Load()
}

// GetRegionFromMember implements the RPCClient interface.
func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
Expand Down Expand Up @@ -620,6 +630,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetRegion(ctx, key, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down Expand Up @@ -660,6 +674,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetPrevRegion(ctx, key, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down Expand Up @@ -700,6 +718,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down
5 changes: 1 addition & 4 deletions client/clients/router/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,7 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request
} else if req.id != 0 {
id = req.id
}
region, ok := resp.RegionsById[id]
if !ok {
err = errs.ErrClientRegionNotFound.FastGenByArgs(id)
} else {
if region, ok := resp.RegionsById[id]; ok {
req.region = ConvertToRegion(region)
}
req.tryDone(err)
Expand Down
6 changes: 0 additions & 6 deletions client/clients/router/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOp
return req.wait()
}

// GetRegionFromMember implements the Client interface.
func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) {
// Before we support the follower stream connection, this method is equivalent to `GetRegion`.
return c.GetRegion(ctx, key, opts...)
}

// GetPrevRegion implements the Client interface.
func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) {
req := c.newRequest(ctx)
Expand Down
1 change: 0 additions & 1 deletion client/errs/errno.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ var (
ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID"))
ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream"))
ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen"))
ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound"))
)

// grpcutil errors
Expand Down
8 changes: 7 additions & 1 deletion client/inner_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/tls"
"sync"
"sync/atomic"
"time"

"go.uber.org/zap"
Expand Down Expand Up @@ -46,7 +47,12 @@ type innerClient struct {
serviceDiscovery sd.ServiceDiscovery
tokenDispatcher *tokenDispatcher

routerClient *router.Cli
// The router client is used to get the region info via the streaming gRPC,
// this flag is used to control whether to enable it, currently only used
// in the test.
enableRouterClient atomic.Bool
routerClient *router.Cli

// For service mode switching.
serviceModeKeeper

Expand Down
82 changes: 82 additions & 0 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/goleak"
"golang.org/x/exp/rand"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -1025,12 +1026,18 @@ type clientTestSuite struct {
grpcPDClient pdpb.PDClient
regionHeartbeat pdpb.PD_RegionHeartbeatClient
reportBucket pdpb.PD_ReportBucketsClient

enableRouterClient bool
}

func TestClientTestSuite(t *testing.T) {
suite.Run(t, new(clientTestSuite))
}

func TestClientTestSuiteWithRouterClient(t *testing.T) {
suite.Run(t, &clientTestSuite{enableRouterClient: true})
}

func (suite *clientTestSuite) SetupSuite() {
var err error
re := suite.Require()
Expand All @@ -1044,6 +1051,9 @@ func (suite *clientTestSuite) SetupSuite() {

suite.ctx, suite.clean = context.WithCancel(context.Background())
suite.client = setupCli(suite.ctx, re, suite.srv.GetEndpoints())
if suite.enableRouterClient {
suite.client.(interface{ EnableRouterClient() }).EnableRouterClient()
}

suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx)
re.NoError(err)
Expand Down Expand Up @@ -1337,6 +1347,78 @@ func (suite *clientTestSuite) TestGetRegionByID() {
})
}

func (suite *clientTestSuite) TestGetRegionConcurrently() {
re := suite.Require()
suite.srv.DirectlyGetRaftCluster().ResetRegionCache()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

regionID := regionIDAllocator.alloc()
region := &metapb.Region{
Id: regionID,
RegionEpoch: &metapb.RegionEpoch{
ConfVer: 1,
Version: 1,
},
StartKey: []byte("a"),
EndKey: []byte("b"),
Peers: peers,
}
req := &pdpb.RegionHeartbeatRequest{
Header: newHeader(),
Region: region,
Leader: peers[0],
}
err := suite.regionHeartbeat.Send(req)
re.NoError(err)

const concurrency = 1000

wg := sync.WaitGroup{}
wg.Add(concurrency)
for range concurrency {
go func() {
defer wg.Done()
switch rand.Intn(3) {
case 0:
testutil.Eventually(re, func() bool {
r, err := suite.client.GetRegion(ctx, []byte("a"))
re.NoError(err)
if r == nil {
return false
}
return reflect.DeepEqual(region, r.Meta) &&
reflect.DeepEqual(peers[0], r.Leader) &&
r.Buckets == nil
})
case 1:
testutil.Eventually(re, func() bool {
r, err := suite.client.GetPrevRegion(ctx, []byte("b"))
re.NoError(err)
if r == nil {
return false
}
return reflect.DeepEqual(region, r.Meta) &&
reflect.DeepEqual(peers[0], r.Leader) &&
r.Buckets == nil
})
case 2:
testutil.Eventually(re, func() bool {
r, err := suite.client.GetRegionByID(ctx, regionID)
re.NoError(err)
if r == nil {
return false
}
return reflect.DeepEqual(region, r.Meta) &&
reflect.DeepEqual(peers[0], r.Leader) &&
r.Buckets == nil
})
}
}()
}
wg.Wait()
}

func (suite *clientTestSuite) TestGetStore() {
re := suite.Require()
cluster := suite.srv.GetRaftCluster()
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
go.etcd.io/etcd/client/v3 v3.5.15
go.uber.org/goleak v1.3.0
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4
google.golang.org/grpc v1.62.1
gorm.io/driver/mysql v1.4.5
gorm.io/gorm v1.24.3
Expand Down Expand Up @@ -187,7 +188,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.24.0 // indirect
golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
Expand Down

0 comments on commit 46fd376

Please sign in to comment.