Skip to content

Commit

Permalink
Write normalized scheme and host to routing.Route fields
Browse files Browse the repository at this point in the history
This change is needed for proxy to easily have normalized host and use it
in endpointregistry for dynamic host-wide data to be easily fetchable.

Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Jan 5, 2024
1 parent e297b55 commit 01bfbd8
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 83 deletions.
42 changes: 2 additions & 40 deletions filters/fadein/fadein.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package fadein

import (
"fmt"
"net"
"net/url"
"strings"
"time"

log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
)

Expand Down Expand Up @@ -101,42 +99,6 @@ func NewEndpointCreated() filters.Spec {

func (endpointCreated) Name() string { return filters.EndpointCreatedName }

func normalizeSchemeHost(s, h string) (string, string, error) {
// endpoint address cannot contain path, the rest is not case sensitive
s, h = strings.ToLower(s), strings.ToLower(h)

hh, p, err := net.SplitHostPort(h)
if err != nil {
// what is the actual right way of doing this, considering IPv6 addresses, too?
if !strings.Contains(err.Error(), "missing port") {
return "", "", err
}

p = ""
} else {
h = hh
}

switch {
case p == "" && s == "http":
p = "80"
case p == "" && s == "https":
p = "443"
}

h = net.JoinHostPort(h, p)
return s, h, nil
}

func normalizeEndpoint(e string) (string, string, error) {
u, err := url.Parse(e)
if err != nil {
return "", "", err
}

return normalizeSchemeHost(u.Scheme, u.Host)
}

func endpointKey(scheme, host string) string {
return fmt.Sprintf("%s://%s", scheme, host)
}
Expand All @@ -151,7 +113,7 @@ func (endpointCreated) CreateFilter(args []interface{}) (filters.Filter, error)
return nil, filters.ErrInvalidFilterParameters
}

