Skip to content

Commit

Permalink
Merge pull request #166 from monzo/add-max-connection-age-server-option
Browse files Browse the repository at this point in the history
Add max connection age typhon server option
  • Loading branch information
danielchatfield authored May 5, 2023
2 parents b08f231 + 32e4835 commit ea17ae8
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
39 changes: 38 additions & 1 deletion e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ import (
"time"

"github.com/fortytw2/leaktest"
"github.com/monzo/terrors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"

"github.com/monzo/terrors"

"github.com/monzo/typhon/prototest"
)

Expand Down Expand Up @@ -840,3 +841,39 @@ func TestE2EServerTimeouts(t *testing.T) {
}
})
}

// TestE2EMaxConnectionAge_ConnectionAgeSet tests that when using the
// WithMaxConnectionAge server option the connection age is available
// to the handler. This is a fairly weak test as it doesn't actually
// test that connections are closed. Unfortunately our test framework
// seems to always close connections so this is as good as we can get.
func TestE2EMaxConnectionAge_ConnectionAgeSet(t *testing.T) {
addConnectionStartTimeHeader = true

someFlavours(t, []string{
"http1.1",
"http1.1-tls",
"http2.0-h2",

// The Go h2c implementation doesn't currently support max
// connection age.
}, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := flav.Context()
defer cancel()

srv := Service(func(req Request) Response {
time.Sleep(100 * time.Millisecond)
return NewResponse(req)
})

srv = srv.Filter(ErrorFilter)
s := flav.Serve(srv, WithMaxConnectionAge(time.Hour))
defer s.Stop(ctx)

req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
assert.NoError(t, rsp.Error)
connectionStart := rsp.Header.Get(connectionStartTimeHeaderKey)
require.NotEmpty(t, connectionStart)
})
}
78 changes: 78 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,81 @@ func WithTimeout(opts TimeoutOptions) ServerOption {
s.srv.IdleTimeout = opts.Idle
}
}

var (
connectionStartTimeHeaderKey = "X-Typhon-Connection-Start"
// addConnectionStartTimeHeader is set to true within tests to
// make it easier to test the server option.
addConnectionStartTimeHeader = false
)

// WithMaxConnectionAge returns a server option that will enforce a max
// connection age. When a connection has reached the max connection age
// then the next request that is processed on that connection will result
// in the connection being gracefully closed. This does mean that if a
// connection is not being used then it can outlive the maximum connection
// age.
func WithMaxConnectionAge(maxAge time.Duration) ServerOption {
// We have no ability within a handler to get access to the
// underlying net.Conn that the request came on. However,
// the http.Server has a ConnContext field that can be used
// to specify a function that can modify the context used for
// that connection. We can use this to store the connection
// start time in the context and then in the handler we can
// read that out and whenever the maxAge has been exceeded we
// can close the connection.
//
// We could close the connection by calling the Close method
// on the net.Conn. This would have the benefit that we could
// close the connection exactly at the expiry but would have
// the disadvantage that it does not gracefully close the
// connection – it would kill all in-flight requests. Instead,
// we set the 'Connection: close' response header which will
// be translated into an HTTP2 GOAWAY frame and result in the
// connection being gracefully closed.

return func(s *Server) {
// Wrap the current ConnContext (if set) to store a reference
// to the connection start time in the context.
origConnContext := s.srv.ConnContext
s.srv.ConnContext = func(ctx context.Context, conn net.Conn) context.Context {
if origConnContext != nil {
ctx = origConnContext(ctx, conn)
}

return setConnectionStartTimeInContext(ctx, time.Now())
}

// Wrap the handler to set the 'Connection: close' response
// header if the max age has been exceeded.
origHandler := s.srv.Handler
s.srv.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
connectionStart, ok := readConnectionStartTimeFromContext(request.Context())
if ok {
if time.Since(connectionStart) > maxAge {
h := writer.Header()
h.Add("Connection", "close")
}

// This is used within tests
if addConnectionStartTimeHeader {
h := writer.Header()
h.Add(connectionStartTimeHeaderKey, connectionStart.String())
}
}

origHandler.ServeHTTP(writer, request)
})
}
}

type connectionContextKey struct{}

func setConnectionStartTimeInContext(parent context.Context, t time.Time) context.Context {
return context.WithValue(parent, connectionContextKey{}, t)
}

func readConnectionStartTimeFromContext(ctx context.Context) (time.Time, bool) {
conn, ok := ctx.Value(connectionContextKey{}).(time.Time)
return conn, ok
}

0 comments on commit ea17ae8

Please sign in to comment.