Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize(sign): 精简代码 #113

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions client/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ func NewClient(uin uint32, appInfo *auth.AppInfo, signUrl ...string) *QQClient {
alive: true,
UA: "LagrangeGo qq/" + appInfo.PackageSign,
}
client.signProvider = sign.NewSignClient(appInfo, func(s string) {
client.debug(s)
}, signUrl...)
client.signProvider = sign.NewSigner(appInfo, client.debug, signUrl...)
client.transport.Version = appInfo
client.transport.Sig.D2Key = make([]byte, 0, 16)
client.highwaySession.Uin = &client.transport.Sig.Uin
Expand Down
74 changes: 26 additions & 48 deletions client/sign/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package sign
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/pkg/errors"
)

var (
Expand Down Expand Up @@ -69,10 +70,10 @@ func containSignPKG(cmd string) bool {
return ok
}

func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration, target interface{}, header http.Header) error {
func httpGet[T any](rawUrl string, queryParams map[string]string, timeout time.Duration, header http.Header) (target T, err error) {
u, err := url.Parse(rawUrl)
if err != nil {
return fmt.Errorf("failed to parse URL: %w", err)
return
}

q := u.Query()
Expand All @@ -86,90 +87,67 @@ func httpGet(rawUrl string, queryParams map[string]string, timeout time.Duration

req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return fmt.Errorf("failed to create GET request: %w", err)
return
}
for k, vs := range header {
for _, v := range vs {
req.Header.Add(k, v)
}
}
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
}
resp, err = httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
}
return fmt.Errorf("failed to perform GET request: %w", err)
}
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("failed to unmarshal JSON response: %w", err)
}

return nil
return doHTTP[T](ctx, req)
}

func httpPost(rawUrl string, body io.Reader, timeout time.Duration, target interface{}, header http.Header) error {
func httpPost[T any](rawUrl string, body io.Reader, timeout time.Duration, header http.Header) (target T, err error) {
u, err := url.Parse(rawUrl)
if err != nil {
return fmt.Errorf("failed to parse URL: %w", err)
return
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), body)
if err != nil {
return fmt.Errorf("failed to create POST request: %w", err)
err = errors.Wrap(err, "create POST")
return
}
for k, vs := range header {
for _, v := range vs {
req.Header.Add(k, v)
}
}
req.Header.Add("Content-Type", "application/json")
return doHTTP[T](ctx, req)
}

func doHTTP[T any](ctx context.Context, req *http.Request) (target T, err error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
err = ctx.Err()
return
}
resp, err = httpClient.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return fmt.Errorf("request timed out")
err = ctx.Err()
return
}
return fmt.Errorf("failed to perform POST request: %w", err)
err = errors.Wrap(err, "perform POST")
return
}
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return
}

if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("failed to unmarshal JSON response: %w", err)
if err = json.NewDecoder(resp.Body).Decode(&target); err != nil {
err = errors.Wrap(err, "unmarshal response")
return
}

return nil
}

type Response struct {
Platform string `json:"platform"`
Version string `json:"version"`
Value struct {
Sign string `json:"sign"`
Extra string `json:"extra"`
Token string `json:"token"`
} `json:"value"`
return
}
10 changes: 10 additions & 0 deletions client/sign/provider.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
package sign

type Response struct {
Platform string `json:"platform"`
Version string `json:"version"`
Value struct {
Sign string `json:"sign"`
Extra string `json:"extra"`
Token string `json:"token"`
} `json:"value"`
}

type Provider interface {
Sign(cmd string, seq uint32, data []byte) (*Response, error)
}
56 changes: 25 additions & 31 deletions client/sign/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
Expand All @@ -12,43 +13,38 @@ import (
"sync/atomic"
"time"

"github.com/LagrangeDev/LagrangeGo/utils"

"github.com/LagrangeDev/LagrangeGo/client/auth"
"github.com/LagrangeDev/LagrangeGo/utils"
)

