Skip to content

Commit

Permalink
story(issue-277): endpoint pass contextcontext to errorhandler (#278)
Browse files Browse the repository at this point in the history
* feat(issue-277): update error handler interface to accept context

* docs(issue-277): improve package documentation
  • Loading branch information
Zaba505 authored Sep 17, 2024
1 parent 74d5fa6 commit 6d37d14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 36 deletions.
73 changes: 48 additions & 25 deletions rest/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,29 @@ import (
"go.opentelemetry.io/otel"
)

// Handler
// Handler defines an RPC inspired way of handling HTTP requests.
//
// Req and Resp can implement various interfaces which [Operation]
// uses to automate many tasks before and after calling your Handler.
// For example, [Operation] handles unmarshaling and marshaling the request (Req)
// and response (Resp) types automatically if they implement [encoding.BinaryUnmarshaler]
// and [encoding.BinaryMarshaler], respectively.
type Handler[Req, Resp any] interface {
Handle(context.Context, *Req) (*Resp, error)
}

// HandlerFunc
// HandlerFunc is an adapter type to allow the use of ordinary functions as [Handler]s.
type HandlerFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error)

// Handle implements the [Handler] interface.
func (f HandlerFunc[Req, Resp]) Handle(ctx context.Context, req *Req) (*Resp, error) {
return f(ctx, req)
}

// ErrorHandler
// ErrorHandler defines the behaviour taken by [Operation]
// when a [Handler] returns an [error].
type ErrorHandler interface {
HandleError(http.ResponseWriter, error)
HandleError(context.Context, http.ResponseWriter, error)
}

