From 46fd37688d34ae787494e7018bae54c9ddfd81fd Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 22 Jan 2025 12:23:38 +0800 Subject: [PATCH] Setup the test suite Signed-off-by: JmPotato --- client/client.go | 22 +++++++ client/clients/router/client.go | 5 +- client/clients/router/request.go | 6 -- client/errs/errno.go | 1 - client/inner_client.go | 8 ++- tests/integrations/client/client_test.go | 82 ++++++++++++++++++++++++ tests/integrations/go.mod | 2 +- 7 files changed, 113 insertions(+), 13 deletions(-) diff --git a/client/client.go b/client/client.go index 8b21b17169e..5ead9b5443d 100644 --- a/client/client.go +++ b/client/client.go @@ -570,6 +570,16 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } +// EnableRouterClient enables the router client. +// This is only for test currently. +func (c *client) EnableRouterClient() { + c.inner.enableRouterClient.Store(true) +} + +func (c *client) isRouterClientEnabled() bool { + return c.inner.enableRouterClient.Load() +} + // GetRegionFromMember implements the RPCClient interface. func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { @@ -620,6 +630,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -660,6 +674,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetPrevRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -700,6 +718,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 8d0cfd64d2d..8bd44b8b6a3 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -227,10 +227,7 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request } else if req.id != 0 { id = req.id } - region, ok := resp.RegionsById[id] - if !ok { - err = errs.ErrClientRegionNotFound.FastGenByArgs(id) - } else { + if region, ok := resp.RegionsById[id]; ok { req.region = ConvertToRegion(region) } req.tryDone(err) diff --git a/client/clients/router/request.go b/client/clients/router/request.go index 2e1c2e97aa5..cc1ada0a729 100644 --- a/client/clients/router/request.go +++ b/client/clients/router/request.go @@ -80,12 +80,6 @@ func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOp return req.wait() } -// GetRegionFromMember implements the Client interface. -func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) { - // Before we support the follower stream connection, this method is equivalent to `GetRegion`. - return c.GetRegion(ctx, key, opts...) -} - // GetPrevRegion implements the Client interface. func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { req := c.newRequest(ctx) diff --git a/client/errs/errno.go b/client/errs/errno.go index 8f81d2d6777..99a426d0776 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,7 +70,6 @@ var ( ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) - ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound")) ) // grpcutil errors diff --git a/client/inner_client.go b/client/inner_client.go index 269c2330f8e..464fd413e25 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -46,7 +47,12 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher - routerClient *router.Cli + // The router client is used to get the region info via the streaming gRPC, + // this flag is used to control whether to enable it, currently only used + // in the test. + enableRouterClient atomic.Bool + routerClient *router.Cli + // For service mode switching. serviceModeKeeper diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 91a6d44943e..b08a174e31f 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/goleak" + "golang.org/x/exp/rand" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -1025,12 +1026,18 @@ type clientTestSuite struct { grpcPDClient pdpb.PDClient regionHeartbeat pdpb.PD_RegionHeartbeatClient reportBucket pdpb.PD_ReportBucketsClient + + enableRouterClient bool } func TestClientTestSuite(t *testing.T) { suite.Run(t, new(clientTestSuite)) } +func TestClientTestSuiteWithRouterClient(t *testing.T) { + suite.Run(t, &clientTestSuite{enableRouterClient: true}) +} + func (suite *clientTestSuite) SetupSuite() { var err error re := suite.Require() @@ -1044,6 +1051,9 @@ func (suite *clientTestSuite) SetupSuite() { suite.ctx, suite.clean = context.WithCancel(context.Background()) suite.client = setupCli(suite.ctx, re, suite.srv.GetEndpoints()) + if suite.enableRouterClient { + suite.client.(interface{ EnableRouterClient() }).EnableRouterClient() + } suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) re.NoError(err) @@ -1337,6 +1347,78 @@ func (suite *clientTestSuite) TestGetRegionByID() { }) } +func (suite *clientTestSuite) TestGetRegionConcurrently() { + re := suite.Require() + suite.srv.DirectlyGetRaftCluster().ResetRegionCache() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + StartKey: []byte("a"), + EndKey: []byte("b"), + Peers: peers, + } + req := &pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + } + err := suite.regionHeartbeat.Send(req) + re.NoError(err) + + const concurrency = 1000 + + wg := sync.WaitGroup{} + wg.Add(concurrency) + for range concurrency { + go func() { + defer wg.Done() + switch rand.Intn(3) { + case 0: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(ctx, []byte("a")) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 1: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetPrevRegion(ctx, []byte("b")) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 2: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegionByID(ctx, regionID) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + } + }() + } + wg.Wait() +} + func (suite *clientTestSuite) TestGetStore() { re := suite.Require() cluster := suite.srv.GetRaftCluster() diff --git a/tests/integrations/go.mod b/tests/integrations/go.mod index fca5b54bb07..40c5350c18c 100644 --- a/tests/integrations/go.mod +++ b/tests/integrations/go.mod @@ -25,6 +25,7 @@ require ( go.etcd.io/etcd/client/v3 v3.5.15 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 + golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 google.golang.org/grpc v1.62.1 gorm.io/driver/mysql v1.4.5 gorm.io/gorm v1.24.3 @@ -187,7 +188,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.24.0 // indirect - golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect