Skip to content

Commit

Permalink
feat(authN): Redesign JWT token auth #372
Browse files Browse the repository at this point in the history
Redesign JWT token authentication middleware to support
additional/alternative authentication method
  • Loading branch information
michalkrzyz committed Nov 19, 2024
1 parent 73c9774 commit 6072be7
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 194 deletions.
3 changes: 1 addition & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ services:
DB_NAME: ${DB_NAME}
DB_SCHEMA: /app_sqlschema/schema.sql
SEED_MODE: ${SEED_MODE}
AUTH_TYPE: token
AUTH_TOKEN_SECRET: xxx
#AUTH_TOKEN_SECRET: xxx
volumes:
- ./internal/database/mariadb/init/schema.sql:/app_sqlschema/schema.sql
depends_on:
Expand Down
57 changes: 44 additions & 13 deletions internal/api/graphql/access/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package access

import (
"strings"
"fmt"
"net/http"
"reflect"

"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
Expand All @@ -17,23 +19,52 @@ type Logger interface {
Warn(...interface{})
}

type Auth interface {
GetMiddleware() gin.HandlerFunc
func NewAuth(cfg *util.Config) *Auth {
l := newLogger()
auth := Auth{logger: l}
auth.AppendInstance(NewTokenAuthMethod(l, cfg))
//TODO: auth.AppendInstance(NewOidcAuthMethod(l, cfg))
return &auth
}

func NewAuth(cfg *util.Config) Auth {
l := newLogger()
type Auth struct {
chain []AuthMethod
logger Logger
}

authType := strings.ToLower(cfg.AuthType)
if authType == "token" {
return NewTokenAuth(l, cfg)
} else if authType == "none" {
return NewNoAuth()
}
type AuthMethod interface {
Verify(*gin.Context) error
}

l.Warn("AUTH_TYPE is not set, assuming 'none' authorization method")
func (a *Auth) GetMiddleware() gin.HandlerFunc {
return func(authCtx *gin.Context) {
if len(a.chain) > 0 {
var retMsg string
for _, auth := range a.chain {
if err := auth.Verify(authCtx); err == nil {
authCtx.Next()
return
} else {
if retMsg != "" {
retMsg = fmt.Sprintf("%s, ", retMsg)
}
retMsg = fmt.Sprintf("%s%s", retMsg, err)
}
}
a.logger.Error("Unauthorized access: %s", retMsg)
authCtx.JSON(http.StatusUnauthorized, gin.H{"error": retMsg})
authCtx.Abort()
return
}
authCtx.Next()
return
}
}

return NewNoAuth()
func (a *Auth) AppendInstance(am AuthMethod) {
if !reflect.ValueOf(am).IsNil() {
a.chain = append(a.chain, am)
}
}

func newLogger() Logger {
Expand Down
21 changes: 0 additions & 21 deletions internal/api/graphql/access/no_auth.go

This file was deleted.

20 changes: 10 additions & 10 deletions internal/api/graphql/access/test/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

const (
testUsername = "testUser"
testClientName = "testClientName"
)

func SendGetRequest(url string, headers map[string]string) *http.Response {
Expand Down Expand Up @@ -53,7 +53,7 @@ type Jwt struct {
signingMethod jwt.SigningMethod
signKey interface{}
expiresAt *jwt.NumericDate
username string
name string
}

func NewJwt(secret string) *Jwt {
Expand All @@ -64,8 +64,8 @@ func NewRsaJwt(privKey *rsa.PrivateKey) *Jwt {
return &Jwt{signKey: privKey, signingMethod: jwt.SigningMethodRS256}
}

func (j *Jwt) WithUsername(username string) *Jwt {
j.username = username
func (j *Jwt) WithName(name string) *Jwt {
j.name = name
return j
}

Expand All @@ -81,7 +81,7 @@ func (j *Jwt) String() string {
ExpiresAt: j.expiresAt,
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "heureka",
Subject: j.username,
Subject: j.name,
},
}
token := jwt.NewWithClaims(j.signingMethod, claims)
Expand All @@ -92,15 +92,15 @@ func (j *Jwt) String() string {
}

func GenerateJwt(jwtSecret string, expiresIn time.Duration) string {
return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(testUsername).String()
return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithName(testClientName).String()
}

func GenerateJwtWithUsername(jwtSecret string, expiresIn time.Duration, username string) string {
return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(username).String()
func GenerateJwtWithName(jwtSecret string, expiresIn time.Duration, name string) string {
return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithName(name).String()
}

func GenerateInvalidJwt(jwtSecret string) string {
return NewJwt(jwtSecret).WithUsername(testUsername).String()
return NewJwt(jwtSecret).WithName(testClientName).String()
}

func GenerateRsaPrivateKey() *rsa.PrivateKey {
Expand All @@ -110,5 +110,5 @@ func GenerateRsaPrivateKey() *rsa.PrivateKey {
}

func GenerateJwtWithInvalidSigningMethod(jwtSecret string, expiresIn time.Duration) string {
return NewRsaJwt(GenerateRsaPrivateKey()).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(testUsername).String()
return NewRsaJwt(GenerateRsaPrivateKey()).WithExpiresAt(time.Now().Add(expiresIn)).WithName(testClientName).String()
}
116 changes: 0 additions & 116 deletions internal/api/graphql/access/token_auth.go

This file was deleted.

109 changes: 109 additions & 0 deletions internal/api/graphql/access/token_auth_method.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// SPDX-FileCopyrightText: 2024 SAP SE or an SAP affiliate company and Greenhouse contributors
// SPDX-License-Identifier: Apache-2.0

package access

import (
"context"
"fmt"
"time"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"

"github.com/cloudoperators/heureka/internal/util"
)

type ginContextKeyType string

const (
ginContextKey ginContextKeyType = "GinContextKey"
scannerNameKey string = "scannername"
tokenAuthHeader string = "X-Service-Authorization"
)

func NewTokenAuthMethod(l Logger, cfg *util.Config) *TokenAuthMethod {
if cfg.AuthTokenSecret != "" {
return &TokenAuthMethod{logger: l, secret: []byte(cfg.AuthTokenSecret)}
}
return nil
}

type TokenClaims struct {
Version string `json:"version"`
jwt.RegisteredClaims
}

type TokenAuthMethod struct {
logger Logger
secret []byte
}

func (tam TokenAuthMethod) Verify(c *gin.Context) error {
verifyError := func(s string) error {
return fmt.Errorf("TokenAuthMethod(%s)", s)
}

tokenString := c.GetHeader(tokenAuthHeader)
if tokenString == "" {
return verifyError("No authorization header")
}
token, claims, err := tam.parseFromString(tokenString)
if err != nil {
tam.logger.Error("JWT parsing error: ", err)
return verifyError("Token parsing error")
} else if !token.Valid || claims.ExpiresAt == nil {
tam.logger.Error("Invalid token")
return verifyError("Invalid token")
} else if claims.ExpiresAt.Before(time.Now()) {
tam.logger.Warn("Expired token")
return verifyError("Token expired")
}
c.Set(scannerNameKey, claims.RegisteredClaims.Subject)
ctx := context.WithValue(c.Request.Context(), ginContextKey, c)
c.Request = c.Request.WithContext(ctx)
return nil
}

func (tam TokenAuthMethod) parseFromString(tokenString string) (*jwt.Token, *TokenClaims, error) {
claims := &TokenClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse)
return token, claims, err
}

func (tam *TokenAuthMethod) parse(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Invalid JWT parse method")
}
return tam.secret, nil
}

func ScannerNameFromContext(ctx context.Context) (string, error) {
gc, err := ginContextFromContext(ctx)
if err != nil {
return "", err
}

s, ok := gc.Get(scannerNameKey)
if !ok {
return "", fmt.Errorf("could not find scanner name in gin.Context")
}
ss, ok := s.(string)
if !ok {
return "", fmt.Errorf("invalid scanner name type")
}
return ss, nil
}

func ginContextFromContext(ctx context.Context) (*gin.Context, error) {
ginContext := ctx.Value(ginContextKey)
if ginContext == nil {
return nil, fmt.Errorf("could not retrieve gin.Context")
}

gc, ok := ginContext.(*gin.Context)
if !ok {
return nil, fmt.Errorf("gin.Context has wrong type")
}
return gc, nil
}
Loading

0 comments on commit 6072be7

Please sign in to comment.