s, h, err := normalizeEndpoint(e)
s, h, err := snet.SchemeHost(e)
if err != nil {
return nil, err
}
Expand Down
37 changes: 1 addition & 36 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import (
"fmt"
"math"
"math/rand"
"net"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -428,12 +425,7 @@ func (a Algorithm) String() string {
func parseEndpoints(r *routing.Route) error {
r.LBEndpoints = make([]routing.LBEndpoint, len(r.Route.LBEndpoints))
for i, e := range r.Route.LBEndpoints {
eu, err := url.ParseRequestURI(e)
if err != nil {
return err
}

scheme, host, err := normalizeSchemeHost(eu.Scheme, eu.Host)
scheme, host, err := snet.SchemeHost(e)
if err != nil {
return err
}
Expand Down Expand Up @@ -463,33 +455,6 @@ func setAlgorithm(r *routing.Route) error {
return nil
}

func normalizeSchemeHost(s, h string) (string, string, error) {
// endpoint address cannot contain path, the rest is not case sensitive
s, h = strings.ToLower(s), strings.ToLower(h)

hh, p, err := net.SplitHostPort(h)
if err != nil {
// what is the actual right way of doing this, considering IPv6 addresses, too?
if !strings.Contains(err.Error(), "missing port") {
return "", "", err
}

p = ""
} else {
h = hh
}

switch {
case p == "" && s == "http":
p = "80"
case p == "" && s == "https":
p = "443"
}

h = net.JoinHostPort(h, p)
return s, h, nil
}

// Do implements routing.PostProcessor
func (p *algorithmProvider) Do(r []*routing.Route) []*routing.Route {
rr := make([]*routing.Route, 0, len(r))
Expand Down
45 changes: 45 additions & 0 deletions net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"strings"

"github.com/pkg/errors"
"go4.org/netipx"
)

Expand Down Expand Up @@ -154,3 +156,46 @@ func ParseIPCIDRs(cidrs []string) (*netipx.IPSet, error) {

return ips, nil
}

// SchemeHost parses URI string (without #fragment part) and returns schema used in this URI as first return value and
// host[:port] part as second return value. Port is never omitted for HTTP(S): if no port is specified in URI, default port for given
// schema is used. If URI is invalid, error is returned.
func SchemeHost(input string) (string, string, error) {
u, err := url.ParseRequestURI(input)
if err != nil {
return "", "", err
}
if u.Scheme == "" {
return "", "", errors.Errorf(`parse %q: missing scheme`, input)
}
if u.Host == "" {
return "", "", errors.Errorf(`parse %q: missing host`, input)
}

// endpoint address cannot contain path, the rest is not case sensitive
s, h := strings.ToLower(u.Scheme), strings.ToLower(u.Host)

hh, p, err := net.SplitHostPort(h)
if err != nil {
// what is the actual right way of doing this, considering IPv6 addresses, too?
if !strings.Contains(err.Error(), "missing port") {
return "", "", err
}

p = ""
} else {
h = hh
}

switch {
case p == "" && s == "http":
p = "80"
case p == "" && s == "https":
p = "443"
}

if p != "" {
h = net.JoinHostPort(h, p)
}
return s, h, nil
}
192 changes: 192 additions & 0 deletions net/net_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,41 @@
package net

import (
"fmt"
"net"
"net/http"
"net/netip"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

type tc[T any] struct {
location string
in T
}

// https://github.com/golang/go/issues/52751
func testCase[T any](in T) tc[T] {
_, file, line, _ := runtime.Caller(1)
location := fmt.Sprintf("%s:%d", filepath.Base(file), line)
return tc[T]{location: location, in: in}
}

func (tc *tc[T]) logLocation(t *testing.T) {
t.Helper()
t.Cleanup(func() {
t.Helper()
if t.Failed() {
t.Logf("Test case location: %s", tc.location)
}
})
}

func TestRemoteAddr(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down Expand Up @@ -236,3 +263,168 @@ func TestIPNetsDoNotContain(t *testing.T) {
})
}
}

type TestSchemeHostItem struct {
input string
scheme string
host string
err string
}

func TestSchemeHost(t *testing.T) {
for _, ti := range []tc[TestSchemeHostItem]{
testCase(TestSchemeHostItem{
input: "http://example.com",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:80",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:8080",
scheme: "http",
host: "example.com:8080",
err: "",
}),

testCase(TestSchemeHostItem{
input: "https://example.com",
scheme: "https",
host: "example.com:443",
err: "",
}),
testCase(TestSchemeHostItem{
input: "https://example.com:443",
scheme: "https",
host: "example.com:443",
err: "",
}),
testCase(TestSchemeHostItem{
input: "https://example.com:8080",
scheme: "https",
host: "example.com:8080",
err: "",
}),

testCase(TestSchemeHostItem{
input: "postgres://example.com",
scheme: "postgres",
host: "example.com",
err: "",
}),
testCase(TestSchemeHostItem{
input: "postgres://example.com:5432",
scheme: "postgres",
host: "example.com:5432",
err: "",
}),
testCase(TestSchemeHostItem{
input: "postgresql://example.com",
scheme: "postgresql",
host: "example.com",
err: "",
}),
testCase(TestSchemeHostItem{
input: "postgresql://example.com:5432",
scheme: "postgresql",
host: "example.com:5432",
err: "",
}),

testCase(TestSchemeHostItem{
input: "someprotocol://example.com",
scheme: "someprotocol",
host: "example.com",
err: "",
}),
testCase(TestSchemeHostItem{
input: "someprotocol://example.com:12345",
scheme: "someprotocol",
host: "example.com:12345",
err: "",
}),

testCase(TestSchemeHostItem{
input: "example.com",
scheme: "",
host: "",
err: `parse "example.com": invalid URI for request`,
}),
testCase(TestSchemeHostItem{
input: "example.com/",
scheme: "",
host: "",
err: `parse "example.com/": invalid URI for request`,
}),
testCase(TestSchemeHostItem{
input: "example.com:80",
scheme: "",
host: "",
err: `parse "example.com:80": missing host`,
}),

testCase(TestSchemeHostItem{
input: "hTTP://exAMPLe.com",
scheme: "http",
host: "example.com:80",
err: "",
}),

testCase(TestSchemeHostItem{
input: "http://example.com/foo/bar",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:80/foo/bar",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:8080/foo/bar",
scheme: "http",
host: "example.com:8080",
err: "",
}),

testCase(TestSchemeHostItem{
input: "http://example.com?foo=bar",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:80?foo=bar",
scheme: "http",
host: "example.com:80",
err: "",
}),
testCase(TestSchemeHostItem{
input: "http://example.com:8080?foo=bar",
scheme: "http",
host: "example.com:8080",
err: "",
}),
} {
t.Run(ti.in.input, func(t *testing.T) {
ti.logLocation(t)

scheme, host, err := SchemeHost(ti.in.input)
if ti.in.err != "" {
assert.EqualError(t, err, ti.in.err)
} else {
if assert.NoError(t, err) {
assert.Equal(t, ti.in.scheme, scheme)
assert.Equal(t, ti.in.host, host)
}
}
})
}
}
Loading

0 comments on commit 01bfbd8

Please sign in to comment.