Skip to content

Commit

Permalink
Merge pull request #43 from Sovietaced/background-jwk
Browse files Browse the repository at this point in the history
Add background fetch for JWT sets
  • Loading branch information
Sovietaced authored Apr 12, 2024
2 parents 317bbfe + 7000aa7 commit 5614bbc
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 33 deletions.
107 changes: 98 additions & 9 deletions keyfunc/okta/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,31 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/sovietaced/okta-jwt-verifier/metadata"
"io"
"log/slog"
"net/http"
"sync"
"time"
)

type FetchStrategy int64

const (
Lazy FetchStrategy = iota // Fetch new Okta JWT set inline with requests (when not cached)
// Background Fetch new Okta JWT set in the background regardless of requests being made. This option was designed
// for eliminating in-line Okta JWK set calls and minimizing latency in production use. Warning: this option will
// attempt to seed Okta JWT sets on initialization and block.
Background

DefaultCacheTtl = 5 * time.Minute
)

// Options are configurable options for the KeyfuncProvider.
type Options struct {
httpClient *http.Client
clock clock.Clock
cacheTtl time.Duration
httpClient *http.Client
clock clock.Clock
cacheTtl time.Duration
fetchStrategy FetchStrategy
backgroundCtx context.Context
}

// WithHttpClient allows for a configurable http client.
Expand All @@ -40,11 +55,28 @@ func WithCacheTtl(ttl time.Duration) Option {
}
}

// WithFetchStrategy specifies a strategy for fetching new Okta JWK sets.
func WithFetchStrategy(fetchStrategy FetchStrategy) Option {
return func(mo *Options) {
mo.fetchStrategy = fetchStrategy
}
}

// WithBackgroundCtx specified the context to use in order to control the lifecycle of the background fetching
// goroutine.
func WithBackgroundCtx(ctx context.Context) Option {
return func(mo *Options) {
mo.backgroundCtx = ctx
}
}

func defaultOptions() *Options {
opts := &Options{}
WithHttpClient(http.DefaultClient)(opts)
withClock(clock.New())(opts)
WithCacheTtl(5 * time.Minute)(opts)
WithCacheTtl(DefaultCacheTtl)(opts)
WithFetchStrategy(Lazy)(opts)
WithBackgroundCtx(context.Background())(opts)
return opts
}

Expand All @@ -69,16 +101,38 @@ type KeyfuncProvider struct {
keyfuncMutex sync.Mutex
cacheTtl time.Duration
cachedKeyfunc *cachedKeyfunc
fetchStrategy FetchStrategy
}

// NewKeyfuncProvider creates a new KeyfuncProvider.
func NewKeyfuncProvider(mp metadata.Provider, options ...Option) *KeyfuncProvider {
func NewKeyfuncProvider(mp metadata.Provider, options ...Option) (*KeyfuncProvider, error) {
opts := defaultOptions()
for _, option := range options {
option(opts)
}

return &KeyfuncProvider{mp: mp, httpClient: opts.httpClient, clock: opts.clock, cacheTtl: opts.cacheTtl}
kp := &KeyfuncProvider{
mp: mp,
httpClient: opts.httpClient,
clock: opts.clock,
cacheTtl: opts.cacheTtl,
fetchStrategy: opts.fetchStrategy,
}

if opts.fetchStrategy == Background {
md, err := kp.mp.GetMetadata(opts.backgroundCtx)
if err != nil {
return nil, fmt.Errorf("getting metadata: %w", err)
}

_, err = kp.backgroundFetchAndCache(opts.backgroundCtx, md.JwksUri)
if err != nil {
return nil, fmt.Errorf("failed to seed Okta JWK set: %w", err)
}
go kp.backgroundFetchLoop(opts.backgroundCtx)
}

return kp, nil
}

// GetKeyfunc gets a jwt.Keyfunc based on the OIDC metadata.
Expand All @@ -88,15 +142,15 @@ func (kp *KeyfuncProvider) GetKeyfunc(ctx context.Context) (jwt.Keyfunc, error)
return nil, fmt.Errorf("getting metadata: %w", err)
}

