From faae8e114f8d95a85eda02e80a9aed7a58973426 Mon Sep 17 00:00:00 2001 From: "Maksim (codesenberg) Fedoseev" Date: Mon, 30 Oct 2023 22:34:02 +0000 Subject: [PATCH] bombardier: change internal URL handling and set timeouts for dialing Type of url field in config struct was changed from string to net/url.URL. The way fasthttpClient is setup was adjusted to resolve the issue mentioned in #105. Dial function used by fasthttp client was adjusted to respect --timeout flag, which should fix #103. --- args_parser.go | 17 +--------- args_parser_test.go | 61 +++++++++++++--------------------- bombardier.go | 2 +- bombardier_performance_test.go | 6 ++-- bombardier_test.go | 32 +++++++++--------- client_cert_test.go | 2 +- clients.go | 41 +++++++++++------------ clients_test.go | 22 ++++++++---- common.go | 14 ++++++-- config.go | 28 +++++++--------- config_test.go | 58 ++++++++++++-------------------- dialer.go | 7 ++-- internal/test_info.go | 8 ++++- template/doc.go | 2 ++ templates.go | 2 +- 15 files changed, 142 insertions(+), 160 deletions(-) diff --git a/args_parser.go b/args_parser.go index eddea8e..e604bd4 100644 --- a/args_parser.go +++ b/args_parser.go @@ -206,7 +206,7 @@ func (k *kingpinParser) parse(args []string) (config, error) { "unknown format or invalid format spec %q", k.formatSpec, ) } - url, err := tryParseURL(k.url) + url, err := urlx.Parse(k.url) if err != nil { return emptyConf, err } @@ -264,18 +264,3 @@ func parsePrintSpec(spec string) (bool, bool, bool, error) { } return pi, pp, pr, nil } - -func tryParseURL(raw string) (string, error) { - u, err := urlx.Parse(raw) - if err != nil { - return "", fmt.Errorf("%q does not appear to be a URL: %v", raw, err) - } - - if u.Scheme != "http" && u.Scheme != "https" { - return "", fmt.Errorf( - "only http and https schemes are supported, which %q is not, url was %q", u.Scheme, raw, - ) - } - - return u.String(), nil -} diff --git a/args_parser_test.go b/args_parser_test.go index 9ba6513..707a1d8 100644 --- a/args_parser_test.go +++ b/args_parser_test.go @@ -59,7 +59,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), printIntro: true, printProgress: true, printResult: true, @@ -75,7 +75,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://localhost", + url: ParseURLOrPanic("https://localhost"), printIntro: true, printProgress: true, printResult: true, @@ -89,7 +89,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -133,7 +133,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), method: "GET", numReqs: &defaultNumberOfReqs, - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -159,7 +159,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), printLatencies: true, method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -185,7 +185,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), insecure: true, method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -214,7 +214,7 @@ func TestArgsParsing(t *testing.T) { method: "GET", keyPath: "testclient.key", certPath: "testclient.cert", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -254,7 +254,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), method: "POST", body: "reqbody", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -290,7 +290,7 @@ func TestArgsParsing(t *testing.T) { {"Two", "Value two"}, }, method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -325,7 +325,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), rate: &ten, printIntro: true, printProgress: true, @@ -350,7 +350,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), clientType: fhttp, printIntro: true, printProgress: true, @@ -371,7 +371,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), clientType: nhttp1, printIntro: true, printProgress: true, @@ -392,7 +392,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), clientType: nhttp2, printIntro: true, printProgress: true, @@ -424,7 +424,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), method: "GET", bodyFilePath: "testbody.txt", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -450,7 +450,7 @@ func TestArgsParsing(t *testing.T) { headers: new(headersList), method: "GET", stream: true, - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -469,7 +469,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -514,7 +514,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -559,7 +559,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: false, printResult: true, @@ -584,7 +584,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: false, printProgress: false, printResult: false, @@ -629,7 +629,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -674,7 +674,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -704,7 +704,7 @@ func TestArgsParsing(t *testing.T) { timeout: defaultTimeout, headers: new(headersList), method: "GET", - url: "https://somehost.somedomain", + url: ParseURLOrPanic("https://somehost.somedomain"), printIntro: true, printProgress: true, printResult: true, @@ -823,21 +823,6 @@ func TestArgsParsingWithInvalidPrintSpec(t *testing.T) { } } -func TestTryParseUrl(t *testing.T) { - invalid := []string{ - "ftp://bla:89", - "http://bla:bla:bla", - "htp:/bla:bla:bla", - } - - for _, url := range invalid { - gotURL, err := tryParseURL(url) - if err == nil { - t.Errorf("%q is not a valid URL, parsed as %q", url, gotURL) - } - } -} - func TestEmbeddedURLParsing(t *testing.T) { p := newKingpinParser() url := "http://127.0.0.1:8080/to?url=http://10.100.99.41:38667" @@ -845,7 +830,7 @@ func TestEmbeddedURLParsing(t *testing.T) { if err != nil { t.Error(err) } - if c.url != url { + if c.url.String() != url { t.Errorf("got %q, wanted %q", c.url, url) } } diff --git a/bombardier.go b/bombardier.go index d966116..7a57728 100644 --- a/bombardier.go +++ b/bombardier.go @@ -134,7 +134,7 @@ func newBombardier(c config) (*bombardier, error) { disableKeepAlives: c.disableKeepAlives, headers: c.headers, - url: c.url, + requestURL: c.url, method: c.method, body: pbody, bodProd: bsp, diff --git a/bombardier_performance_test.go b/bombardier_performance_test.go index 5d740ab..8249811 100644 --- a/bombardier_performance_test.go +++ b/bombardier_performance_test.go @@ -24,13 +24,14 @@ func BenchmarkBombardierSingleReqPerf(b *testing.B) { numConns: defaultNumberOfConns, numReqs: nil, duration: &longDuration, - url: "http://" + addr, + url: ParseURLOrPanic("http://" + addr), headers: new(headersList), timeout: defaultTimeout, method: "GET", body: "", printLatencies: false, clientType: clientTypeFromString(*clientType), + format: knownFormat("json"), }, b) } @@ -40,7 +41,7 @@ func BenchmarkBombardierRateLimitPerf(b *testing.B) { numConns: defaultNumberOfConns, numReqs: nil, duration: &longDuration, - url: "http://" + addr, + url: ParseURLOrPanic("http://" + addr), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -48,6 +49,7 @@ func BenchmarkBombardierRateLimitPerf(b *testing.B) { printLatencies: false, rate: &highRate, clientType: clientTypeFromString(*clientType), + format: knownFormat("json"), }, b) } diff --git a/bombardier_test.go b/bombardier_test.go index 335ddc1..de895bd 100644 --- a/bombardier_test.go +++ b/bombardier_test.go @@ -36,7 +36,7 @@ func testBombardierShouldFireSpecifiedNumberOfRequests( b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -71,7 +71,7 @@ func testBombardierShouldFinish(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, duration: &desiredTestDuration, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -134,7 +134,7 @@ func testBombardierShouldSendHeaders(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: &requestHeaders, timeout: defaultTimeout, method: "GET", @@ -180,7 +180,7 @@ func testBombardierHTTPCodeRecording(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -229,7 +229,7 @@ func testBombardierTimeoutRecoding(clientType clientTyp, t *testing.T) { numConns: defaultNumberOfConns, numReqs: &numReqs, duration: nil, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: shortTimeout, method: "GET", @@ -267,7 +267,7 @@ func testBombardierThroughputRecording(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -301,7 +301,7 @@ func TestBombardierStatsPrinting(t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -336,7 +336,7 @@ func TestBombardierErrorIfFailToReadClientCert(t *testing.T) { _, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: "http://localhost", + url: ParseURLOrPanic("http://localhost"), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -402,7 +402,7 @@ func testBombardierClientCerts(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &singleRequest, - url: server.URL, + url: ParseURLOrPanic(server.URL), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -449,7 +449,7 @@ func testBombardierRateLimiting(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, duration: &testDuration, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "GET", @@ -507,7 +507,7 @@ func testBombardierSendsBody(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &one, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "POST", @@ -556,7 +556,7 @@ func testBombardierSendsBodyFromFile(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &one, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "POST", @@ -576,7 +576,7 @@ func TestBombardierFileDoesntExist(t *testing.T) { bodyPath := "/does/not/exist.forreal" _, e := newBombardier(config{ numConns: defaultNumberOfConns, - url: "http://example.com", + url: ParseURLOrPanic("http://example.com"), headers: new(headersList), timeout: defaultTimeout, method: "POST", @@ -620,7 +620,7 @@ func testBombardierStreamsBody(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &one, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "POST", @@ -673,7 +673,7 @@ func testBombardierStreamsBodyFromFile(clientType clientTyp, t *testing.T) { b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &one, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: new(headersList), timeout: defaultTimeout, method: "POST", @@ -713,7 +713,7 @@ func testBombardierShouldSendCustomHostHeader( b, e := newBombardier(config{ numConns: defaultNumberOfConns, numReqs: &numReqs, - url: s.URL, + url: ParseURLOrPanic(s.URL), headers: &headers, timeout: defaultTimeout, method: "GET", diff --git a/client_cert_test.go b/client_cert_test.go index 1107af8..a580fcd 100644 --- a/client_cert_test.go +++ b/client_cert_test.go @@ -29,7 +29,7 @@ func TestGenerateTLSConfig(t *testing.T) { for _, e := range expectations { _, r := generateTLSConfig( config{ - url: "https://doesnt.exist.com", + url: ParseURLOrPanic("https://doesnt.exist.com"), certPath: e.certPath, keyPath: e.keyPath, }, diff --git a/clients.go b/clients.go index 5835c1a..500c17e 100644 --- a/clients.go +++ b/clients.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/goware/urlx" "github.com/valyala/fasthttp" ) @@ -27,8 +26,9 @@ type clientOpts struct { tlsConfig *tls.Config disableKeepAlives bool - headers *headersList - url, method string + requestURL *url.URL + headers *headersList + method string body *string bodProd bodyStreamProducer @@ -39,8 +39,9 @@ type clientOpts struct { type fasthttpClient struct { client *fasthttp.Client - headers *fasthttp.RequestHeader - host, requestURI, method, scheme string + headers *fasthttp.RequestHeader + uri *fasthttp.URI + method string body *string bodProd bodyStreamProducer @@ -48,14 +49,15 @@ type fasthttpClient struct { func newFastHTTPClient(opts *clientOpts) client { c := new(fasthttpClient) - u, err := urlx.Parse(opts.url) - if err != nil { - // opts.url guaranteed to be valid at this point + uri := fasthttp.AcquireURI() + if err := uri.Parse( + []byte(opts.requestURL.Host), + []byte(opts.requestURL.String()), + ); err != nil { + // opts.requestURL must always be valid panic(err) } - c.host = u.Host - c.requestURI = u.RequestURI() - c.scheme = u.Scheme + c.uri = uri c.client = &fasthttp.Client{ MaxConnsPerHost: int(opts.maxConns), ReadTimeout: opts.timeout, @@ -64,6 +66,7 @@ func newFastHTTPClient(opts *clientOpts) client { TLSConfig: opts.tlsConfig, Dial: fasthttpDialFunc( opts.bytesRead, opts.bytesWritten, + opts.timeout, ), } c.headers = headersToFastHTTPHeaders(opts.headers) @@ -78,13 +81,12 @@ func (c *fasthttpClient) do() ( // prepare the request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - req.Header.SetHost(c.host) + req.Header.SetMethod(c.method) if c.headers != nil { c.headers.CopyTo(&req.Header) } - req.SetRequestURI(c.requestURI) - req.Header.SetMethod(c.method) - req.URI().SetScheme(c.scheme) + req.SetURI(c.uri) + req.UseHostHeader = true if c.body != nil { req.SetBodyString(*c.body) } else { @@ -130,8 +132,8 @@ func newHTTPClient(opts *clientOpts) client { MaxIdleConnsPerHost: int(opts.maxConns), DisableKeepAlives: opts.disableKeepAlives, ForceAttemptHTTP2: opts.HTTP2, + DialContext: httpDialContextFunc(opts.bytesRead, opts.bytesWritten, opts.timeout), } - tr.DialContext = httpDialContextFunc(opts.bytesRead, opts.bytesWritten) cl := &http.Client{ Transport: tr, @@ -144,12 +146,7 @@ func newHTTPClient(opts *clientOpts) client { c.headers = headersToHTTPHeaders(opts.headers) c.method, c.body, c.bodProd = opts.method, opts.body, opts.bodProd - var err error - c.url, err = urlx.Parse(opts.url) - if err != nil { - // opts.url guaranteed to be valid at this point - panic(err) - } + c.url = opts.requestURL return client(c) } diff --git a/clients_test.go b/clients_test.go index 3542fa8..757aa00 100644 --- a/clients_test.go +++ b/clients_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "sync/atomic" "testing" + + "github.com/goware/urlx" ) func TestShouldReturnNilIfNoHeadersWhereSet(t *testing.T) { @@ -77,12 +79,16 @@ func TestHTTP2Client(t *testing.T) { defer s.Close() bytesRead, bytesWritten := int64(0), int64(0) + requestURL, err := urlx.Parse(s.URL) + if err != nil { + t.Fatal(err) + } c := newHTTPClient(&clientOpts{ HTTP2: true, - headers: new(headersList), - url: s.URL, - method: "GET", + headers: new(headersList), + requestURL: requestURL, + method: "GET", tlsConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -127,12 +133,16 @@ func TestHTTP1Clients(t *testing.T) { defer s.Close() bytesRead, bytesWritten := int64(0), int64(0) + requestURL, err := urlx.Parse(s.URL) + if err != nil { + t.Fatal(err) + } cc := &clientOpts{ HTTP2: false, - headers: new(headersList), - url: s.URL, - method: "GET", + headers: new(headersList), + requestURL: requestURL, + method: "GET", body: new(string), diff --git a/common.go b/common.go index 6f9791f..cd3d541 100644 --- a/common.go +++ b/common.go @@ -2,8 +2,11 @@ package main import ( "errors" + "net/url" "sort" "time" + + "github.com/goware/urlx" ) const ( @@ -31,8 +34,7 @@ var ( } cantHaveBody = []string{"HEAD"} - errInvalidURL = errors.New( - "no hostname or invalid scheme") + errUnsupportedScheme = errors.New("unsupported scheme") errInvalidNumberOfConns = errors.New( "invalid number of connections(must be > 0)") errInvalidNumberOfRequests = errors.New( @@ -56,6 +58,14 @@ var ( "empty print spec is not a valid print spec") ) +func ParseURLOrPanic(s string) *url.URL { + u, err := urlx.Parse(s) + if err != nil { + panic(err) + } + return u +} + func init() { sort.Strings(httpMethods) sort.Strings(cantHaveBody) diff --git a/config.go b/config.go index 123c5c2..9dbbc8d 100644 --- a/config.go +++ b/config.go @@ -8,15 +8,16 @@ import ( ) type config struct { - numConns uint64 - numReqs *uint64 - disableKeepAlives bool - duration *time.Duration - url, method, certPath, keyPath string - body, bodyFilePath string - stream bool - headers *headersList - timeout time.Duration + numConns uint64 + numReqs *uint64 + disableKeepAlives bool + duration *time.Duration + url *url.URL + method, certPath, keyPath string + body, bodyFilePath string + stream bool + headers *headersList + timeout time.Duration // TODO(codesenberg): printLatencies should probably be // re(named&maked) into printPercentiles or even let // users provide their own percentiles and not just @@ -84,14 +85,9 @@ func (c *config) testType() testTyp { } func (c *config) checkURL() error { - url, err := url.Parse(c.url) - if err != nil { - return err + if c.url.Scheme != "http" && c.url.Scheme != "https" { + return errUnsupportedScheme } - if url.Host == "" || (url.Scheme != "http" && url.Scheme != "https") { - return errInvalidURL - } - c.url = url.String() return nil } diff --git a/config_test.go b/config_test.go index 0066c3a..f77f945 100644 --- a/config_test.go +++ b/config_test.go @@ -59,26 +59,12 @@ func TestCheckArgs(t *testing.T) { in config out error }{ - { - config{ - numConns: defaultNumberOfConns, - numReqs: &defaultNumberOfReqs, - duration: &defaultTestDuration, - url: "ftp://localhost:8080", - headers: noHeaders, - timeout: defaultTimeout, - method: "GET", - body: "", - format: knownFormat("plain-text"), - }, - errInvalidURL, - }, { config{ numConns: 0, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -92,7 +78,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &invalidNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -106,7 +92,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: nil, duration: &smallTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -120,7 +106,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: negativeTimeoutDuration, method: "GET", @@ -134,7 +120,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "HEAD", @@ -148,7 +134,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "HEAD", @@ -162,7 +148,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -176,7 +162,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -190,7 +176,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -204,7 +190,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -220,7 +206,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -236,7 +222,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "GET", @@ -250,7 +236,7 @@ func TestCheckArgs(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: noHeaders, timeout: defaultTimeout, method: "POST", @@ -273,18 +259,18 @@ func TestCheckArgs(t *testing.T) { } } -func TestCheckArgsGarbageUrl(t *testing.T) { +func TestCheckArgsUnsupportedURLScheme(t *testing.T) { c := config{ numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "8080", + url: ParseURLOrPanic("ftp://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "GET", body: "", } - if c.checkArgs() == nil { + if c.checkArgs() != errUnsupportedScheme { t.Fail() } } @@ -294,7 +280,7 @@ func TestCheckArgsInvalidRequestMethod(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "ABRACADABRA", @@ -314,7 +300,7 @@ func TestCheckArgsTestType(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: nil, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "GET", @@ -324,7 +310,7 @@ func TestCheckArgsTestType(t *testing.T) { numConns: defaultNumberOfConns, numReqs: nil, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "GET", @@ -334,7 +320,7 @@ func TestCheckArgsTestType(t *testing.T) { numConns: defaultNumberOfConns, numReqs: &defaultNumberOfReqs, duration: &defaultTestDuration, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "GET", @@ -344,7 +330,7 @@ func TestCheckArgsTestType(t *testing.T) { numConns: defaultNumberOfConns, numReqs: nil, duration: nil, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: defaultTimeout, method: "GET", @@ -374,7 +360,7 @@ func TestTimeoutMillis(t *testing.T) { numConns: defaultNumberOfConns, numReqs: nil, duration: nil, - url: "http://localhost:8080", + url: ParseURLOrPanic("http://localhost:8080"), headers: nil, timeout: 2 * time.Second, method: "GET", diff --git a/dialer.go b/dialer.go index 7ca85d9..f0a3a41 100644 --- a/dialer.go +++ b/dialer.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync/atomic" + "time" ) type countingConn struct { @@ -33,9 +34,10 @@ func (cc *countingConn) Write(b []byte) (n int, err error) { var fasthttpDialFunc = func( bytesRead, bytesWritten *int64, + dialTimeout time.Duration, ) func(string) (net.Conn, error) { return func(address string) (net.Conn, error) { - conn, err := net.Dial("tcp", address) + conn, err := net.DialTimeout("tcp", address, dialTimeout) if err != nil { return nil, err } @@ -52,8 +54,9 @@ var fasthttpDialFunc = func( var httpDialContextFunc = func( bytesRead, bytesWritten *int64, + dialTimeout time.Duration, ) func(context.Context, string, string) (net.Conn, error) { - dialer := &net.Dialer{} + dialer := &net.Dialer{Timeout: dialTimeout} return func(ctx context.Context, network, address string) (net.Conn, error) { conn, err := dialer.DialContext(ctx, network, address) if err != nil { diff --git a/internal/test_info.go b/internal/test_info.go index 98ed641..1c0faf5 100644 --- a/internal/test_info.go +++ b/internal/test_info.go @@ -2,6 +2,7 @@ package internal import ( "math" + "net/url" "sort" "time" ) @@ -27,7 +28,7 @@ type Spec struct { TestDuration time.Duration Method string - URL string + URL *url.URL Headers []Header @@ -44,6 +45,11 @@ type Spec struct { Rate *uint64 } +// RequestURL returns URL as string. +func (s Spec) RequestURL() string { + return s.URL.String() +} + // IsTimedTest tells if the test was limited by time. func (s Spec) IsTimedTest() bool { return s.TestType == ByTime diff --git a/template/doc.go b/template/doc.go index f66c00f..867931f 100644 --- a/template/doc.go +++ b/template/doc.go @@ -6,6 +6,8 @@ User-defined templates use Go's text/template package, so you might want to check its documentation first. There are a bunch of helper methods available inside a template besides those described in aforementioned documentation, namely: + - URLString() + Returns the URL string used for the load test. - WithLatencies() Tells whether --latencies flag were activated. - FormatBinary(numberOfBytes float64) string diff --git a/templates.go b/templates.go index 4f71edc..dd83772 100644 --- a/templates.go +++ b/templates.go @@ -75,7 +75,7 @@ const ( ,"testType":"number-of-requests","numberOfRequests":{{ .NumberOfRequests }} {{- end -}} -,"method":"{{ .Method }}","url":{{ .URL | printf "%q" }} +,"method":"{{ .Method }}","url":{{ .RequestURL | printf "%q" }} {{- with .Headers -}} ,"headers":[