From 0588d59abfe94317121f3419d2495ac474b80d46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Zaffarano?= Date: Sun, 17 Oct 2021 08:38:23 -0300 Subject: [PATCH] Implement max concurrency control (#35) --- pkg/task/daemon.go | 2 +- pkg/task/transport/server.go | 4 +- pkg/task/transport/tls.go | 14 +++++-- pkg/task/transport/tls_test.go | 70 +++++++++++++++++++++++++++++++++- 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/pkg/task/daemon.go b/pkg/task/daemon.go index 362f2b2..ccd2fd0 100644 --- a/pkg/task/daemon.go +++ b/pkg/task/daemon.go @@ -35,7 +35,7 @@ func Serve(cfg config.Config) (err error) { Process(client, auth, ra) } - server, err := transport.NewServer(tlsConfig, handler) + server, err := transport.NewServer(tlsConfig, cfg.GetInt(QueueSize), handler) if err != nil { return fmt.Errorf("initializing server: %v", err) } diff --git a/pkg/task/transport/server.go b/pkg/task/transport/server.go index 4b6f87e..14e6395 100644 --- a/pkg/task/transport/server.go +++ b/pkg/task/transport/server.go @@ -15,6 +15,6 @@ type Server interface { type Handler func(io.ReadWriteCloser) // NewServer creates a new taskd server working according to the configuration -func NewServer(cfg TLSConfig, handler Handler) (Server, error) { - return newTLSServer(cfg, handler) +func NewServer(cfg TLSConfig, maxConcurrency int, handler Handler) (Server, error) { + return newTLSServer(cfg, maxConcurrency, handler) } diff --git a/pkg/task/transport/tls.go b/pkg/task/transport/tls.go index ada7893..b66a2ce 100644 --- a/pkg/task/transport/tls.go +++ b/pkg/task/transport/tls.go @@ -26,7 +26,7 @@ func init() { } // NewTlsServer creates a new tls-based server -func newTLSServer(cfg TLSConfig, handlerFunc Handler) (Server, error) { +func newTLSServer(cfg TLSConfig, maxConcurrency int, handlerFunc Handler) (Server, error) { var ca []byte var cert tls.Certificate var err error @@ -72,7 +72,7 @@ func newTLSServer(cfg TLSConfig, handlerFunc Handler) (Server, error) { server.wg.Add(1) server.handler = handlerFunc - go server.serve() + go server.serve(maxConcurrency) return &server, nil } @@ -97,9 +97,11 @@ func (s *tlsServer) Close() error { return err } -func (s *tlsServer) serve() { +func (s *tlsServer) serve(maxConcurrency int) { defer s.wg.Done() + concurrency := make(chan interface{}, maxConcurrency) + for { conn, err := s.listener.Accept() if err != nil { @@ -111,8 +113,12 @@ func (s *tlsServer) serve() { } } s.wg.Add(1) + concurrency <- 1 go func() { - defer s.wg.Done() + defer func() { + <-concurrency + s.wg.Done() + }() s.handler(conn) }() diff --git a/pkg/task/transport/tls_test.go b/pkg/task/transport/tls_test.go index 92dff5a..b3fe86d 100644 --- a/pkg/task/transport/tls_test.go +++ b/pkg/task/transport/tls_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net" "path/filepath" + "sync" "testing" "time" @@ -75,7 +76,7 @@ func TestServer(t *testing.T) { BindAddress: filepath.Join(base, c.bindAddress), } - srv, err := NewServer(cfg, dummyHandler) + srv, err := NewServer(cfg, 1, dummyHandler) assert.NotNil(t, err) assert.Nil(t, srv) }) @@ -83,6 +84,71 @@ func TestServer(t *testing.T) { }) } +func TestMaxConcurrency(t *testing.T) { + maxConcurrency := 3 + + base := filepath.Join("testdata", "certs") + srvConfig := TLSConfig{ + CaCert: filepath.Join(base, "ca.pem"), + ServerCert: filepath.Join(base, "server.pem"), + ServerKey: filepath.Join(base, "server.key"), + BindAddress: fmt.Sprintf("localhost:%d", nextFreePort(t, 1025)), + } + clientCfg := newTLSConfig(t, "client.conf") + var wg sync.WaitGroup + wg.Add(1) + ack := make(chan interface{}) + + handler := func(client io.ReadWriteCloser) { + defer client.Close() + + buf := make([]byte, 10) + count, err := client.Read(buf) + assert.Nil(t, err) + assert.Greater(t, count, 0) + ack <- 1 + wg.Wait() + } + + srv, err := newTLSServer(srvConfig, maxConcurrency, handler) + assert.Nil(t, err) + defer srv.Close() + + for i := 0; i < maxConcurrency+1; i++ { + go func() { + client, err := tls.Dial("tcp", srvConfig.BindAddress, clientCfg) + if err != nil { + assert.FailNow(t, err.Error()) + } + + // force handshake + _, err = client.Write([]byte("ping")) + if err != nil { + assert.FailNow(t, err.Error()) + } + }() + } + + received := 0 + timeouted := false + for received < maxConcurrency+1 { + select { + case <-ack: + received++ + case <-time.After(1000 * time.Millisecond): + assert.False(t, timeouted) + assert.Equal(t, maxConcurrency, received) + timeouted = true + wg.Done() + } + } + if !assert.True(t, timeouted, "No concurrency bounded applied") { + // finish all the ongoing connections + wg.Done() + } + +} + func newTaskdClientServer(t *testing.T, clCfgFile string) (net.Conn, io.ReadWriteCloser, func()) { t.Helper() @@ -112,7 +178,7 @@ func newTaskdClientServer(t *testing.T, clCfgFile string) (net.Conn, io.ReadWrit ready <- buf[:size] } - srv, err := newTLSServer(srvConfig, handler) + srv, err := newTLSServer(srvConfig, 1, handler) if err != nil { assert.FailNowf(t, "Error creating server: %s", err.Error()) }