Skip to content

Commit

Permalink
UC-162 auth cache (#86)
Browse files Browse the repository at this point in the history
* cache auth token after successful authentication

* add test case TestExtendedProtocol_CheckAuth_Invalid_Cached
  • Loading branch information
leroxyl authored Nov 24, 2021
1 parent 74bbd81 commit 07dc247
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 51 deletions.
44 changes: 16 additions & 28 deletions main/adapters/repository/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ func (p *ExtendedProtocol) LoadIdentity(uid uuid.UUID) (*ent.Identity, error) {
return nil, err
}

p.authCache.Store(uid, i.AuthToken)

return i, nil
}

Expand Down Expand Up @@ -186,27 +184,8 @@ func (p *ExtendedProtocol) LoadPublicKey(uid uuid.UUID) (pubKeyPEM []byte, err e
return pubKeyPEM, nil
}

func (p *ExtendedProtocol) LoadAuth(uid uuid.UUID) (auth string, err error) {
_auth, found := p.authCache.Load(uid)

if found {
auth, found = _auth.(string)
}

if !found {
i, err := p.LoadIdentity(uid)
if err != nil {
return "", err
}

auth = i.AuthToken
}

return auth, nil
}

func (p *ExtendedProtocol) IsInitialized(uid uuid.UUID) (initialized bool, err error) {
_, err = p.LoadAuth(uid)
_, err = p.LoadPrivateKey(uid)
if err == ErrNotExist {
return false, nil
}
Expand All @@ -218,21 +197,32 @@ func (p *ExtendedProtocol) IsInitialized(uid uuid.UUID) (initialized bool, err e
}

func (p *ExtendedProtocol) CheckAuth(ctx context.Context, uid uuid.UUID, authToCheck string) (ok, found bool, err error) {
pwHash, err := p.LoadAuth(uid)
_auth, found := p.authCache.Load(uid)

if found {
if auth, ok := _auth.(string); ok {
return auth == authToCheck, found, err
}
}

i, err := p.LoadIdentity(uid)
if err == ErrNotExist {
return false, false, nil
return ok, found, nil
}
if err != nil {
return false, false, err
return ok, found, err
}

found = true

needsUpdate, ok, err := p.pwHasher.CheckPassword(ctx, pwHash, authToCheck)
needsUpdate, ok, err := p.pwHasher.CheckPassword(ctx, i.AuthToken, authToCheck)
if err != nil || !ok {
return ok, found, err
}

// auth check was successful
p.authCache.Store(uid, authToCheck)

if needsUpdate {
if err := p.updatePwHash(uid, authToCheck); err != nil {
log.Errorf("%s: password hash update failed: %v", uid, err)
Expand Down Expand Up @@ -273,8 +263,6 @@ func (p *ExtendedProtocol) updatePwHash(uid uuid.UUID, authToCheck string) error
return fmt.Errorf("could not commit transaction after storing updated password hash: %v", err)
}

p.authCache.Store(uid, updatedHash)

return nil
}

Expand Down
105 changes: 82 additions & 23 deletions main/adapters/repository/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ func TestProtocol(t *testing.T) {
_, err = p.LoadPublicKey(testIdentity.Uid)
assert.Equal(t, ErrNotExist, err)

_, err = p.LoadAuth(testIdentity.Uid)
assert.Equal(t, ErrNotExist, err)

// store identity
tx, err := p.StartTransaction(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -371,6 +368,34 @@ func TestExtendedProtocol_CheckAuth_Invalid(t *testing.T) {
assert.False(t, ok)
}

func TestExtendedProtocol_CheckAuth_Invalid_Cached(t *testing.T) {
ctxMngr := &MockCtxMngr{}
p, err := NewExtendedProtocol(ctxMngr, conf)
require.NoError(t, err)

i := generateRandomIdentity()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// store identity
tx, err := p.StartTransaction(ctx)
require.NoError(t, err)

err = p.StoreIdentity(tx, i)
require.NoError(t, err)

err = tx.Commit()
require.NoError(t, err)

p.authCache.Store(i.Uid, ctxMngr.id.AuthToken)

ok, found, err := p.CheckAuth(ctx, i.Uid, "invalid auth")
require.NoError(t, err)
assert.True(t, found)
assert.False(t, ok)
}

func TestExtendedProtocol_CheckAuth_NotFound(t *testing.T) {
p, err := NewExtendedProtocol(&MockCtxMngr{}, conf)
require.NoError(t, err)
Expand All @@ -382,36 +407,70 @@ func TestExtendedProtocol_CheckAuth_NotFound(t *testing.T) {
}

func TestExtendedProtocol_CheckAuth_Update(t *testing.T) {
ctxMngr := &MockCtxMngr{}
conf.KdUpdateParams = true
p, err := NewExtendedProtocol(ctxMngr, conf)
require.NoError(t, err)

i := generateRandomIdentity()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

uid := uuid.New()
testAuth := "password123"
kd := &pw.Argon2idKeyDerivator{
Params: pw.GetArgon2idParams(pw.DefaultMemory, pw.DefaultTime,
2*pw.DefaultParallelism, pw.DefaultKeyLen, pw.DefaultSaltLen),
}
pwHash, err := kd.GeneratePasswordHash(ctx, testAuth)
// store identity
tx, err := p.StartTransaction(ctx)
require.NoError(t, err)

ctxMngr := &MockCtxMngr{}
ctxMngr.id.Uid = uid
ctxMngr.id.AuthToken = pwHash
err = p.StoreIdentity(tx, i)
require.NoError(t, err)

p := &ExtendedProtocol{
ContextManager: ctxMngr,
pwHasher: pw.NewArgon2idKeyDerivator(0, pw.GetDefaultArgon2idParams(), true),
authCache: &sync.Map{},
}
err = tx.Commit()
require.NoError(t, err)

p.authCache.Store(uid, pwHash)
pwHashPreUpdate := ctxMngr.id.AuthToken
p.pwHasher.Params = pw.GetArgon2idParams(pw.DefaultMemory, pw.DefaultTime,
2*pw.DefaultParallelism, pw.DefaultKeyLen, pw.DefaultSaltLen)

ok, found, err := p.CheckAuth(ctx, uid, testAuth)
ok, found, err := p.CheckAuth(ctx, i.Uid, i.AuthToken)
require.NoError(t, err)
assert.True(t, found)
assert.True(t, ok)
require.True(t, found)
require.True(t, ok)

assert.NotEqual(t, pwHashPreUpdate, ctxMngr.id.AuthToken)

ok, found, err = p.CheckAuth(ctx, i.Uid, i.AuthToken)
require.NoError(t, err)
require.True(t, found)
require.True(t, ok)
}

func TestExtendedProtocol_CheckAuth_AuthCache(t *testing.T) {
p, err := NewExtendedProtocol(&MockCtxMngr{}, conf)
require.NoError(t, err)

i := generateRandomIdentity()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// store identity
tx, err := p.StartTransaction(ctx)
require.NoError(t, err)

err = p.StoreIdentity(tx, i)
require.NoError(t, err)

err = tx.Commit()
require.NoError(t, err)

ok, found, err := p.CheckAuth(ctx, i.Uid, i.AuthToken)
require.NoError(t, err)
require.True(t, found)
require.True(t, ok)

assert.NotEqual(t, pwHash, ctxMngr.id.AuthToken)
cachedAuth, found := p.authCache.Load(i.Uid)
require.True(t, found)
assert.Equal(t, i.AuthToken, cachedAuth.(string))
}

func TestProtocol_Cache(t *testing.T) {
Expand Down

0 comments on commit 07dc247

Please sign in to comment.