From b6ac607bf128dec3865cabc21a1221505ce1b468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 9 Nov 2024 15:18:24 +0900 Subject: [PATCH] =?UTF-8?q?optimize(sign):=20=E7=B2=BE=E7=AE=80=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/base.go | 4 +-- client/sign/http.go | 74 +++++++++++++++-------------------------- client/sign/provider.go | 10 ++++++ client/sign/sign.go | 56 ++++++++++++++----------------- 4 files changed, 62 insertions(+), 82 deletions(-) diff --git a/client/base.go b/client/base.go index c0520650..33dbaf66 100644 --- a/client/base.go +++ b/client/base.go @@ -46,9 +46,7 @@ func NewClient(uin uint32, appInfo *auth.AppInfo, signUrl ...string) *QQClient { alive: true, UA: "LagrangeGo qq/" + appInfo.PackageSign, } - client.signProvider = sign.NewSignClient(appInfo, func(s string) { - client.debug(s) - }, signUrl...) + client.signProvider = sign.NewSigner(appInfo, client.debug, signUrl...) client.transport.Version = appInfo client.transport.Sig.D2Key = make([]byte, 0, 16) client.highwaySession.Uin = &client.transport.Sig.Uin diff --git a/client/sign/http.go b/client/sign/http.go index 6fa99242..ac3c522d 100644 --- a/client/sign/http.go +++ b/client/sign/http.go @@ -3,12 +3,13 @@ package sign import ( "context" "encoding/json" - "errors" "fmt" "io" "net/http" "net/url" "time" + + "github.com/pkg/errors" ) var ( @@ -69,10 +70,10 @@ func containSignPKG(cmd string) bool { return ok } -func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration, target interface{}, header http.Header) error { +func httpGet[T any](rawUrl string, queryParams map[string]string, timeout time.Duration, header http.Header) (target T, err error) { u, err := url.Parse(rawUrl) if err != nil { - return fmt.Errorf("failed to parse URL: %w", err) + return } q := u.Query() @@ -86,43 +87,20 @@ func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return fmt.Errorf("failed to create GET request: %w", err) + return } for k, vs := range header { for _, v := range vs { req.Header.Add(k, v) } } - resp, err := httpClient.Do(req) - if err != nil { - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return fmt.Errorf("request timed out") - } - resp, err = httpClient.Do(req) - if err != nil { - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return fmt.Errorf("request timed out") - } - return fmt.Errorf("failed to perform GET request: %w", err) - } - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - if err := json.NewDecoder(resp.Body).Decode(target); err != nil { - return fmt.Errorf("failed to unmarshal JSON response: %w", err) - } - - return nil + return doHTTP[T](ctx, req) } -func httpPost(rawUrl string, body io.Reader, timeout time.Duration, target interface{}, header http.Header) error { +func httpPost[T any](rawUrl string, body io.Reader, timeout time.Duration, header http.Header) (target T, err error) { u, err := url.Parse(rawUrl) if err != nil { - return fmt.Errorf("failed to parse URL: %w", err) + return } ctx, cancel := context.WithTimeout(context.Background(), timeout) @@ -130,7 +108,8 @@ func httpPost(rawUrl string, body io.Reader, timeout time.Duration, target inter req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), body) if err != nil { - return fmt.Errorf("failed to create POST request: %w", err) + err = errors.Wrap(err, "create POST") + return } for k, vs := range header { for _, v := range vs { @@ -138,38 +117,37 @@ func httpPost(rawUrl string, body io.Reader, timeout time.Duration, target inter } } req.Header.Add("Content-Type", "application/json") + return doHTTP[T](ctx, req) +} + +func doHTTP[T any](ctx context.Context, req *http.Request) (target T, err error) { resp, err := http.DefaultClient.Do(req) if err != nil { if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return fmt.Errorf("request timed out") + err = ctx.Err() + return } resp, err = httpClient.Do(req) if err != nil { if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return fmt.Errorf("request timed out") + err = ctx.Err() + return } - return fmt.Errorf("failed to perform POST request: %w", err) + err = errors.Wrap(err, "perform POST") + return } } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + err = fmt.Errorf("unexpected status code: %d", resp.StatusCode) + return } - if err := json.NewDecoder(resp.Body).Decode(target); err != nil { - return fmt.Errorf("failed to unmarshal JSON response: %w", err) + if err = json.NewDecoder(resp.Body).Decode(&target); err != nil { + err = errors.Wrap(err, "unmarshal response") + return } - return nil -} - -type Response struct { - Platform string `json:"platform"` - Version string `json:"version"` - Value struct { - Sign string `json:"sign"` - Extra string `json:"extra"` - Token string `json:"token"` - } `json:"value"` + return } diff --git a/client/sign/provider.go b/client/sign/provider.go index 77995f8d..b7d7b18c 100644 --- a/client/sign/provider.go +++ b/client/sign/provider.go @@ -1,5 +1,15 @@ package sign +type Response struct { + Platform string `json:"platform"` + Version string `json:"version"` + Value struct { + Sign string `json:"sign"` + Extra string `json:"extra"` + Token string `json:"token"` + } `json:"value"` +} + type Provider interface { Sign(cmd string, seq uint32, data []byte) (*Response, error) } diff --git a/client/sign/sign.go b/client/sign/sign.go index 9dae64dc..9c5e934f 100644 --- a/client/sign/sign.go +++ b/client/sign/sign.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "math" "net/http" "sort" "strconv" @@ -12,43 +13,38 @@ import ( "sync/atomic" "time" - "github.com/LagrangeDev/LagrangeGo/utils" - "github.com/LagrangeDev/LagrangeGo/client/auth" + "github.com/LagrangeDev/LagrangeGo/utils" ) type ( - Status uint32 - Client struct { lock sync.RWMutex signCount atomic.Uint32 - instances []*Instance + instances []*remote app *auth.AppInfo httpClient *http.Client extraHeaders http.Header - log func(string) + log func(string, ...any) lastTestTime time.Time } - Instance struct { + remote struct { server string latency atomic.Uint32 - status atomic.Uint32 } ) const ( - OK Status = iota - Down + serverLatencyDown = math.MaxUint32 ) -var VersionMismatchError = errors.New("sign version mismatch") +var ErrVersionMismatch = errors.New("sign version mismatch") -func NewSignClient(appinfo *auth.AppInfo, log func(string), signServers ...string) *Client { +func NewSigner(appinfo *auth.AppInfo, log func(string, ...any), signServers ...string) *Client { client := &Client{ - instances: utils.Map[string, *Instance](signServers, func(s string) *Instance { - return &Instance{server: s} + instances: utils.Map(signServers, func(s string) *remote { + return &remote{server: s} }), app: appinfo, httpClient: &http.Client{}, @@ -72,24 +68,24 @@ func (c *Client) AddRequestHeader(header map[string]string) { func (c *Client) AddSignServer(signServers ...string) { c.lock.Lock() defer c.lock.Unlock() - c.instances = append(c.instances, utils.Map[string, *Instance](signServers, func(s string) *Instance { - return &Instance{server: s} + c.instances = append(c.instances, utils.Map[string, *remote](signServers, func(s string) *remote { + return &remote{server: s} })...) } func (c *Client) GetSignServer() []string { c.lock.RLock() defer c.lock.RUnlock() - return utils.Map[*Instance, string](c.instances, func(sign *Instance) string { + return utils.Map(c.instances, func(sign *remote) string { return sign.server }) } -func (c *Client) getAvailableSign() *Instance { +func (c *Client) getAvailableSign() *remote { c.lock.RLock() defer c.lock.RUnlock() for _, i := range c.instances { - if Status(i.status.Load()) == OK { + if i.latency.Load() < serverLatencyDown { return i } } @@ -116,10 +112,10 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) { if sign := c.getAvailableSign(); sign != nil { resp, err := sign.sign(cmd, seq, data, c.extraHeaders) if err != nil { - sign.status.Store(uint32(Down)) + sign.latency.Store(serverLatencyDown) continue } else if resp.Version != c.app.CurrentVersion && resp.Value.Extra != c.app.SignExtraHexLower && resp.Value.Extra != c.app.SignExtraHexUpper { - return nil, VersionMismatchError + return nil, ErrVersionMismatch } c.log(fmt.Sprintf("signed for [%s:%d](%dms)", cmd, seq, time.Now().UnixMilli()-startTime)) @@ -137,6 +133,7 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) { func (c *Client) test() { c.lock.Lock() if time.Now().Before(c.lastTestTime.Add(10 * time.Minute)) { + c.lock.Unlock() return } c.lastTestTime = time.Now() @@ -147,22 +144,21 @@ func (c *Client) test() { c.sortByLatency() } -func (i *Instance) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) { +func (i *remote) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) { if !containSignPKG(cmd) { return nil, nil } - resp := Response{} sb := strings.Builder{} sb.WriteString(`{"cmd":"` + cmd + `",`) sb.WriteString(`"seq":` + strconv.Itoa(int(seq)) + `,`) sb.WriteString(`"src":"` + fmt.Sprintf("%x", buf) + `"}`) - err := httpPost(i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, &resp, header) + resp, err := httpPost[Response](i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, header) if err != nil || resp.Value.Sign == "" { - err := httpGet(i.server, map[string]string{ + resp, err = httpGet[Response](i.server, map[string]string{ "cmd": cmd, "seq": strconv.Itoa(int(seq)), "src": fmt.Sprintf("%x", buf), - }, 8*time.Second, &resp, header) + }, 8*time.Second, header) if err != nil { return nil, err } @@ -170,19 +166,17 @@ func (i *Instance) sign(cmd string, seq uint32, buf []byte, header http.Header) return &resp, nil } -func (i *Instance) test() { +func (i *remote) test() { startTime := time.Now().UnixMilli() resp, err := i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil) if err != nil || resp.Value.Sign == "" { - i.status.Store(uint32(Down)) - i.latency.Store(99999) + i.latency.Store(serverLatencyDown) return } // 有长连接的情况,取两次平均值 resp, err = i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil) if err != nil || resp.Value.Sign == "" { - i.status.Store(uint32(Down)) - i.latency.Store(99999) + i.latency.Store(serverLatencyDown) return } // 粗略计算,应该足够了