keyfunc, err := kp.getOrFetchKeyfunc(ctx, md.JwksUri)
kf, err := kp.lazyFetchAndCache(ctx, md.JwksUri)
if err != nil {
return nil, fmt.Errorf("getting or fetching keyfunc: %w", err)
}

return keyfunc, nil
return kf, nil
}

func (kp *KeyfuncProvider) getOrFetchKeyfunc(ctx context.Context, jwksUri string) (jwt.Keyfunc, error) {
func (kp *KeyfuncProvider) lazyFetchAndCache(ctx context.Context, jwksUri string) (jwt.Keyfunc, error) {
cachedKeyfuncCopy := kp.cachedKeyfunc

if cachedKeyfuncCopy != nil && kp.clock.Now().Before(cachedKeyfuncCopy.expiration) {
Expand Down Expand Up @@ -125,6 +179,41 @@ func (kp *KeyfuncProvider) getOrFetchKeyfunc(ctx context.Context, jwksUri string
return kf, nil
}

func (kp *KeyfuncProvider) backgroundFetchAndCache(ctx context.Context, jwksUri string) (jwt.Keyfunc, error) {
// Acquire a lock
kp.keyfuncMutex.Lock()
defer kp.keyfuncMutex.Unlock()

expiration := kp.clock.Now().Add(kp.cacheTtl)

newKeyfunc, err := kp.fetchKeyfunc(ctx, jwksUri)
if err != nil {
return nil, fmt.Errorf("failed to fetch new fresh key func: %w", err)
}
kp.cachedKeyfunc = newCachedKeyfunc(expiration, newKeyfunc)
return kp.cachedKeyfunc.keyfunc, nil
}

func (kp *KeyfuncProvider) backgroundFetchLoop(ctx context.Context) {
ticker := kp.clock.Ticker(kp.cacheTtl / 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
md, err := kp.mp.GetMetadata(ctx)
if err != nil {
slog.ErrorContext(ctx, fmt.Sprintf("failed to fetch and cache metadata: %s", err.Error()))
}
_, err = kp.backgroundFetchAndCache(ctx, md.JwksUri)
if err != nil {
slog.ErrorContext(ctx, fmt.Sprintf("failed to fetch and cache key func: %s", err.Error()))
}
}
}
}

func (kp *KeyfuncProvider) fetchKeyfunc(ctx context.Context, jwksUri string) (jwt.Keyfunc, error) {

httpRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksUri, nil)
Expand Down
21 changes: 13 additions & 8 deletions keyfunc/okta/okta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"time"
)

func TestKeyfuncProvider(t *testing.T) {
func TestLazyKeyfuncProvider(t *testing.T) {

// Generate RSA key.
pk, err := rsa.GenerateKey(rand.Reader, 2048)
Expand All @@ -36,7 +36,8 @@ func TestKeyfuncProvider(t *testing.T) {
},
}

kp := NewKeyfuncProvider(mp)
kp, err := NewKeyfuncProvider(mp)
require.NoError(t, err)

keyfunc, err := kp.GetKeyfunc(ctx)
require.NoError(t, err)
Expand All @@ -53,7 +54,8 @@ func TestKeyfuncProvider(t *testing.T) {
}

clock := clock2.NewMock()
kp := NewKeyfuncProvider(mp, withClock(clock))
kp, err := NewKeyfuncProvider(mp, withClock(clock))
require.NoError(t, err)

keyfunc, err := kp.GetKeyfunc(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -95,7 +97,8 @@ func TestKeyfuncProvider(t *testing.T) {
},
}

kp := NewKeyfuncProvider(mp, WithHttpClient(&httpClient))
kp, err := NewKeyfuncProvider(mp, WithHttpClient(&httpClient))
require.NoError(t, err)

tracer := provider.Tracer("test")
spanCtx, span := tracer.Start(ctx, "test")
Expand All @@ -120,9 +123,10 @@ func TestKeyfuncProvider(t *testing.T) {
t.Run("get keyfunc and metadata provider returns error", func(t *testing.T) {
mp := errorMetadataProvider{err: fmt.Errorf("synthetic error")}

kp := NewKeyfuncProvider(&mp)
kp, err := NewKeyfuncProvider(&mp)
require.NoError(t, err)

_, err := kp.GetKeyfunc(ctx)
_, err = kp.GetKeyfunc(ctx)
require.ErrorContains(t, err, "getting metadata: synthetic error")
})

Expand All @@ -133,9 +137,10 @@ func TestKeyfuncProvider(t *testing.T) {
},
}

kp := NewKeyfuncProvider(mp)
kp, err := NewKeyfuncProvider(mp)
require.NoError(t, err)

_, err := kp.GetKeyfunc(ctx)
_, err = kp.GetKeyfunc(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "getting or fetching keyfunc: fetching keyfunc: making http request for jwks")
})
Expand Down
8 changes: 1 addition & 7 deletions metadata/okta/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,7 @@ func (mp *MetadataProvider) backgroundFetchAndCache(ctx context.Context) (metada
}

