Skip to content

Commit

Permalink
Add a login service to the jimm layer (#1314)
Browse files Browse the repository at this point in the history
* add a login service to the jimm layer

* fix test due to changed error message

* add tests and move UserLogin method

The GetUser and UpdateUserLastLogin have been unexported in favour of UserLogin.
The tests in admin_test.go use a mock authenticator but primarily verify that the new methods in jimm/admin.go perform the expected validation and have the baseline desired behaviour.

* remove redundant mock function

* fix test timeout

* add comment about service account IDs

* fix test and reduce duplication

- fixed a test I broke with the recent change.
- reduced the cognitive complexity of `handleAdminFacade` by reducing code duplication.

* update godoc for LoginWithSessionCookie
  • Loading branch information
kian99 authored Aug 19, 2024
1 parent 7a3ae63 commit 7473eaa
Show file tree
Hide file tree
Showing 17 changed files with 444 additions and 349 deletions.
80 changes: 58 additions & 22 deletions internal/jimm/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,91 @@ import (
"golang.org/x/oauth2"

"github.com/canonical/jimm/v3/internal/errors"
"github.com/canonical/jimm/v3/internal/jimm/credentials"
"github.com/canonical/jimm/v3/internal/openfga"
"github.com/canonical/jimm/v3/pkg/names"
)

// LoginDevice starts the device login flow.
func LoginDevice(ctx context.Context, authenticator OAuthAuthenticator) (*oauth2.DeviceAuthResponse, error) {
const op = errors.Op("jujuapi.LoginDevice")

deviceResponse, err := authenticator.Device(ctx)
func (j *JIMM) LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
const op = errors.Op("jimm.LoginDevice")
resp, err := j.OAuthAuthenticator.Device(ctx)
if err != nil {
return nil, errors.E(op, err)
}

return deviceResponse, nil
return resp, nil
}

func GetDeviceSessionToken(ctx context.Context, authenticator OAuthAuthenticator, credentialStore credentials.CredentialStore, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) {
const op = errors.Op("jujuapi.GetDeviceSessionToken")

if authenticator == nil {
return "", errors.E("nil authenticator")
}
// GetDeviceSessionToken polls an OIDC server while a user logs in and returns a session token scoped to the user's identity.
func (j *JIMM) GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) {
const op = errors.Op("jimm.GetDeviceSessionToken")

if credentialStore == nil {
return "", errors.E("nil credential store")
}

token, err := authenticator.DeviceAccessToken(ctx, deviceOAuthResponse)
token, err := j.OAuthAuthenticator.DeviceAccessToken(ctx, deviceOAuthResponse)
if err != nil {
return "", errors.E(op, err)
}

idToken, err := authenticator.ExtractAndVerifyIDToken(ctx, token)
idToken, err := j.OAuthAuthenticator.ExtractAndVerifyIDToken(ctx, token)
if err != nil {
return "", errors.E(op, err)
}

email, err := authenticator.Email(idToken)
email, err := j.OAuthAuthenticator.Email(idToken)
if err != nil {
return "", errors.E(op, err)
}

if err := authenticator.UpdateIdentity(ctx, email, token); err != nil {
if err := j.OAuthAuthenticator.UpdateIdentity(ctx, email, token); err != nil {
return "", errors.E(op, err)
}

encToken, err := authenticator.MintSessionToken(email)
encToken, err := j.OAuthAuthenticator.MintSessionToken(email)
if err != nil {
return "", errors.E(op, err)
}

return string(encToken), nil
}

// LoginClientCredentials verifies a user's client ID and secret before the user is logged in.
func (j *JIMM) LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) {
const op = errors.Op("jimm.LoginClientCredentials")
// We expect the client to send the service account ID "as-is" and because we know that this is a clientCredentials login,
// we can append the @serviceaccount domain to the clientID (if not already present).
clientIdWithDomain, err := names.EnsureValidServiceAccountId(clientID)
if err != nil {
return nil, errors.E(op, err)
}

err = j.OAuthAuthenticator.VerifyClientCredentials(ctx, clientID, clientSecret)
if err != nil {
return nil, errors.E(op, err)
}

return j.UserLogin(ctx, clientIdWithDomain)
}

// LoginWithSessionToken verifies a user's session token before the user is logged in.
func (j *JIMM) LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) {
const op = errors.Op("jimm.LoginWithSessionToken")
jwtToken, err := j.OAuthAuthenticator.VerifySessionToken(sessionToken)
if err != nil {
return nil, errors.E(op, err)
}

email := jwtToken.Subject()
return j.UserLogin(ctx, email)
}

// LoginWithSessionCookie uses the identity ID expected to have come from a session cookie, to log the user in.
//
// The work to parse and store the user's identity from the session cookie takes place in internal/jimmhttp/websocket.go
// [WSHandler.ServerHTTP] during the upgrade from an HTTP connection to a websocket. The user's identity is stored
// and passed to this function with the assumption that the cookie contained a valid session. This function is far from
// the session cookie logic due to the separation between the HTTP layer and Juju's RPC mechanism.
func (j *JIMM) LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) {
const op = errors.Op("jimm.LoginWithSessionCookie")
if identityID == "" {
return nil, errors.E(op, "missing cookie identity")
}
return j.UserLogin(ctx, identityID)
}
137 changes: 137 additions & 0 deletions internal/jimm/admin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright 2024 Canonical.

