Skip to content

Commit

Permalink
optimize(sign): 精简代码
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Nov 9, 2024
1 parent c4ada2b commit b6ac607
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 82 deletions.
4 changes: 1 addition & 3 deletions client/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ func NewClient(uin uint32, appInfo *auth.AppInfo, signUrl ...string) *QQClient {
alive: true,
UA: "LagrangeGo qq/" + appInfo.PackageSign,
}
client.signProvider = sign.NewSignClient(appInfo, func(s string) {
client.debug(s)
}, signUrl...)
client.signProvider = sign.NewSigner(appInfo, client.debug, signUrl...)
client.transport.Version = appInfo
client.transport.Sig.D2Key = make([]byte, 0, 16)
client.highwaySession.Uin = &client.transport.Sig.Uin
Expand Down
74 changes: 26 additions & 48 deletions client/sign/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package sign
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/pkg/errors"
)

var (
Expand Down Expand Up @@ -69,10 +70,10 @@ func containSignPKG(cmd string) bool {
return ok
}

func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration, target interface{}, header http.Header) error {
func httpGet[T any](rawUrl string, queryParams map[string]string, timeout time.Duration, header http.Header) (target T, err error) {
u, err := url.Parse(rawUrl)
if err != nil {
return fmt.Errorf("failed to parse URL: %w", err)
return
}

q := u.Query()
Expand All @@ -86,90 +87,67 @@ func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration

req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return fmt.Errorf("failed to create GET request: %w", err)
return
}
for k, vs := range header {
for _, v := range vs {
req.Header.Add(k, v)
}
}
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
}
resp, err = httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
}
return fmt.Errorf("failed to perform GET request: %w", err)
}
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("failed to unmarshal JSON response: %w", err)
}

return nil
return doHTTP[T](ctx, req)
}

func httpPost(rawUrl string, body io.Reader, timeout time.Duration, target interface{}, header http.Header) error {
func httpPost[T any](rawUrl string, body io.Reader, timeout time.Duration, header http.Header) (target T, err error) {
u, err := url.Parse(rawUrl)
if err != nil {
return fmt.Errorf("failed to parse URL: %w", err)
return
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), body)
if err != nil {
return fmt.Errorf("failed to create POST request: %w", err)
err = errors.Wrap(err, "create POST")
return
}
for k, vs := range header {
for _, v := range vs {
req.Header.Add(k, v)
}
}
req.Header.Add("Content-Type", "application/json")
return doHTTP[T](ctx, req)
}

func doHTTP[T any](ctx context.Context, req *http.Request) (target T, err error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
err = ctx.Err()
return
}
resp, err = httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
err = ctx.Err()
return
}
return fmt.Errorf("failed to perform POST request: %w", err)
err = errors.Wrap(err, "perform POST")
return
}
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return
}

if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("failed to unmarshal JSON response: %w", err)
if err = json.NewDecoder(resp.Body).Decode(&target); err != nil {
err = errors.Wrap(err, "unmarshal response")
return
}

return nil
}

type Response struct {
Platform string `json:"platform"`
Version string `json:"version"`
Value struct {
Sign string `json:"sign"`
Extra string `json:"extra"`
Token string `json:"token"`
} `json:"value"`
return
}
10 changes: 10 additions & 0 deletions client/sign/provider.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
package sign

type Response struct {
Platform string `json:"platform"`
Version string `json:"version"`
Value struct {
Sign string `json:"sign"`
Extra string `json:"extra"`
Token string `json:"token"`
} `json:"value"`
}

type Provider interface {
Sign(cmd string, seq uint32, data []byte) (*Response, error)
}
56 changes: 25 additions & 31 deletions client/sign/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
Expand All @@ -12,43 +13,38 @@ import (
"sync/atomic"
"time"

"github.com/LagrangeDev/LagrangeGo/utils"

"github.com/LagrangeDev/LagrangeGo/client/auth"
"github.com/LagrangeDev/LagrangeGo/utils"
)

