Skip to content

Commit

Permalink
add: add support client rate limit (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
mimuret authored Feb 15, 2023
1 parent 88b1f21 commit 004d998
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 // indirect
golang.org/x/sys v0.0.0-20220209214540-3681064d5158 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
Expand Down
41 changes: 37 additions & 4 deletions pkg/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ import (
"os"
"reflect"
"time"

"golang.org/x/time/rate"
)

const DefaultEndpoint = "https://api.dns-platform.jp/dpf/v1"

type ClientInterface interface {
SetRoundTripper(rt http.RoundTripper)
Read(ctx context.Context, s Spec) (string, error)
List(ctx context.Context, s ListSpec, keywords SearchParams) (string, error)
ListAll(ctx context.Context, s CountableListSpec, keywords SearchParams) (string, error)
Expand All @@ -34,9 +37,10 @@ var _ ClientInterface = &Client{}
type Client struct {
Endpoint string
Token string
logger Logger

Client *http.Client
logger Logger
client *http.Client

LastRequest *RequestInfo
LastResponse *ResponseInfo
}
Expand All @@ -52,14 +56,43 @@ type ResponseInfo struct {
Body []byte
}

type RateRoundTripper struct {
RroundTripper http.RoundTripper
Limiter *rate.Limiter
}

func NewRateRoundTripper(rt http.RoundTripper, limiter *rate.Limiter) *RateRoundTripper {
return &RateRoundTripper{
RroundTripper: rt,
Limiter: limiter,
}
}

func (r *RateRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if r.Limiter == nil {
r.Limiter = rate.NewLimiter(rate.Limit(1.0), 5)
}
if r.RroundTripper == nil {
r.RroundTripper = http.DefaultTransport
}
if err := r.Limiter.Wait(req.Context()); err != nil {
return nil, fmt.Errorf("request rate-limit by client side: %w", err)
}
return r.RroundTripper.RoundTrip(req)
}

func NewClient(token string, endpoint string, logger Logger) *Client {
if endpoint == "" {
endpoint = DefaultEndpoint
}
if logger == nil {
logger = NewStdLogger(os.Stderr, "dpf-client", 0, 4)
}
return &Client{Endpoint: endpoint, Token: token, logger: logger, Client: http.DefaultClient}
return &Client{Endpoint: endpoint, Token: token, logger: logger, client: &http.Client{Transport: NewRateRoundTripper(nil, nil)}}
}

func (c *Client) SetRoundTripper(rt http.RoundTripper) {
c.client.Transport = rt
}

func (c *Client) marshalJSON(action Action, body interface{}) ([]byte, error) {
Expand Down Expand Up @@ -138,7 +171,7 @@ func (c *Client) Do(ctx context.Context, spec Spec, action Action, body interfac
return "", err
}
// request
resp, err := c.Client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to get http response: %w", err)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"golang.org/x/time/rate"

"github.com/jarcoal/httpmock"
"github.com/mimuret/golang-iij-dpf/pkg/api"
Expand Down Expand Up @@ -84,6 +85,7 @@ var _ = Describe("Client", func() {
"error_type": "NotFound",
"error_message": "Specified resource not found."
}`)))
c.SetRoundTripper(api.NewRateRoundTripper(nil, rate.NewLimiter(rate.Inf, 0)))
})
Context("NewClient", func() {
BeforeEach(func() {
Expand Down
6 changes: 5 additions & 1 deletion pkg/testtool/testclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ func NewTestClient(token, endpoint string, logger api.Logger) *TestClient {
RequestHeaders: make(map[string]http.Header),
RequestBody: make(map[string]string),
}
cl.Client.Transport = nop
cl.SetRoundTripper(nop)
nop.Client = cl
return nop
}

func (n *TestClient) SetRoundTripper(rt http.RoundTripper) {
n.Client.SetRoundTripper(rt)
}

func (n *TestClient) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Body != nil {
bs, err := ioutil.ReadAll(req.Body)
Expand Down

0 comments on commit 004d998

Please sign in to comment.