From e714b00803744220b05fc5ea3294a7d2c4126124 Mon Sep 17 00:00:00 2001 From: coyzeng Date: Wed, 7 Dec 2022 18:52:59 +0800 Subject: [PATCH 1/2] loadbalance strategy extension --- roundrobin/options.go | 8 ++-- roundrobin/rebalancer.go | 4 +- roundrobin/rebalancer_test.go | 8 ++-- roundrobin/rr.go | 50 +++++++++++++++++-------- roundrobin/strategy.go | 70 +++++++++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 26 deletions(-) create mode 100644 roundrobin/strategy.go diff --git a/roundrobin/options.go b/roundrobin/options.go index a794ae6..1382589 100644 --- a/roundrobin/options.go +++ b/roundrobin/options.go @@ -67,15 +67,15 @@ func RebalancerDebug(debug bool) RebalancerOption { } // ServerOption provides various options for server, e.g. weight. -type ServerOption func(*server) error +type ServerOption func(s Server) error // Weight is an optional functional argument that sets weight of the server. func Weight(w int) ServerOption { - return func(s *server) error { + return func(s Server) error { if w < 0 { - return fmt.Errorf("Weight should be >= 0") + return fmt.Errorf("Weight should be >= 0 ") } - s.weight = w + s.Set(w) return nil } } diff --git a/roundrobin/rebalancer.go b/roundrobin/rebalancer.go index 9e5f6b6..ba3a9c5 100644 --- a/roundrobin/rebalancer.go +++ b/roundrobin/rebalancer.go @@ -29,7 +29,7 @@ type BalancerHandler interface { ServerWeight(u *url.URL) (int, bool) RemoveServer(u *url.URL) error UpsertServer(u *url.URL, options ...ServerOption) error - NextServer() (*url.URL, error) + NextServer(w http.ResponseWriter, req *http.Request, neq *http.Request) (*url.URL, error) Next() http.Handler } @@ -144,7 +144,7 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if !stuck { - fwdURL, err := rb.next.NextServer() + fwdURL, err := rb.next.NextServer(w, req, &newReq) if err != nil { rb.errHandler.ServeHTTP(w, req, err) return diff --git a/roundrobin/rebalancer_test.go b/roundrobin/rebalancer_test.go index e8e2425..be5b020 100644 --- a/roundrobin/rebalancer_test.go +++ b/roundrobin/rebalancer_test.go @@ -123,8 +123,8 @@ func TestRebalancerRecovery(t *testing.T) { assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) - assert.Equal(t, 1, lb.servers[0].weight) - assert.Equal(t, FSMMaxWeight, lb.servers[1].weight) + assert.Equal(t, 1, lb.servers[0].Weight()) + assert.Equal(t, FSMMaxWeight, lb.servers[1].Weight()) // server a is now recovering, the weights should go back to the original state rb.servers[0].meter.(*testMeter).rating = 0 @@ -141,8 +141,8 @@ func TestRebalancerRecovery(t *testing.T) { assert.Equal(t, 1, rb.servers[1].curWeight) // Make sure we have applied the weights to the inner load balancer - assert.Equal(t, 1, lb.servers[0].weight) - assert.Equal(t, 1, lb.servers[1].weight) + assert.Equal(t, 1, lb.servers[0].Weight()) + assert.Equal(t, 1, lb.servers[1].Weight()) } // Test scenario when increaing the weight on good endpoints made it worse. diff --git a/roundrobin/rr.go b/roundrobin/rr.go index f1b891d..564a919 100644 --- a/roundrobin/rr.go +++ b/roundrobin/rr.go @@ -3,6 +3,7 @@ package roundrobin import ( "fmt" + "math/rand" "net/http" "net/url" "sync" @@ -17,7 +18,7 @@ type RoundRobin struct { errHandler utils.ErrorHandler // Current index (starts from -1) index int - servers []*server + servers []Server currentWeight int stickySession *StickySession requestRewriteListener RequestRewriteListener @@ -32,7 +33,7 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { next: next, index: -1, mutex: &sync.Mutex{}, - servers: []*server{}, + servers: []Server{}, stickySession: nil, log: &utils.NoopLogger{}, @@ -76,7 +77,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if !stuck { - uri, err := r.NextServer() + uri, err := r.NextServer(w, req, &newReq) if err != nil { r.errHandler.ServeHTTP(w, req, err) return @@ -103,15 +104,20 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // NextServer gets the next server. -func (r *RoundRobin) NextServer() (*url.URL, error) { - srv, err := r.nextServer() +func (r *RoundRobin) NextServer(w http.ResponseWriter, req *http.Request, neq *http.Request) (*url.URL, error) { + // Use extension balance server, if extension return multiple servers, choose anyone. + if ss := Strategy().Next(w, req, neq, r.servers); len(ss) > 0 { + srv := ss[rand.Intn(len(ss))] + return utils.CopyURL(srv.URL()), nil + } + srv, err := r.nextServer(w, req) if err != nil { return nil, err } - return utils.CopyURL(srv.url), nil + return utils.CopyURL(srv.URL()), nil } -func (r *RoundRobin) nextServer() (*server, error) { +func (r *RoundRobin) nextServer(w http.ResponseWriter, req *http.Request) (Server, error) { r.mutex.Lock() defer r.mutex.Unlock() @@ -140,7 +146,7 @@ func (r *RoundRobin) nextServer() (*server, error) { } } srv := r.servers[r.index] - if srv.weight >= r.currentWeight { + if srv.Weight() >= r.currentWeight { return srv, nil } } @@ -167,7 +173,7 @@ func (r *RoundRobin) Servers() []*url.URL { out := make([]*url.URL, len(r.servers)) for i, srv := range r.servers { - out[i] = srv.url + out[i] = srv.URL() } return out } @@ -178,7 +184,7 @@ func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { defer r.mutex.Unlock() if s, _ := r.findServerByURL(u); s != nil { - return s.weight, true + return s.Weight(), true } return -1, false } @@ -227,12 +233,12 @@ func (r *RoundRobin) resetState() { r.resetIterator() } -func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) { +func (r *RoundRobin) findServerByURL(u *url.URL) (Server, int) { if len(r.servers) == 0 { return nil, -1 } for i, s := range r.servers { - if sameURL(u, s.url) { + if sameURL(u, s.URL()) { return s, i } } @@ -242,8 +248,8 @@ func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) { func (r *RoundRobin) maxWeight() int { max := -1 for _, s := range r.servers { - if s.weight > max { - max = s.weight + if s.Weight() > max { + max = s.Weight() } } return max @@ -253,9 +259,9 @@ func (r *RoundRobin) weightGcd() int { divisor := -1 for _, s := range r.servers { if divisor == -1 { - divisor = s.weight + divisor = s.Weight() } else { - divisor = gcd(divisor, s.weight) + divisor = gcd(divisor, s.Weight()) } } return divisor @@ -275,6 +281,18 @@ type server struct { weight int } +func (that *server) URL() *url.URL { + return that.url +} + +func (that *server) Weight() int { + return that.weight +} + +func (that *server) Set(weight int) { + that.weight = weight +} + var defaultWeight = 1 // SetDefaultWeight sets the default server weight. diff --git a/roundrobin/strategy.go b/roundrobin/strategy.go new file mode 100644 index 0000000..f9b452d --- /dev/null +++ b/roundrobin/strategy.go @@ -0,0 +1,70 @@ +package roundrobin + +import ( + "net/http" + "net/url" + "sort" +) + +func init() { + var _ LBStrategy = new(CompositeStrategy) +} +func Strategy() LBStrategy { + return strategies +} + +func Provide(lbs LBStrategy) { + strategies.Add(lbs) +} + +var strategies = new(CompositeStrategy) + +type Server interface { + + // URL server url. + URL() *url.URL + + // Weight Relative weight for the endpoint to other endpoints in the load balancer. + Weight() int + + // Set the weight. + Set(weight int) +} + +type LBStrategy interface { + + // Name is the strategy name. + Name() string + + // Priority more than has more priority. + Priority() int + + // Next servers + // Load balancer extension for custom rules filter. + Next(w http.ResponseWriter, req *http.Request, neq *http.Request, servers []Server) []Server +} + +type CompositeStrategy struct { + strategies []LBStrategy +} + +func (that *CompositeStrategy) Add(lbs LBStrategy) *CompositeStrategy { + that.strategies = append(that.strategies, lbs) + sort.Slice(that.strategies, func(i, j int) bool { return that.strategies[i].Priority() < that.strategies[j].Priority() }) + return that +} + +func (that *CompositeStrategy) Name() string { + return "composite" +} + +func (that *CompositeStrategy) Priority() int { + return 0 +} + +func (that *CompositeStrategy) Next(w http.ResponseWriter, req *http.Request, neq *http.Request, servers []Server) []Server { + for _, strategy := range that.strategies { + servers = strategy.Next(w, req, neq, servers) + } + return servers +} From 1be0c32e5e7796a55762c4d56ec1583f5ad6cc61 Mon Sep 17 00:00:00 2001 From: coyzeng Date: Wed, 7 Dec 2022 19:37:23 +0800 Subject: [PATCH 2/2] strip strategy --- roundrobin/rr.go | 5 ++--- roundrobin/strategy.go | 10 ++++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/roundrobin/rr.go b/roundrobin/rr.go index 564a919..64d2abe 100644 --- a/roundrobin/rr.go +++ b/roundrobin/rr.go @@ -106,9 +106,8 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { // NextServer gets the next server. func (r *RoundRobin) NextServer(w http.ResponseWriter, req *http.Request, neq *http.Request) (*url.URL, error) { // Use extension balance server, if extension return multiple servers, choose anyone. - if ss := Strategy().Next(w, req, neq, r.servers); len(ss) > 0 { - srv := ss[rand.Intn(len(ss))] - return utils.CopyURL(srv.URL()), nil + if ss := Strategy().Next(w, req, neq, r.servers); len(ss) > 0 && (len(ss) < len(r.servers) || len(r.servers) < 1) { + return Strategy().Strip(w, req, neq, utils.CopyURL(ss[rand.Intn(len(ss))].URL())), nil } srv, err := r.nextServer(w, req) if err != nil { diff --git a/roundrobin/strategy.go b/roundrobin/strategy.go index f9b452d..206a410 100644 --- a/roundrobin/strategy.go +++ b/roundrobin/strategy.go @@ -42,6 +42,9 @@ type LBStrategy interface { // Next servers // Load balancer extension for custom rules filter. Next(w http.ResponseWriter, req *http.Request, neq *http.Request, servers []Server) []Server + + // Strip filter the server URL + Strip(w http.ResponseWriter, req *http.Request, neq *http.Request, uri *url.URL) *url.URL } type CompositeStrategy struct { @@ -68,3 +71,10 @@ func (that *CompositeStrategy) Next(w http.ResponseWriter, req *http.Request, ne } return servers } + +func (that *CompositeStrategy) Strip(w http.ResponseWriter, req *http.Request, neq *http.Request, uri *url.URL) *url.URL { + for _, strategy := range that.strategies { + uri = strategy.Strip(w, req, neq, uri) + } + return uri +}