type (
Status uint32

Client struct {
lock sync.RWMutex
signCount atomic.Uint32
instances []*Instance
instances []*remote
app *auth.AppInfo
httpClient *http.Client
extraHeaders http.Header
log func(string)
log func(string, ...any)
lastTestTime time.Time
}

Instance struct {
remote struct {
server string
latency atomic.Uint32
status atomic.Uint32
}
)

const (
OK Status = iota
Down
serverLatencyDown = math.MaxUint32
)

var VersionMismatchError = errors.New("sign version mismatch")
var ErrVersionMismatch = errors.New("sign version mismatch")

func NewSignClient(appinfo *auth.AppInfo, log func(string), signServers ...string) *Client {
func NewSigner(appinfo *auth.AppInfo, log func(string, ...any), signServers ...string) *Client {
client := &Client{
instances: utils.Map[string, *Instance](signServers, func(s string) *Instance {
return &Instance{server: s}
instances: utils.Map(signServers, func(s string) *remote {
return &remote{server: s}
}),
app: appinfo,
httpClient: &http.Client{},
Expand All @@ -72,24 +68,24 @@ func (c *Client) AddRequestHeader(header map[string]string) {
func (c *Client) AddSignServer(signServers ...string) {
c.lock.Lock()
defer c.lock.Unlock()
c.instances = append(c.instances, utils.Map[string, *Instance](signServers, func(s string) *Instance {
return &Instance{server: s}
c.instances = append(c.instances, utils.Map[string, *remote](signServers, func(s string) *remote {
return &remote{server: s}
})...)
}

func (c *Client) GetSignServer() []string {
c.lock.RLock()
defer c.lock.RUnlock()
return utils.Map[*Instance, string](c.instances, func(sign *Instance) string {
return utils.Map(c.instances, func(sign *remote) string {
return sign.server
})
}

func (c *Client) getAvailableSign() *Instance {
func (c *Client) getAvailableSign() *remote {
c.lock.RLock()
defer c.lock.RUnlock()
for _, i := range c.instances {
if Status(i.status.Load()) == OK {
if i.latency.Load() < serverLatencyDown {
return i
}
}
Expand All @@ -116,10 +112,10 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) {
if sign := c.getAvailableSign(); sign != nil {
resp, err := sign.sign(cmd, seq, data, c.extraHeaders)
if err != nil {
sign.status.Store(uint32(Down))
sign.latency.Store(serverLatencyDown)
continue
} else if resp.Version != c.app.CurrentVersion && resp.Value.Extra != c.app.SignExtraHexLower && resp.Value.Extra != c.app.SignExtraHexUpper {
return nil, VersionMismatchError
return nil, ErrVersionMismatch
}
c.log(fmt.Sprintf("signed for [%s:%d](%dms)",
cmd, seq, time.Now().UnixMilli()-startTime))
Expand All @@ -137,6 +133,7 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) {
func (c *Client) test() {
c.lock.Lock()
if time.Now().Before(c.lastTestTime.Add(10 * time.Minute)) {
c.lock.Unlock()
return
}
c.lastTestTime = time.Now()
Expand All @@ -147,42 +144,39 @@ func (c *Client) test() {
c.sortByLatency()
}

func (i *Instance) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) {
func (i *remote) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) {
if !containSignPKG(cmd) {
return nil, nil
}
resp := Response{}
sb := strings.Builder{}
sb.WriteString(`{"cmd":"` + cmd + `",`)
sb.WriteString(`"seq":` + strconv.Itoa(int(seq)) + `,`)
sb.WriteString(`"src":"` + fmt.Sprintf("%x", buf) + `"}`)
err := httpPost(i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, &resp, header)
resp, err := httpPost[Response](i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, header)
if err != nil || resp.Value.Sign == "" {
err := httpGet(i.server, map[string]string{
resp, err = httpGet[Response](i.server, map[string]string{
"cmd": cmd,
"seq": strconv.Itoa(int(seq)),
"src": fmt.Sprintf("%x", buf),
}, 8*time.Second, &resp, header)
}, 8*time.Second, header)
if err != nil {
return nil, err
}
}
return &resp, nil
}

func (i *Instance) test() {
func (i *remote) test() {
startTime := time.Now().UnixMilli()
resp, err := i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil)
if err != nil || resp.Value.Sign == "" {
i.status.Store(uint32(Down))
i.latency.Store(99999)
i.latency.Store(serverLatencyDown)
return
}
// 有长连接的情况,取两次平均值
resp, err = i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil)
if err != nil || resp.Value.Sign == "" {
i.status.Store(uint32(Down))
i.latency.Store(99999)
i.latency.Store(serverLatencyDown)
return
}
// 粗略计算,应该足够了
Expand Down

0 comments on commit b6ac607

Please sign in to comment.