Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Freeman committed Oct 6, 2024
1 parent 4adcf8b commit b5ad89b
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 29 deletions.
5 changes: 5 additions & 0 deletions cmd/event-ingest/configs/.env
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ NATS_MAX_WAIT=10s
NATS_MAX_PULL_WAIT=10
NATS_BATCH_SIZE=100
GRPC_SERVER_ADDRESS=
KEYCLOAK_URL=https://keycloak.threadr.ai
KEYCLOAK_REALM=NGS
OAUTH_CLIENT_ID=eventrunner
OAUTH_CLIENT_SECRET=secret
TOKEN_INTROSPECT_URL=https://keycloak.threadr.ai/realms/CarverAuto/protocol/openid-connect/token/introspect
8 changes: 4 additions & 4 deletions cmd/event-ingest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func main() {
jwtMiddleware.Validate,
middleware.AuthenticateAPIKey,
middleware.RequireRole("admin", "event_publisher"),
func(cc *customctx.CustomContext) (interface{}, error) {
func(cc *customctx.Context) (interface{}, error) {
return httpServer.HandleEvent(cc)
},
))
Expand All @@ -60,14 +60,14 @@ func combineMiddleware(middlewares ...interface{}) gofr.Handler {
return func(c *gofr.Context) (interface{}, error) {
cc := customctx.NewCustomContext(c)

var handler func(*customctx.CustomContext) (interface{}, error)
var handler func(*customctx.Context) (interface{}, error)

// Apply middlewares in reverse order
for i := len(middlewares) - 1; i >= 0; i-- {
switch m := middlewares[i].(type) {
case func(*customctx.CustomContext) (interface{}, error):
case func(*customctx.Context) (interface{}, error):
handler = m
case func(func(*customctx.CustomContext) (interface{}, error)) func(*customctx.CustomContext) (interface{}, error):
case func(func(*customctx.Context) (interface{}, error)) func(*customctx.Context) (interface{}, error):
handler = m(handler)
case func(gofr.Handler) gofr.Handler:
return m(func(*gofr.Context) (interface{}, error) {
Expand Down
7 changes: 0 additions & 7 deletions pkg/api/errors.go

This file was deleted.

17 changes: 16 additions & 1 deletion pkg/api/middleware/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,32 @@ package middleware

import (
"context"

"net/http"

gofrHTTP "gofr.dev/pkg/gofr/http"
)

// contextKey is a custom type for context keys to avoid collisions.
type contextKey string

// APIKeyContextKey is the key used to store the API key in the context.
const APIKeyContextKey contextKey = "APIKey"

// APIKeyMiddleware is a middleware that extracts the API key from the request headers and stores it in the context.
func APIKeyMiddleware() gofrHTTP.Middleware {
return func(inner http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-API-Key")
ctx := context.WithValue(r.Context(), "APIKey", apiKey)
ctx := context.WithValue(r.Context(), APIKeyContextKey, apiKey)
inner.ServeHTTP(w, r.WithContext(ctx))
})
}
}

// GetAPIKeyFromContext retrieves the API key from the context.
func GetAPIKeyFromContext(ctx context.Context) (string, bool) {
apiKey, ok := ctx.Value(APIKeyContextKey).(string)

return apiKey, ok
}
2 changes: 1 addition & 1 deletion pkg/api/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func NewJWTMiddleware(ctx context.Context, config *config.OAuthConfig) (*JWTMidd
}

// Validate validates a JWT token.
func (m *JWTMiddleware) Validate(next func(*customctx.CustomContext) (interface{}, error)) gofr.Handler {
func (m *JWTMiddleware) Validate(next func(*customctx.Context) (interface{}, error)) gofr.Handler {
return func(c *gofr.Context) (interface{}, error) {
cc := customctx.NewCustomContext(c)

Expand Down
11 changes: 6 additions & 5 deletions pkg/api/middleware/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
)

// AuthenticateAPIKey checks if the API key is valid and active, otherwise returns an error.
func AuthenticateAPIKey(next func(*customctx.CustomContext) (interface{}, error)) func(*customctx.CustomContext) (interface{}, error) {
return func(cc *customctx.CustomContext) (interface{}, error) {
func AuthenticateAPIKey(next func(*customctx.Context) (interface{}, error)) func(*customctx.Context) (interface{}, error) {
return func(cc *customctx.Context) (interface{}, error) {
apiKey, ok := cc.GetAPIKey()
if !ok || apiKey == "" {
return nil, eventingest.NewAuthError("Missing API Key")
Expand All @@ -32,15 +32,16 @@ func AuthenticateAPIKey(next func(*customctx.CustomContext) (interface{}, error)
// The user's role is stored in the JWT token. The roles parameter is a list of roles that are allowed
// to access the resource.
func RequireRole(roles ...string) func(
func(*customctx.CustomContext) (interface{}, error)) func(*customctx.CustomContext) (interface{}, error) {
func(*customctx.Context) (interface{}, error)) func(*customctx.Context) (interface{}, error) {
return func(
next func(*customctx.CustomContext) (interface{}, error)) func(*customctx.CustomContext) (interface{}, error) {
return func(cc *customctx.CustomContext) (interface{}, error) {
next func(*customctx.Context) (interface{}, error)) func(*customctx.Context) (interface{}, error) {
return func(cc *customctx.Context) (interface{}, error) {
userRole, ok := cc.GetStringClaim("user_role")
if !ok {
return nil, eventingest.NewAuthError("Missing user role")
}

// Check if the user has the required role
for _, role := range roles {
if userRole == role {
return next(cc)
Expand Down
18 changes: 9 additions & 9 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@ import (
"gofr.dev/pkg/gofr"
)

type CustomContext struct {
type Context struct {
*gofr.Context
claims map[string]interface{}
}

// NewCustomContext creates a new CustomContext.
func NewCustomContext(c *gofr.Context) *CustomContext {
return &CustomContext{
// NewCustomContext creates a new Context.
func NewCustomContext(c *gofr.Context) *Context {
return &Context{
Context: c,
claims: make(map[string]interface{}),
}
}

// SetClaim sets the claim value.
func (c *CustomContext) SetClaim(key string, value interface{}) {
func (c *Context) SetClaim(key string, value interface{}) {
c.claims[key] = value
}

// GetClaim returns the claim value.
func (c *CustomContext) GetClaim(key string) (interface{}, bool) {
func (c *Context) GetClaim(key string) (interface{}, bool) {
value, ok := c.claims[key]

return value, ok
}

// GetStringClaim returns the claim value as a string.
func (c *CustomContext) GetStringClaim(key string) (string, bool) {
func (c *Context) GetStringClaim(key string) (string, bool) {
value, ok := c.claims[key]
if !ok {
return "", false
Expand All @@ -43,7 +43,7 @@ func (c *CustomContext) GetStringClaim(key string) (string, bool) {
}

// GetUUIDClaim returns the claim value as a UUID.
func (c *CustomContext) GetUUIDClaim(key string) (uuid.UUID, bool) {
func (c *Context) GetUUIDClaim(key string) (uuid.UUID, bool) {
value, ok := c.claims[key]
if !ok {
return uuid.UUID{}, false
Expand All @@ -55,7 +55,7 @@ func (c *CustomContext) GetUUIDClaim(key string) (uuid.UUID, bool) {
}

// GetAPIKey returns the API key from the context.
func (c *CustomContext) GetAPIKey() (string, bool) {
func (c *Context) GetAPIKey() (string, bool) {
if apiKey, ok := c.Context.Request.Context().Value("APIKey").(string); ok {
return apiKey, true
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/eventingest/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func NewHTTPServer(app *gofr.App, forwarder EventForwarder) *HTTPServer {
}
}

// HandleEvent handles an event request, it accepts a CustomContext and returns an interface and an error.
func (s *HTTPServer) HandleEvent(cc *customctx.CustomContext) (interface{}, error) {
// HandleEvent handles an event request, it accepts a Context and returns an interface and an error.
func (s *HTTPServer) HandleEvent(cc *customctx.Context) (interface{}, error) {
tenantID, ok := cc.GetUUIDClaim("tenant_id")
if !ok {
return nil, NewAuthError("Missing tenant ID")
Expand Down

0 comments on commit b5ad89b

Please sign in to comment.