type options struct {
Expand All @@ -52,10 +59,12 @@ type options struct {
openapi openapi3.Operation
}

// Option
// Option configures a [Operation].
type Option func(*options)

// Operation
// Operation is a RPC inspired [http.Handler] (aka endpoint) that also
// keeps track of the associated types and parameters
// in order to construct an OpenAPI operation definition.
type Operation[Req, Resp any] struct {
validators []func(*http.Request) error
injectors []injector
Expand All @@ -68,23 +77,26 @@ type Operation[Req, Resp any] struct {
openapi openapi3.Operation
}

// DefaultStatusCode is the default HTTP status code returned
// by an [Operation] when the underlying [Handler] does not return an [error].
const DefaultStatusCode = http.StatusOK

// StatusCode
// StatusCode will change the HTTP status code that is returned
// by an [Operation] when the underlying [Handler] does not return an [error].
func StatusCode(statusCode int) Option {
return func(ho *options) {
ho.defaultStatusCode = statusCode
}
}

// PathParam
// PathParam defines a URL path parameter e.g. /book/{id} where id is the path param.
type PathParam struct {
Name string
Pattern string
Required bool
}

// PathParams
// PathParams registers the [PathParam]s with the OpenAPI operation definition.
func PathParams(ps ...PathParam) Option {
return func(o *options) {
for _, p := range ps {
Expand All @@ -107,14 +119,14 @@ func PathParams(ps ...PathParam) Option {
}
}

// Header
// Header defines a HTTP header.
type Header struct {
Name string
Pattern string
Required bool
}

// Headers
// Headers registers the [Header]s with the OpenAPI operation definition.
func Headers(hs ...Header) Option {
return func(o *options) {
for _, h := range hs {
Expand All @@ -139,14 +151,14 @@ func Headers(hs ...Header) Option {
}
}

// QueryParam
// QueryParam defines a URL query parameter e.g. /book?id=123
type QueryParam struct {
Name string
Pattern string
Required bool
}

// QueryParams
// QueryParams registers the [QueryParam]s with the OpenAPI operation definition.
func QueryParams(qps ...QueryParam) Option {
return func(o *options) {
for _, qp := range qps {
Expand All @@ -171,12 +183,16 @@ func QueryParams(qps ...QueryParam) Option {
}
}

// ContentTyper
// ContentTyper is the interface which request and response types
// should implement in order to allow the [Operation] to automatically
// validate and set the "Content-Type" HTTP Header along with
// properly documenting the types in the OpenAPI operation definition.
type ContentTyper interface {
ContentType() string
}

// Accepts
// Accepts registers the Req type in the OpenAPI operation definition
// as a possible request to the [Operation].
func Accepts[Req any]() Option {
return func(o *options) {
contentType := ""
Expand Down Expand Up @@ -214,7 +230,8 @@ func Accepts[Req any]() Option {
}
}

// Returns
// Returns registers the status code in the OpenAPI operation
// definition as a possible response from the [Operation].
func Returns(status int) Option {
return func(o *options) {
o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{
Expand All @@ -223,7 +240,8 @@ func Returns(status int) Option {
}
}

// ReturnsWith
// ReturnsWith registers the Resp type and status code in the OpenAPI
// operation definition as a possible response from the [Operation].
func ReturnsWith[Resp any](status int) Option {
return func(o *options) {
var resp Resp
Expand Down Expand Up @@ -256,20 +274,24 @@ func ReturnsWith[Resp any](status int) Option {
}
}

// OnError
// OnError registers the [ErrorHandler] with the [Operation]. Any
// [error]s returned by the underlying [Handler] will be passed to
// this [ErrorHandler].
func OnError(eh ErrorHandler) Option {
return func(o *options) {
o.errHandler = eh
}
}

type errorHandlerFunc func(http.ResponseWriter, error)
type errorHandlerFunc func(context.Context, http.ResponseWriter, error)

func (f errorHandlerFunc) HandleError(w http.ResponseWriter, err error) {
f(w, err)
func (f errorHandlerFunc) HandleError(ctx context.Context, w http.ResponseWriter, err error) {
f(ctx, w, err)
}

// DefaultErrorStatusCode
// DefaultErrorStatusCode is the default HTTP status code returned by
// an [Operation] if no [ErrorHandler] has been registered with the
// [OnError] option and the underlying [Handler] returns an [error].
const DefaultErrorStatusCode = http.StatusInternalServerError

// NewOperation initializes a Operation.
Expand All @@ -279,7 +301,7 @@ func NewOperation[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Op
pathParams: make(map[PathParam]struct{}),
headerParams: make(map[Header]struct{}),
queryParams: make(map[QueryParam]struct{}),
errHandler: errorHandlerFunc(func(w http.ResponseWriter, err error) {
errHandler: errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
w.WriteHeader(DefaultErrorStatusCode)
}),
openapi: openapi3.Operation{
Expand Down Expand Up @@ -337,6 +359,7 @@ func initInjectors(o *options) []injector {
return injectors
}

// OpenApi returns the OpenAPI operation definition for this endpoint.
func (op *Operation[Req, Resp]) OpenApi() openapi3.Operation {
return op.openapi
}
Expand Down Expand Up @@ -471,8 +494,8 @@ func (op *Operation[Req, Resp]) writeResponse(ctx context.Context, w http.Respon
}

func (op *Operation[Req, Resp]) handleError(ctx context.Context, w http.ResponseWriter, err error) {
_, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError")
spanCtx, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError")
defer span.End()

op.errHandler.HandleError(w, err)
op.errHandler.HandleError(spanCtx, w, err)
}
22 changes: 11 additions & 11 deletions rest/endpoint/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[Empty, JsonContent](func(_ context.Context, _ *Empty) (*JsonContent, error) {
return nil, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand All @@ -401,7 +401,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[Empty, ReaderContent](func(_ context.Context, _ *Empty) (*ReaderContent, error) {
return nil, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -432,7 +432,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[Empty, Empty](func(_ context.Context, _ *Empty) (*Empty, error) {
return nil, errors.New("failed")
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
w.WriteHeader(errStatusCode)
})),
)
Expand All @@ -458,7 +458,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
Required: true,
},
),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -494,7 +494,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
Pattern: "^[a-zA-Z]*$",
},
),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -531,7 +531,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
Required: true,
},
),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -567,7 +567,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
Pattern: "^[a-zA-Z]*$",
},
),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -599,7 +599,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[JsonContent, Empty](func(_ context.Context, _ *JsonContent) (*Empty, error) {
return &Empty{}, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down Expand Up @@ -632,7 +632,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[FailUnmarshalBinary, Empty](func(_ context.Context, _ *FailUnmarshalBinary) (*Empty, error) {
return &Empty{}, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand All @@ -659,7 +659,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[InvalidRequest, Empty](func(_ context.Context, _ *InvalidRequest) (*Empty, error) {
return &Empty{}, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand All @@ -686,7 +686,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) {
HandlerFunc[Empty, FailMarshalBinary](func(_ context.Context, _ *Empty) (*FailMarshalBinary, error) {
return &FailMarshalBinary{}, nil
}),
OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) {
OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) {
caughtError = err

w.WriteHeader(DefaultErrorStatusCode)
Expand Down

0 comments on commit 6d37d14

Please sign in to comment.