Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

story(issue-295): rest interface out mux #296

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)
[![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock)
[![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock)
![Coverage](https://img.shields.io/badge/Coverage-97.0%25-brightgreen)
![Coverage](https://img.shields.io/badge/Coverage-96.5%25-brightgreen)
[![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml)

**bedrock provides a minimal, modular and composable foundation for
Expand Down
3 changes: 2 additions & 1 deletion example/custom_framework/framework/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/z5labs/bedrock"
"github.com/z5labs/bedrock/pkg/app"
"github.com/z5labs/bedrock/rest"
"github.com/z5labs/bedrock/rest/mux"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
Expand Down Expand Up @@ -120,7 +121,7 @@ func HttpServer(cfg HttpServerConfig) Option {
}

type Endpoint struct {
Method string
Method mux.Method
Path string
Operation Operation
}
Expand Down
172 changes: 172 additions & 0 deletions rest/mux/mux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (c) 2024 Z5Labs and Contributors
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

// Package mux defines a simple API for all http multiplexers to implement.
package mux

import (
"fmt"
"net/http"
"path"
"slices"
"strings"
"sync"
)

// Method defines an HTTP method expected to be used in a RESTful API.
type Method string

const (
MethodGet Method = http.MethodGet
MethodPut Method = http.MethodPut
MethodPost Method = http.MethodPost
MethodDelete Method = http.MethodDelete
)

// HttpOption defines a configuration option for [Http].
type HttpOption func(*Http)

// NotFoundHandler will register the given [http.Handler] to handle
// any HTTP requests that do not match any other method-pattern combinations.
func NotFoundHandler(h http.Handler) HttpOption {
return func(mux *Http) {
mux.notFound = h
}
}

// MethodNotAllowedHandler will register the given [http.Handler] to handle
// any HTTP requests whose method does not match the method registered to a pattern.
func MethodNotAllowedHandler(h http.Handler) HttpOption {
return func(mux *Http) {
mux.methodNotAllowed = h
}
}

// Http wraps a [http.ServeMux] and provides some helpers around overriding
// the default "HTTP 404 Not Found" and "HTTP 405 Method Not Allowed" behaviour.
type Http struct {
mux *http.ServeMux

initFallbacksOnce sync.Once
notFound http.Handler
methodNotAllowed http.Handler

pathMethods map[string][]Method
}

// NewHttp initializes a request multiplexer using the standard [http.ServeMux.]
func NewHttp(opts ...HttpOption) *Http {
mux := &Http{
mux: http.NewServeMux(),
pathMethods: make(map[string][]Method),
}
for _, opt := range opts {
opt(mux)
}
return mux
}

// Handle will register the [http.Handler] for the given method and pattern
// with the underlying [http.ServeMux]. The method and pattern will be formatted
// together as "method pattern" when calling [http.ServeMux.Handle].
func (m *Http) Handle(method Method, pattern string, h http.Handler) {
m.pathMethods[pattern] = append(m.pathMethods[pattern], method)
m.mux.Handle(fmt.Sprintf("%s %s", method, pattern), h)

// {$} is a special case where we only want to exact match the path pattern.
if strings.HasSuffix(pattern, "{$}") {
return
}

if strings.HasSuffix(pattern, "/") {
withoutTrailingSlash := pattern[:len(pattern)-1]
if len(withoutTrailingSlash) == 0 {
return
}

m.pathMethods[withoutTrailingSlash] = append(m.pathMethods[withoutTrailingSlash], method)
m.mux.Handle(fmt.Sprintf("%s %s", method, withoutTrailingSlash), h)
return
}

// if the end of the path contains the "..." wildcard segment
// then we can't add a "/" to it since "..." should not be followed
// by a "/", per the http.ServeMux docs.
base := path.Base(pattern)
if strings.Contains(base, "...") {
return
}

withTrailingSlash := pattern + "/"
m.pathMethods[withTrailingSlash] = append(m.pathMethods[withTrailingSlash], method)
m.mux.Handle(fmt.Sprintf("%s %s", method, withTrailingSlash), h)
}

// ServeHTTP implements the [http.Handler] interface.
func (m *Http) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.initFallbacksOnce.Do(m.registerFallbackHandlers)

m.mux.ServeHTTP(w, r)
}

func (m *Http) registerFallbackHandlers() {
fs := []func(*http.ServeMux){
registerNotFoundHandler(m.notFound),
registerMethodNotAllowedHandler(m.methodNotAllowed, m.pathMethods),
}
for _, f := range fs {
f(m.mux)
}
}

func registerNotFoundHandler(h http.Handler) func(*http.ServeMux) {
return func(mux *http.ServeMux) {
if h == nil {
return
}
mux.Handle("/{path...}", h)
}
}

func registerMethodNotAllowedHandler(h http.Handler, pathMethods map[string][]Method) func(*http.ServeMux) {
return func(mux *http.ServeMux) {
if h == nil {
return
}
if len(pathMethods) == 0 {
return
}

// this list is pulled from the OpenAPI v3 Path Item Object documentation.
supportedMethods := []Method{
http.MethodGet,
http.MethodPut,
http.MethodPost,
http.MethodDelete,
http.MethodOptions,
http.MethodHead,
http.MethodPatch,
http.MethodTrace,
}

for path, methods := range pathMethods {
unsupportedMethods := diffSets(supportedMethods, methods)
for _, method := range unsupportedMethods {
mux.Handle(fmt.Sprintf("%s %s", method, path), h)
}
}
}
}

func diffSets[T comparable](xs, ys []T) []T {
zs := make([]T, 0, len(xs))
for _, x := range xs {
if slices.Contains(ys, x) {
continue
}
zs = append(zs, x)
}
return zs
}
Loading