diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index e8061a2..f4af127 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -22,12 +22,18 @@ import ( "go.opentelemetry.io/otel" ) -// Handler +// Handler defines an RPC inspired way of handling HTTP requests. +// +// Req and Resp can implement various interfaces which [Operation] +// uses to automate many tasks before and after calling your Handler. +// For example, [Operation] handles unmarshaling and marshaling the request (Req) +// and response (Resp) types automatically if they implement [encoding.BinaryUnmarshaler] +// and [encoding.BinaryMarshaler], respectively. type Handler[Req, Resp any] interface { Handle(context.Context, *Req) (*Resp, error) } -// HandlerFunc +// HandlerFunc is an adapter type to allow the use of ordinary functions as [Handler]s. type HandlerFunc[Req, Resp any] func(context.Context, *Req) (*Resp, error) // Handle implements the [Handler] interface. @@ -35,9 +41,10 @@ func (f HandlerFunc[Req, Resp]) Handle(ctx context.Context, req *Req) (*Resp, er return f(ctx, req) } -// ErrorHandler +// ErrorHandler defines the behaviour taken by [Operation] +// when a [Handler] returns an [error]. type ErrorHandler interface { - HandleError(http.ResponseWriter, error) + HandleError(context.Context, http.ResponseWriter, error) } type options struct { @@ -52,10 +59,12 @@ type options struct { openapi openapi3.Operation } -// Option +// Option configures a [Operation]. type Option func(*options) -// Operation +// Operation is a RPC inspired [http.Handler] (aka endpoint) that also +// keeps track of the associated types and parameters +// in order to construct an OpenAPI operation definition. type Operation[Req, Resp any] struct { validators []func(*http.Request) error injectors []injector @@ -68,23 +77,26 @@ type Operation[Req, Resp any] struct { openapi openapi3.Operation } +// DefaultStatusCode is the default HTTP status code returned +// by an [Operation] when the underlying [Handler] does not return an [error]. const DefaultStatusCode = http.StatusOK -// StatusCode +// StatusCode will change the HTTP status code that is returned +// by an [Operation] when the underlying [Handler] does not return an [error]. func StatusCode(statusCode int) Option { return func(ho *options) { ho.defaultStatusCode = statusCode } } -// PathParam +// PathParam defines a URL path parameter e.g. /book/{id} where id is the path param. type PathParam struct { Name string Pattern string Required bool } -// PathParams +// PathParams registers the [PathParam]s with the OpenAPI operation definition. func PathParams(ps ...PathParam) Option { return func(o *options) { for _, p := range ps { @@ -107,14 +119,14 @@ func PathParams(ps ...PathParam) Option { } } -// Header +// Header defines a HTTP header. type Header struct { Name string Pattern string Required bool } -// Headers +// Headers registers the [Header]s with the OpenAPI operation definition. func Headers(hs ...Header) Option { return func(o *options) { for _, h := range hs { @@ -139,14 +151,14 @@ func Headers(hs ...Header) Option { } } -// QueryParam +// QueryParam defines a URL query parameter e.g. /book?id=123 type QueryParam struct { Name string Pattern string Required bool } -// QueryParams +// QueryParams registers the [QueryParam]s with the OpenAPI operation definition. func QueryParams(qps ...QueryParam) Option { return func(o *options) { for _, qp := range qps { @@ -171,12 +183,16 @@ func QueryParams(qps ...QueryParam) Option { } } -// ContentTyper +// ContentTyper is the interface which request and response types +// should implement in order to allow the [Operation] to automatically +// validate and set the "Content-Type" HTTP Header along with +// properly documenting the types in the OpenAPI operation definition. type ContentTyper interface { ContentType() string } -// Accepts +// Accepts registers the Req type in the OpenAPI operation definition +// as a possible request to the [Operation]. func Accepts[Req any]() Option { return func(o *options) { contentType := "" @@ -214,7 +230,8 @@ func Accepts[Req any]() Option { } } -// Returns +// Returns registers the status code in the OpenAPI operation +// definition as a possible response from the [Operation]. func Returns(status int) Option { return func(o *options) { o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ @@ -223,7 +240,8 @@ func Returns(status int) Option { } } -// ReturnsWith +// ReturnsWith registers the Resp type and status code in the OpenAPI +// operation definition as a possible response from the [Operation]. func ReturnsWith[Resp any](status int) Option { return func(o *options) { var resp Resp @@ -256,20 +274,24 @@ func ReturnsWith[Resp any](status int) Option { } } -// OnError +// OnError registers the [ErrorHandler] with the [Operation]. Any +// [error]s returned by the underlying [Handler] will be passed to +// this [ErrorHandler]. func OnError(eh ErrorHandler) Option { return func(o *options) { o.errHandler = eh } } -type errorHandlerFunc func(http.ResponseWriter, error) +type errorHandlerFunc func(context.Context, http.ResponseWriter, error) -func (f errorHandlerFunc) HandleError(w http.ResponseWriter, err error) { - f(w, err) +func (f errorHandlerFunc) HandleError(ctx context.Context, w http.ResponseWriter, err error) { + f(ctx, w, err) } -// DefaultErrorStatusCode +// DefaultErrorStatusCode is the default HTTP status code returned by +// an [Operation] if no [ErrorHandler] has been registered with the +// [OnError] option and the underlying [Handler] returns an [error]. const DefaultErrorStatusCode = http.StatusInternalServerError // NewOperation initializes a Operation. @@ -279,7 +301,7 @@ func NewOperation[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Op pathParams: make(map[PathParam]struct{}), headerParams: make(map[Header]struct{}), queryParams: make(map[QueryParam]struct{}), - errHandler: errorHandlerFunc(func(w http.ResponseWriter, err error) { + errHandler: errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { w.WriteHeader(DefaultErrorStatusCode) }), openapi: openapi3.Operation{ @@ -337,6 +359,7 @@ func initInjectors(o *options) []injector { return injectors } +// OpenApi returns the OpenAPI operation definition for this endpoint. func (op *Operation[Req, Resp]) OpenApi() openapi3.Operation { return op.openapi } @@ -471,8 +494,8 @@ func (op *Operation[Req, Resp]) writeResponse(ctx context.Context, w http.Respon } func (op *Operation[Req, Resp]) handleError(ctx context.Context, w http.ResponseWriter, err error) { - _, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError") + spanCtx, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError") defer span.End() - op.errHandler.HandleError(w, err) + op.errHandler.HandleError(spanCtx, w, err) } diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index c8085bc..6bf3037 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -374,7 +374,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[Empty, JsonContent](func(_ context.Context, _ *Empty) (*JsonContent, error) { return nil, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -401,7 +401,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[Empty, ReaderContent](func(_ context.Context, _ *Empty) (*ReaderContent, error) { return nil, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -432,7 +432,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[Empty, Empty](func(_ context.Context, _ *Empty) (*Empty, error) { return nil, errors.New("failed") }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { w.WriteHeader(errStatusCode) })), ) @@ -458,7 +458,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -494,7 +494,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -531,7 +531,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -567,7 +567,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -599,7 +599,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[JsonContent, Empty](func(_ context.Context, _ *JsonContent) (*Empty, error) { return &Empty{}, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -632,7 +632,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[FailUnmarshalBinary, Empty](func(_ context.Context, _ *FailUnmarshalBinary) (*Empty, error) { return &Empty{}, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -659,7 +659,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[InvalidRequest, Empty](func(_ context.Context, _ *InvalidRequest) (*Empty, error) { return &Empty{}, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode) @@ -686,7 +686,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { HandlerFunc[Empty, FailMarshalBinary](func(_ context.Context, _ *Empty) (*FailMarshalBinary, error) { return &FailMarshalBinary{}, nil }), - OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + OnError(errorHandlerFunc(func(ctx context.Context, w http.ResponseWriter, err error) { caughtError = err w.WriteHeader(DefaultErrorStatusCode)