package jimm_test

import (
"context"
"encoding/base64"
"testing"
"time"

qt "github.com/frankban/quicktest"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/oauth2"

"github.com/canonical/jimm/v3/internal/db"
"github.com/canonical/jimm/v3/internal/jimm"
"github.com/canonical/jimm/v3/internal/jimmtest"
)

func TestLoginDevice(t *testing.T) {
c := qt.New(t)
mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil)
jimm := jimm.JIMM{
OAuthAuthenticator: &mockAuthenticator,
}
resp, err := jimm.LoginDevice(context.Background())
c.Assert(err, qt.IsNil)
c.Assert(*resp, qt.CmpEquals(cmpopts.IgnoreTypes(time.Time{})), oauth2.DeviceAuthResponse{
DeviceCode: "test-device-code",
UserCode: "test-user-code",
VerificationURI: "http://no-such-uri.canonical.com",
VerificationURIComplete: "http://no-such-uri.canonical.com",
Interval: int64(time.Minute.Seconds()),
})
}

func TestGetDeviceSessionToken(t *testing.T) {
c := qt.New(t)
pollingChan := make(chan string, 1)
mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, pollingChan)
jimm := jimm.JIMM{
OAuthAuthenticator: &mockAuthenticator,
}
pollingChan <- "user-foo"
token, err := jimm.GetDeviceSessionToken(context.Background(), nil)
c.Assert(err, qt.IsNil)
c.Assert(token, qt.Not(qt.Equals), "")
decodedToken, err := base64.StdEncoding.DecodeString(token)
c.Assert(err, qt.IsNil)
parsedToken, err := jwt.ParseInsecure([]byte(decodedToken))
c.Assert(err, qt.IsNil)
c.Assert(parsedToken.Subject(), qt.Equals, "[email protected]")
}

func TestLoginClientCredentials(t *testing.T) {
c := qt.New(t)
mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil)
client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name())
c.Assert(err, qt.IsNil)
jimm := jimm.JIMM{
UUID: "foo",
Database: db.Database{
DB: jimmtest.PostgresDB(c, func() time.Time { return now }),
},
OAuthAuthenticator: &mockAuthenticator,
OpenFGAClient: client,
}
ctx := context.Background()
err = jimm.Database.Migrate(ctx, false)
c.Assert(err, qt.IsNil)
invalidClientID := "123@123@"
_, err = jimm.LoginClientCredentials(ctx, invalidClientID, "foo-secret")
c.Assert(err, qt.ErrorMatches, "invalid client ID")

validClientID := "my-svc-acc"
user, err := jimm.LoginClientCredentials(ctx, validClientID, "foo-secret")
c.Assert(err, qt.IsNil)
c.Assert(user.Name, qt.Equals, "my-svc-acc@serviceaccount")
}

func TestLoginWithSessionToken(t *testing.T) {
c := qt.New(t)
mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil)
client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name())
c.Assert(err, qt.IsNil)
jimm := jimm.JIMM{
UUID: "foo",
Database: db.Database{
DB: jimmtest.PostgresDB(c, func() time.Time { return now }),
},
OAuthAuthenticator: &mockAuthenticator,
OpenFGAClient: client,
}
ctx := context.Background()
err = jimm.Database.Migrate(ctx, false)
c.Assert(err, qt.IsNil)

token, err := jwt.NewBuilder().
Subject("[email protected]").
Build()
serialisedToken, err := jwt.NewSerializer().Serialize(token)
c.Assert(err, qt.IsNil)
b64Token := base64.StdEncoding.EncodeToString(serialisedToken)

_, err = jimm.LoginWithSessionToken(ctx, "invalid-token")
c.Assert(err, qt.ErrorMatches, "failed to decode token")

user, err := jimm.LoginWithSessionToken(ctx, b64Token)
c.Assert(err, qt.IsNil)
c.Assert(user.Name, qt.Equals, "[email protected]")
}

func TestLoginWithSessionCookie(t *testing.T) {
c := qt.New(t)
mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil)
client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name())
c.Assert(err, qt.IsNil)
jimm := jimm.JIMM{
UUID: "foo",
Database: db.Database{
DB: jimmtest.PostgresDB(c, func() time.Time { return now }),
},
OAuthAuthenticator: &mockAuthenticator,
OpenFGAClient: client,
}
ctx := context.Background()
err = jimm.Database.Migrate(ctx, false)
c.Assert(err, qt.IsNil)

_, err = jimm.LoginWithSessionCookie(ctx, "")
c.Assert(err, qt.ErrorMatches, "missing cookie identity")