func (mp *MetadataProvider) backgroundFetchLoop(ctx context.Context) {
// Seed cache initially
_, err := mp.backgroundFetchAndCache(ctx)
if err != nil {
slog.ErrorContext(ctx, fmt.Sprintf("failed to fetch and cache metadata: %s", err.Error()))
}

ticker := time.NewTicker(mp.cacheTtl / 2)
ticker := mp.clock.Ticker(mp.cacheTtl / 2)
defer ticker.Stop()
for {
select {
Expand Down
22 changes: 16 additions & 6 deletions metadata/okta/okta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,17 @@ func TestBackgroundMetadataProvider(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, serverCount)

// Fast forward time and invalidate the cache
fakeClock.Add(10 * time.Minute)
require.Eventually(t, func() bool {
// Fast forward time and invalidate the cache
fakeClock.Add(mp.cacheTtl)
return serverCount > 1
}, 5*time.Second, time.Millisecond)

newServerCount := serverCount

_, err = mp.GetMetadata(ctx)
require.NoError(t, err)
require.Equal(t, 2, serverCount)
require.Equal(t, newServerCount, serverCount)
})

t.Run("get metadata and verify cached after server error", func(t *testing.T) {
Expand Down Expand Up @@ -222,12 +227,17 @@ func TestBackgroundMetadataProvider(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, serverCount)

// Fast forward time and invalidate the cache
fakeClock.Add(10 * time.Minute)
require.Eventually(t, func() bool {
// Fast forward time and invalidate the cache
fakeClock.Add(mp.cacheTtl)
return serverCount > 1
}, 5*time.Second, time.Millisecond)

newServerCount := serverCount

_, err = mp.GetMetadata(ctx)
require.NoError(t, err)
require.Equal(t, 2, serverCount)
require.Equal(t, newServerCount, serverCount)
})
}

Expand Down
7 changes: 6 additions & 1 deletion verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ func defaultOptions(issuer string) (*Options, error) {
if err != nil {
return nil, fmt.Errorf("creating default metadata provider: %w", err)
}
WithKeyfuncProvider(okta.NewKeyfuncProvider(mp))(opts)

kp, err := okta.NewKeyfuncProvider(mp)
if err != nil {
return nil, fmt.Errorf("creating new key func provider: %w", err)
}
WithKeyfuncProvider(kp)(opts)
WithLeeway(DefaultLeeway)(opts)
return opts, nil
}
Expand Down
8 changes: 6 additions & 2 deletions verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func TestVerifierVerifyIdToken(t *testing.T) {
},
}

kp := okta.NewKeyfuncProvider(mp)
kp, err := okta.NewKeyfuncProvider(mp)
require.NoError(t, err)

v, err := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp))
require.NoError(t, err)

Expand Down Expand Up @@ -212,7 +214,9 @@ func TestVerifierVerifyAccessToken(t *testing.T) {
},
}

kp := okta.NewKeyfuncProvider(mp)
kp, err := okta.NewKeyfuncProvider(mp)
require.NoError(t, err)

v, err := NewVerifier(issuer, clientId, WithKeyfuncProvider(kp))
require.NoError(t, err)

Expand Down

0 comments on commit 5614bbc

Please sign in to comment.