Skip to content

Commit

Permalink
Implement max concurrency control (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
szaffarano authored Oct 17, 2021
1 parent 165c459 commit 0588d59
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pkg/task/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Serve(cfg config.Config) (err error) {
Process(client, auth, ra)
}

server, err := transport.NewServer(tlsConfig, handler)
server, err := transport.NewServer(tlsConfig, cfg.GetInt(QueueSize), handler)
if err != nil {
return fmt.Errorf("initializing server: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/task/transport/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ type Server interface {
type Handler func(io.ReadWriteCloser)

// NewServer creates a new taskd server working according to the configuration
func NewServer(cfg TLSConfig, handler Handler) (Server, error) {
return newTLSServer(cfg, handler)
func NewServer(cfg TLSConfig, maxConcurrency int, handler Handler) (Server, error) {
return newTLSServer(cfg, maxConcurrency, handler)
}
14 changes: 10 additions & 4 deletions pkg/task/transport/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func init() {
}

// NewTlsServer creates a new tls-based server
func newTLSServer(cfg TLSConfig, handlerFunc Handler) (Server, error) {
func newTLSServer(cfg TLSConfig, maxConcurrency int, handlerFunc Handler) (Server, error) {
var ca []byte
var cert tls.Certificate
var err error
Expand Down Expand Up @@ -72,7 +72,7 @@ func newTLSServer(cfg TLSConfig, handlerFunc Handler) (Server, error) {
server.wg.Add(1)
server.handler = handlerFunc

go server.serve()
go server.serve(maxConcurrency)

return &server, nil
}
Expand All @@ -97,9 +97,11 @@ func (s *tlsServer) Close() error {
return err
}

func (s *tlsServer) serve() {
func (s *tlsServer) serve(maxConcurrency int) {
defer s.wg.Done()

concurrency := make(chan interface{}, maxConcurrency)

for {
conn, err := s.listener.Accept()
if err != nil {
Expand All @@ -111,8 +113,12 @@ func (s *tlsServer) serve() {
}
}
s.wg.Add(1)
concurrency <- 1
go func() {
defer s.wg.Done()
defer func() {
<-concurrency
s.wg.Done()
}()

s.handler(conn)
}()
Expand Down
70 changes: 68 additions & 2 deletions pkg/task/transport/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/ioutil"
"net"
"path/filepath"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -75,14 +76,79 @@ func TestServer(t *testing.T) {
BindAddress: filepath.Join(base, c.bindAddress),
}

srv, err := NewServer(cfg, dummyHandler)
srv, err := NewServer(cfg, 1, dummyHandler)
assert.NotNil(t, err)
assert.Nil(t, srv)
})
}
})
}

func TestMaxConcurrency(t *testing.T) {
maxConcurrency := 3

base := filepath.Join("testdata", "certs")
srvConfig := TLSConfig{
CaCert: filepath.Join(base, "ca.pem"),
ServerCert: filepath.Join(base, "server.pem"),
ServerKey: filepath.Join(base, "server.key"),
BindAddress: fmt.Sprintf("localhost:%d", nextFreePort(t, 1025)),
}
clientCfg := newTLSConfig(t, "client.conf")
var wg sync.WaitGroup
wg.Add(1)
ack := make(chan interface{})

handler := func(client io.ReadWriteCloser) {
defer client.Close()

buf := make([]byte, 10)
count, err := client.Read(buf)
assert.Nil(t, err)
assert.Greater(t, count, 0)
ack <- 1
wg.Wait()
}

srv, err := newTLSServer(srvConfig, maxConcurrency, handler)
assert.Nil(t, err)
defer srv.Close()

for i := 0; i < maxConcurrency+1; i++ {
go func() {
client, err := tls.Dial("tcp", srvConfig.BindAddress, clientCfg)
if err != nil {
assert.FailNow(t, err.Error())
}

// force handshake
_, err = client.Write([]byte("ping"))
if err != nil {
assert.FailNow(t, err.Error())
}
}()
}

received := 0
timeouted := false
for received < maxConcurrency+1 {
select {
case <-ack:
received++
case <-time.After(1000 * time.Millisecond):
assert.False(t, timeouted)
assert.Equal(t, maxConcurrency, received)
timeouted = true
wg.Done()
}
}
if !assert.True(t, timeouted, "No concurrency bounded applied") {
// finish all the ongoing connections
wg.Done()
}

}

func newTaskdClientServer(t *testing.T, clCfgFile string) (net.Conn, io.ReadWriteCloser, func()) {
t.Helper()

Expand Down Expand Up @@ -112,7 +178,7 @@ func newTaskdClientServer(t *testing.T, clCfgFile string) (net.Conn, io.ReadWrit
ready <- buf[:size]
}

srv, err := newTLSServer(srvConfig, handler)
srv, err := newTLSServer(srvConfig, 1, handler)
if err != nil {
assert.FailNowf(t, "Error creating server: %s", err.Error())
}
Expand Down

0 comments on commit 0588d59

Please sign in to comment.