user, err := jimm.LoginWithSessionCookie(ctx, "[email protected]")
c.Assert(err, qt.IsNil)
c.Assert(user.Name, qt.Equals, "[email protected]")
}
9 changes: 9 additions & 0 deletions internal/jimm/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/canonical/jimm/v3/internal/db"
"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/openfga"
)

var (
Expand Down Expand Up @@ -47,3 +48,11 @@ func NewWatcherWithDeltaProcessedChannel(db db.Database, dialer Dialer, pubsub P
func (j *JIMM) ListApplicationOfferUsers(ctx context.Context, offer names.ApplicationOfferTag, user *dbmodel.Identity, accessLevel string) ([]jujuparams.OfferUserDetails, error) {
return j.listApplicationOfferUsers(ctx, offer, user, accessLevel)
}

func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, error) {
return j.getUser(ctx, identifier)
}

func (j *JIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error {
return j.updateUserLastLogin(ctx, identifier)
}
5 changes: 0 additions & 5 deletions internal/jimm/jimm.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ type JIMM struct {
OAuthAuthenticator OAuthAuthenticator
}

// OAuthAuthenticationService returns the JIMM's authentication service.
func (j *JIMM) OAuthAuthenticationService() OAuthAuthenticator {
return j.OAuthAuthenticator
}

// ResourceTag returns JIMM's controller tag stating its UUID.
func (j *JIMM) ResourceTag() names.ControllerTag {
return names.NewControllerTag(j.UUID)
Expand Down
21 changes: 18 additions & 3 deletions internal/jimm/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@ import (
"github.com/canonical/jimm/v3/internal/openfga"
)

// GetUser fetches the user specified by the user's email or the service accounts ID
// UserLogin fetches a user based on their identityName and updates their last login time.
func (j *JIMM) UserLogin(ctx context.Context, identityName string) (*openfga.User, error) {
const op = errors.Op("jimm.UserLogin")
user, err := j.getUser(ctx, identityName)
if err != nil {
return nil, errors.E(op, err, errors.CodeUnauthorized)
}
err = j.updateUserLastLogin(ctx, identityName)
if err != nil {
return nil, errors.E(op, err, errors.CodeUnauthorized)
}
return user, nil
}

// getUser fetches the user specified by the user's email or the service accounts ID
// and returns an openfga User that can be used to verify user's permissions.
func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, error) {
func (j *JIMM) getUser(ctx context.Context, identifier string) (*openfga.User, error) {
const op = errors.Op("jimm.GetUser")

user, err := dbmodel.NewIdentity(identifier)
Expand All @@ -36,7 +50,8 @@ func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, e
return u, nil
}

func (j *JIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error {
// updateUserLastLogin updates the user's last login time in the database.
func (j *JIMM) updateUserLastLogin(ctx context.Context, identifier string) error {
const op = errors.Op("jimm.UpdateUserLastLogin")
user, err := dbmodel.NewIdentity(identifier)
if err != nil {
Expand Down
26 changes: 1 addition & 25 deletions internal/jimmjwx/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ func (jwks *JWKSService) StartJWKSRotator(ctx context.Context, checkRotateRequir

credStore := jwks.credentialStore

// For logging and monitoring purposes, we have the rotator spit errors into
// this buffered channel ((size * amount) * 2 of errors we are currently aware of and doubling it to prevent blocks)
errorChan := make(chan error, 8)

if err := rotateJWKS(ctx, credStore, initialRotateRequiredTime); err != nil {
zapctx.Error(ctx, "Rotate JWKS error", zap.Error(err))
return errors.E(op, err)
Expand All @@ -121,39 +117,19 @@ func (jwks *JWKSService) StartJWKSRotator(ctx context.Context, checkRotateRequir
//
// In this case we generate a new set, which should expire in 3 months.
go func() {
defer close(errorChan)
for {
select {
case <-checkRotateRequired:
if err := rotateJWKS(ctx, credStore, initialRotateRequiredTime); err != nil {
errorChan <- err
zapctx.Error(ctx, "security failure", zap.Any("op", op), zap.NamedError("jwks-error", err))
}

case <-ctx.Done():
zapctx.Debug(ctx, "Shutdown for JWKS rotator complete.")
return
}
}
}()

// If for any reason the rotator has an error, we simply receive the error
// in another routine dedicated to logging said errors.
go func(errChan <-chan error) {
for err := range errChan {
zapctx.Error(
ctx,
"security failure",
zap.Any("op", op),
zap.NamedError("jwks-error", err),
)
select {
case <-ctx.Done():
return
default:
}
}
}(errorChan)

return nil
}

Expand Down
3 changes: 3 additions & 0 deletions internal/jimmjwx/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ func startAndTestRotator(c *qt.C, ctx context.Context, store credentials.Credent
for i := 0; i < 60; i++ {
if ks == nil {
ks, err = store.GetJWKS(ctx)
if err != nil {
c.Logf("failed to get JWKS: %s", err)
}
time.Sleep(500 * time.Millisecond)
continue
}
Expand Down
Loading

0 comments on commit 7473eaa

Please sign in to comment.