type (
Status uint32

Client struct {
lock sync.RWMutex
signCount atomic.Uint32
instances []*Instance
instances []*remote
app *auth.AppInfo
httpClient *http.Client
extraHeaders http.Header
log func(string)
log func(string, ...any)
lastTestTime time.Time
}

Instance struct {
remote struct {
server string
latency atomic.Uint32
status atomic.Uint32
}
)

const (
OK Status = iota
Down
serverLatencyDown = math.MaxUint32
)

var VersionMismatchError = errors.New("sign version mismatch")
var ErrVersionMismatch = errors.New("sign version mismatch")

func NewSignClient(appinfo *auth.AppInfo, log func(string), signServers ...string) *Client {
func NewSigner(appinfo *auth.AppInfo, log func(string, ...any), signServers ...string) *Client {
client := &Client{
instances: utils.Map[string, *Instance](signServers, func(s string) *Instance {
return &Instance{server: s}
instances: utils.Map(signServers, func(s string) *remote {
return &remote{server: s}
}),
app: appinfo,
httpClient: &http.Client{},
Expand All @@ -72,24 +68,24 @@ func (c *Client) AddRequestHeader(header map[string]string) {
func (c *Client) AddSignServer(signServers ...string) {
c.lock.Lock()
defer c.lock.Unlock()
c.instances = append(c.instances, utils.Map[string, *Instance](signServers, func(s string) *Instance {
return &Instance{server: s}
c.instances = append(c.instances, utils.Map[string, *remote](signServers, func(s string) *remote {
return &remote{server: s}
})...)
}

func (c *Client) GetSignServer() []string {
c.lock.RLock()
defer c.lock.RUnlock()
return utils.Map[*Instance, string](c.instances, func(sign *Instance) string {
return utils.Map(c.instances, func(sign *remote) string {
return sign.server
})
}

func (c *Client) getAvailableSign() *Instance {
func (c *Client) getAvailableSign() *remote {
c.lock.RLock()
defer c.lock.RUnlock()
for _, i := range c.instances {
if Status(i.status.Load()) == OK {
if i.latency.Load() < serverLatencyDown {
return i
}
}
Expand All @@ -116,10 +112,10 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) {
if sign := c.getAvailableSign(); sign != nil {
resp, err := sign.sign(cmd, seq, data, c.extraHeaders)
if err != nil {
sign.status.Store(uint32(Down))
sign.latency.Store(serverLatencyDown)
continue
} else if resp.Version != c.app.CurrentVersion && resp.Value.Extra != c.app.SignExtraHexLower && resp.Value.Extra != c.app.SignExtraHexUpper {
return nil, VersionMismatchError
return nil, ErrVersionMismatch
}
c.log(fmt.Sprintf("signed for [%s:%d](%dms)",
cmd, seq, time.Now().UnixMilli()-startTime))
Expand All @@ -137,6 +133,7 @@ func (c *Client) Sign(cmd string, seq uint32, data []byte) (*Response, error) {
func (c *Client) test() {
c.lock.Lock()
if time.Now().Before(c.lastTestTime.Add(10 * time.Minute)) {
c.lock.Unlock()
return
}
c.lastTestTime = time.Now()
Expand All @@ -147,42 +144,39 @@ func (c *Client) test() {
c.sortByLatency()
}

func (i *Instance) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) {
func (i *remote) sign(cmd string, seq uint32, buf []byte, header http.Header) (*Response, error) {
if !containSignPKG(cmd) {
return nil, nil
}
resp := Response{}
sb := strings.Builder{}
sb.WriteString(`{"cmd":"` + cmd + `",`)
sb.WriteString(`"seq":` + strconv.Itoa(int(seq)) + `,`)
sb.WriteString(`"src":"` + fmt.Sprintf("%x", buf) + `"}`)
err := httpPost(i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, &resp, header)
resp, err := httpPost[Response](i.server, bytes.NewReader(utils.S2B(sb.String())), 8*time.Second, header)
if err != nil || resp.Value.Sign == "" {
err := httpGet(i.server, map[string]string{
resp, err = httpGet[Response](i.server, map[string]string{
"cmd": cmd,
"seq": strconv.Itoa(int(seq)),
"src": fmt.Sprintf("%x", buf),
}, 8*time.Second, &resp, header)
}, 8*time.Second, header)
if err != nil {
return nil, err
}
}
return &resp, nil
}

func (i *Instance) test() {
func (i *remote) test() {
startTime := time.Now().UnixMilli()
resp, err := i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil)
if err != nil || resp.Value.Sign == "" {
i.status.Store(uint32(Down))
i.latency.Store(99999)
i.latency.Store(serverLatencyDown)
return
}
// 有长连接的情况,取两次平均值
resp, err = i.sign("wtlogin.login", 1, []byte{11, 45, 14}, nil)
if err != nil || resp.Value.Sign == "" {
i.status.Store(uint32(Down))
i.latency.Store(99999)
i.latency.Store(serverLatencyDown)
return
}
// 粗略计算,应该足够了
Expand Down