Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve get all questions and error handling #300

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/debug.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"errors"
"os"
"strings"

Expand Down Expand Up @@ -42,7 +43,7 @@ var whoamiCmd = &cobra.Command{
return err
}
if !user.IsSignedIn {
return leetcode.ErrForbidden
return errors.New("user not signed in")
}
cmd.Println(user.Whoami(c))
return nil
Expand Down
2 changes: 1 addition & 1 deletion leetcode/cache_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (c *jsonCache) Update() error {
if err != nil {
return err
}
log.Info("cache updated", "path", c.path)
log.Info("questions cache updated", "count", len(all), "path", c.path)
return nil
}

Expand Down
3 changes: 2 additions & 1 deletion leetcode/cache_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ func (c *sqliteCache) Update() error {
if err != nil {
return err
}
count := len(all)
placeholder := "(" + strings.Repeat("?,", 19) + "?)"
batch := 100
for len(all) > 0 {
Expand Down Expand Up @@ -367,6 +368,6 @@ func (c *sqliteCache) Update() error {
if err != nil {
return err
}
log.Info("cache updated", "path", c.path)
log.Info("questions cache updated", "count", count, "path", c.path)
return nil
}
112 changes: 61 additions & 51 deletions leetcode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,38 @@ var (
ErrPaidOnlyQuestion = errors.New("this is paid only question, you need to subscribe to LeetCode Premium")
ErrQuestionNotFound = errors.New("no such question")
ErrContestNotStarted = errors.New("contest has not started")
ErrForbidden = UnexpectedStatusCode{
Code: 403,
Body: "access is forbidden, your cookies may have expired or LeetCode has restricted its API access",
}
ErrTooManyRequests = UnexpectedStatusCode{
Code: 429,
Body: "LeetCode limited you access rate, you may be submitting too frequently",
}
)

type UnexpectedStatusCode struct {
Code int
Body string
}

func (e UnexpectedStatusCode) IsError() bool {
return e.Code != 0
}

func (e UnexpectedStatusCode) Error() string {
body := "<empty>"
if len(e.Body) > 1024 {
body = e.Body[:1024]
body := "<empty body>"
if len(e.Body) > 100 {
body = e.Body[:100] + "..."
}
return fmt.Sprintf("[%d %s] %s", e.Code, http.StatusText(e.Code), body)
}

func NewUnexpectedStatusCode(code int, body []byte) UnexpectedStatusCode {
err := UnexpectedStatusCode{Code: code}
switch code {
case http.StatusTooManyRequests:
err.Body = "LeetCode limited you access rate, you may be submitting too frequently"
case http.StatusForbidden:
err.Body = "Access is forbidden, your cookies may have expired or LeetCode has restricted its API access"
default:
err.Body = utils.BytesToString(body)
}
return err
}

type Client interface {
BaseURI() string
Inspect(typ string) (map[string]any, error)
Expand Down Expand Up @@ -182,7 +191,7 @@ const (
problemsApiTagsPath = "/problems/api/tags/"
)

func (c *cnClient) send(req *http.Request, authType authType, result any, failure any) (*http.Response, error) {
func (c *cnClient) send(req *http.Request, authType authType, result any) (*http.Response, error) {
switch authType {
case withoutAuth:
case withAuth:
Expand All @@ -204,31 +213,29 @@ func (c *cnClient) send(req *http.Request, authType authType, result any, failur
log.Debug("request", "method", req.Method, "url", req.URL.String(), "body", utils.BytesToString(bodyStr))
}

var resp *http.Response
err := retry.Do(
func() error {
var err error
resp, err = c.http.Do(req, result, failure)
var (
err error
respErr UnexpectedStatusCode
)
_, err = c.http.Do(req, result, &respErr)
if err != nil {
return err
}
if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
switch resp.StatusCode {
case http.StatusTooManyRequests:
return ErrTooManyRequests
case http.StatusForbidden:
return ErrForbidden
default:
body, _ := io.ReadAll(resp.Body)
return UnexpectedStatusCode{Code: resp.StatusCode, Body: string(body)}
}
if respErr.IsError() {
return respErr
}
return nil
},
retry.RetryIf(
func(err error) bool {
// Do not retry on 429
return !errors.Is(err, ErrTooManyRequests)
var e UnexpectedStatusCode
if errors.As(err, &e) && e.Code == http.StatusTooManyRequests {
return false
}
return true
},
),
retry.Attempts(3),
Expand All @@ -244,7 +251,7 @@ func (c *cnClient) send(req *http.Request, authType authType, result any, failur
}

//nolint:unused
func (c *cnClient) graphqlGet(req graphqlRequest, result any, failure any) (*http.Response, error) {
func (c *cnClient) graphqlGet(req graphqlRequest, result any) (*http.Response, error) {
type params struct {
Query string `url:"query"`
OperationName string `url:"operationName"`
Expand All @@ -262,10 +269,10 @@ func (c *cnClient) graphqlGet(req graphqlRequest, result any, failure any) (*htt
if err != nil {
return nil, err
}
return c.send(r, req.authType, result, failure)
return c.send(r, req.authType, result)
}

func (c *cnClient) graphqlPost(req graphqlRequest, result any, failure any) (*http.Response, error) {
func (c *cnClient) graphqlPost(req graphqlRequest, result any) (*http.Response, error) {
v := req.variables
if v == nil {
v = make(map[string]any)
Expand All @@ -279,23 +286,23 @@ func (c *cnClient) graphqlPost(req graphqlRequest, result any, failure any) (*ht
if err != nil {
return nil, err
}
return c.send(r, req.authType, result, failure)
return c.send(r, req.authType, result)
}

func (c *cnClient) jsonGet(url string, query any, authType authType, result any, failure any) (*http.Response, error) {
func (c *cnClient) jsonGet(url string, query any, authType authType, result any) (*http.Response, error) {
r, err := c.http.New().Get(url).QueryStruct(query).Request()
if err != nil {
return nil, err
}
return c.send(r, authType, result, failure)
return c.send(r, authType, result)
}

func (c *cnClient) jsonPost(url string, json any, authType authType, result any, failure any) (*http.Response, error) {
func (c *cnClient) jsonPost(url string, json any, authType authType, result any) (*http.Response, error) {
r, err := c.http.New().Post(url).BodyJSON(json).Request()
if err != nil {
return nil, err
}
return c.send(r, authType, result, failure)
return c.send(r, authType, result)
}

func (c *cnClient) BaseURI() string {
Expand Down Expand Up @@ -339,7 +346,6 @@ query a {
_, err := c.graphqlGet(
graphqlRequest{query: query},
&resp,
nil,
)
return resp, err
}
Expand Down Expand Up @@ -407,7 +413,7 @@ query globalData {
} `json:"data"`
}
_, err := c.graphqlPost(
graphqlRequest{query: query, authType: requireAuth}, &resp, nil,
graphqlRequest{query: query, authType: requireAuth}, &resp,
)
if err != nil {
return nil, err
Expand All @@ -428,7 +434,7 @@ func (c *cnClient) getQuestionData(slug string, query string, authType authType)
operationName: "questionData",
variables: map[string]any{"titleSlug": slug},
authType: authType,
}, &resp, nil,
}, &resp,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -500,7 +506,7 @@ func (c *cnClient) GetAllQuestions() ([]*QuestionData, error) {
graphqlRequest{
query: query,
operationName: "AllQuestionUrls",
}, &resp, nil,
}, &resp,
)
if err != nil {
return nil, err
Expand All @@ -523,11 +529,15 @@ func (c *cnClient) GetAllQuestions() ([]*QuestionData, error) {
go pw.Render()

var qs []*QuestionData
var respErr UnexpectedStatusCode
dec := progressDecoder{smartDecoder{LogResponse: false}, tracker}
_, err = c.http.New().Get(url).ResponseDecoder(dec).ReceiveSuccess(&qs)
_, err = c.http.New().Get(url).ResponseDecoder(dec).Receive(&qs, &respErr)
if err != nil {
return nil, err
}
if respErr.IsError() {
return nil, respErr
}
for i := range qs {
qs[i].client = c
qs[i].partial = 1
Expand All @@ -552,7 +562,7 @@ func (c *cnClient) GetTodayQuestion() (*QuestionData, error) {
query: query,
operationName: "questionOfToday",
authType: withoutAuth,
}, &resp, nil,
}, &resp,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -582,7 +592,7 @@ func (c *cnClient) GetQuestionOfDate(date time.Time) (*QuestionData, error) {
},
authType: withAuth,
},
&resp, nil,
&resp,
)
if err != nil {
return nil, err
Expand All @@ -601,7 +611,7 @@ func (c *cnClient) GetQuestionOfDate(date time.Time) (*QuestionData, error) {
func (c *cnClient) getContest(contestSlug string) (*Contest, error) {
path := fmt.Sprintf(contestInfoPath, contestSlug)
var resp gjson.Result
_, err := c.jsonGet(path, nil, withAuth, &resp, nil)
_, err := c.jsonGet(path, nil, withAuth, &resp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -653,7 +663,7 @@ func (c *cnClient) GetContestQuestionData(contestSlug string, questionSlug strin
path := fmt.Sprintf(contestProblemsPath, contestSlug, questionSlug)
var html []byte
req, _ := c.http.New().Get(path).Request()
_, err := c.send(req, requireAuth, &html, nil)
_, err := c.send(req, requireAuth, &html)
if err != nil {
var e UnexpectedStatusCode
if errors.As(err, &e) && e.Code == 302 {
Expand Down Expand Up @@ -800,7 +810,7 @@ func (c *cnClient) RunCode(q *QuestionData, lang string, code string, dataInput
"question_id": q.QuestionId,
"typed_code": code,
"data_input": dataInput,
}, requireAuth, &resp, nil,
}, requireAuth, &resp,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -828,7 +838,7 @@ func (c *cnClient) SubmitCode(q *QuestionData, lang string, code string) (string
"questionSlug": q.TitleSlug,
"question_id": q.QuestionId,
"typed_code": code,
}, requireAuth, &resp, nil,
}, requireAuth, &resp,
)
return resp.Get("submission_id").String(), err
}
Expand All @@ -839,7 +849,7 @@ func (c *cnClient) CheckResult(submissionId string) (
) {
path := fmt.Sprintf(checkResultPath, submissionId)
var result gjson.Result
_, err := c.jsonGet(path, nil, requireAuth, &result, nil)
_, err := c.jsonGet(path, nil, requireAuth, &result)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -871,7 +881,7 @@ func (c *cnClient) GetUpcomingContests() ([]*Contest, error) {
`
var resp gjson.Result
_, err := c.graphqlPost(
graphqlRequest{query: query, authType: withAuth}, &resp, nil,
graphqlRequest{query: query, authType: withAuth}, &resp,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -903,7 +913,7 @@ func (c *cnClient) GetUpcomingContests() ([]*Contest, error) {

func (c *cnClient) RegisterContest(slug string) error {
path := fmt.Sprintf(contestRegisterPath, slug)
_, err := c.jsonPost(path, nil, requireAuth, nil, nil)
_, err := c.jsonPost(path, nil, requireAuth, nil)
var e UnexpectedStatusCode
if errors.As(err, &e) && e.Code == http.StatusFound {
err = nil
Expand All @@ -914,7 +924,7 @@ func (c *cnClient) RegisterContest(slug string) error {
func (c *cnClient) UnregisterContest(slug string) error {
path := fmt.Sprintf(contestRegisterPath, slug)
req, _ := c.http.New().Delete(path).Request()
_, err := c.send(req, requireAuth, nil, nil)
_, err := c.send(req, requireAuth, nil)
return err
}

Expand Down Expand Up @@ -964,7 +974,7 @@ query problemsetQuestionList($categorySlug: String, $limit: Int, $skip: Int, $fi
graphqlRequest{
query: query,
variables: vars,
}, &resp, nil,
}, &resp,
)
if err != nil {
return QuestionList{}, err
Expand All @@ -986,7 +996,7 @@ query problemsetQuestionList($categorySlug: String, $limit: Int, $skip: Int, $fi

func (c *cnClient) GetQuestionTags() ([]QuestionTag, error) {
var resp gjson.Result
_, err := c.jsonGet(problemsApiTagsPath, nil, withAuth, &resp, nil)
_, err := c.jsonGet(problemsApiTagsPath, nil, withAuth, &resp)
if err != nil {
return nil, err
}
Expand Down
Loading