diff --git a/internal/jimm/admin.go b/internal/jimm/admin.go index bd2f79e96..844761918 100644 --- a/internal/jimm/admin.go +++ b/internal/jimm/admin.go @@ -8,55 +8,91 @@ import ( "golang.org/x/oauth2" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimm/credentials" + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/pkg/names" ) // LoginDevice starts the device login flow. -func LoginDevice(ctx context.Context, authenticator OAuthAuthenticator) (*oauth2.DeviceAuthResponse, error) { - const op = errors.Op("jujuapi.LoginDevice") - - deviceResponse, err := authenticator.Device(ctx) +func (j *JIMM) LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { + const op = errors.Op("jimm.LoginDevice") + resp, err := j.OAuthAuthenticator.Device(ctx) if err != nil { return nil, errors.E(op, err) } - - return deviceResponse, nil + return resp, nil } -func GetDeviceSessionToken(ctx context.Context, authenticator OAuthAuthenticator, credentialStore credentials.CredentialStore, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) { - const op = errors.Op("jujuapi.GetDeviceSessionToken") - - if authenticator == nil { - return "", errors.E("nil authenticator") - } +// GetDeviceSessionToken polls an OIDC server while a user logs in and returns a session token scoped to the user's identity. +func (j *JIMM) GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) { + const op = errors.Op("jimm.GetDeviceSessionToken") - if credentialStore == nil { - return "", errors.E("nil credential store") - } - - token, err := authenticator.DeviceAccessToken(ctx, deviceOAuthResponse) + token, err := j.OAuthAuthenticator.DeviceAccessToken(ctx, deviceOAuthResponse) if err != nil { return "", errors.E(op, err) } - idToken, err := authenticator.ExtractAndVerifyIDToken(ctx, token) + idToken, err := j.OAuthAuthenticator.ExtractAndVerifyIDToken(ctx, token) if err != nil { return "", errors.E(op, err) } - email, err := authenticator.Email(idToken) + email, err := j.OAuthAuthenticator.Email(idToken) if err != nil { return "", errors.E(op, err) } - if err := authenticator.UpdateIdentity(ctx, email, token); err != nil { + if err := j.OAuthAuthenticator.UpdateIdentity(ctx, email, token); err != nil { return "", errors.E(op, err) } - encToken, err := authenticator.MintSessionToken(email) + encToken, err := j.OAuthAuthenticator.MintSessionToken(email) if err != nil { return "", errors.E(op, err) } return string(encToken), nil } + +// LoginClientCredentials verifies a user's client ID and secret before the user is logged in. +func (j *JIMM) LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) { + const op = errors.Op("jimm.LoginClientCredentials") + // We expect the client to send the service account ID "as-is" and because we know that this is a clientCredentials login, + // we can append the @serviceaccount domain to the clientID (if not already present). + clientIdWithDomain, err := names.EnsureValidServiceAccountId(clientID) + if err != nil { + return nil, errors.E(op, err) + } + + err = j.OAuthAuthenticator.VerifyClientCredentials(ctx, clientID, clientSecret) + if err != nil { + return nil, errors.E(op, err) + } + + return j.UserLogin(ctx, clientIdWithDomain) +} + +// LoginWithSessionToken verifies a user's session token before the user is logged in. +func (j *JIMM) LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) { + const op = errors.Op("jimm.LoginWithSessionToken") + jwtToken, err := j.OAuthAuthenticator.VerifySessionToken(sessionToken) + if err != nil { + return nil, errors.E(op, err) + } + + email := jwtToken.Subject() + return j.UserLogin(ctx, email) +} + +// LoginWithSessionCookie uses the identity ID expected to have come from a session cookie, to log the user in. +// +// The work to parse and store the user's identity from the session cookie takes place in internal/jimmhttp/websocket.go +// [WSHandler.ServerHTTP] during the upgrade from an HTTP connection to a websocket. The user's identity is stored +// and passed to this function with the assumption that the cookie contained a valid session. This function is far from +// the session cookie logic due to the separation between the HTTP layer and Juju's RPC mechanism. +func (j *JIMM) LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) { + const op = errors.Op("jimm.LoginWithSessionCookie") + if identityID == "" { + return nil, errors.E(op, "missing cookie identity") + } + return j.UserLogin(ctx, identityID) +} diff --git a/internal/jimm/admin_test.go b/internal/jimm/admin_test.go new file mode 100644 index 000000000..13a8b8b4c --- /dev/null +++ b/internal/jimm/admin_test.go @@ -0,0 +1,137 @@ +// Copyright 2024 Canonical. + +package jimm_test + +import ( + "context" + "encoding/base64" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/lestrrat-go/jwx/v2/jwt" + "golang.org/x/oauth2" + + "github.com/canonical/jimm/v3/internal/db" + "github.com/canonical/jimm/v3/internal/jimm" + "github.com/canonical/jimm/v3/internal/jimmtest" +) + +func TestLoginDevice(t *testing.T) { + c := qt.New(t) + mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil) + jimm := jimm.JIMM{ + OAuthAuthenticator: &mockAuthenticator, + } + resp, err := jimm.LoginDevice(context.Background()) + c.Assert(err, qt.IsNil) + c.Assert(*resp, qt.CmpEquals(cmpopts.IgnoreTypes(time.Time{})), oauth2.DeviceAuthResponse{ + DeviceCode: "test-device-code", + UserCode: "test-user-code", + VerificationURI: "http://no-such-uri.canonical.com", + VerificationURIComplete: "http://no-such-uri.canonical.com", + Interval: int64(time.Minute.Seconds()), + }) +} + +func TestGetDeviceSessionToken(t *testing.T) { + c := qt.New(t) + pollingChan := make(chan string, 1) + mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, pollingChan) + jimm := jimm.JIMM{ + OAuthAuthenticator: &mockAuthenticator, + } + pollingChan <- "user-foo" + token, err := jimm.GetDeviceSessionToken(context.Background(), nil) + c.Assert(err, qt.IsNil) + c.Assert(token, qt.Not(qt.Equals), "") + decodedToken, err := base64.StdEncoding.DecodeString(token) + c.Assert(err, qt.IsNil) + parsedToken, err := jwt.ParseInsecure([]byte(decodedToken)) + c.Assert(err, qt.IsNil) + c.Assert(parsedToken.Subject(), qt.Equals, "user-foo@canonical.com") +} + +func TestLoginClientCredentials(t *testing.T) { + c := qt.New(t) + mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil) + client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name()) + c.Assert(err, qt.IsNil) + jimm := jimm.JIMM{ + UUID: "foo", + Database: db.Database{ + DB: jimmtest.PostgresDB(c, func() time.Time { return now }), + }, + OAuthAuthenticator: &mockAuthenticator, + OpenFGAClient: client, + } + ctx := context.Background() + err = jimm.Database.Migrate(ctx, false) + c.Assert(err, qt.IsNil) + invalidClientID := "123@123@" + _, err = jimm.LoginClientCredentials(ctx, invalidClientID, "foo-secret") + c.Assert(err, qt.ErrorMatches, "invalid client ID") + + validClientID := "my-svc-acc" + user, err := jimm.LoginClientCredentials(ctx, validClientID, "foo-secret") + c.Assert(err, qt.IsNil) + c.Assert(user.Name, qt.Equals, "my-svc-acc@serviceaccount") +} + +func TestLoginWithSessionToken(t *testing.T) { + c := qt.New(t) + mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil) + client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name()) + c.Assert(err, qt.IsNil) + jimm := jimm.JIMM{ + UUID: "foo", + Database: db.Database{ + DB: jimmtest.PostgresDB(c, func() time.Time { return now }), + }, + OAuthAuthenticator: &mockAuthenticator, + OpenFGAClient: client, + } + ctx := context.Background() + err = jimm.Database.Migrate(ctx, false) + c.Assert(err, qt.IsNil) + + token, err := jwt.NewBuilder(). + Subject("alice@canonical.com"). + Build() + serialisedToken, err := jwt.NewSerializer().Serialize(token) + c.Assert(err, qt.IsNil) + b64Token := base64.StdEncoding.EncodeToString(serialisedToken) + + _, err = jimm.LoginWithSessionToken(ctx, "invalid-token") + c.Assert(err, qt.ErrorMatches, "failed to decode token") + + user, err := jimm.LoginWithSessionToken(ctx, b64Token) + c.Assert(err, qt.IsNil) + c.Assert(user.Name, qt.Equals, "alice@canonical.com") +} + +func TestLoginWithSessionCookie(t *testing.T) { + c := qt.New(t) + mockAuthenticator := jimmtest.NewMockOAuthAuthenticator(c, nil) + client, _, _, err := jimmtest.SetupTestOFGAClient(c.Name(), t.Name()) + c.Assert(err, qt.IsNil) + jimm := jimm.JIMM{ + UUID: "foo", + Database: db.Database{ + DB: jimmtest.PostgresDB(c, func() time.Time { return now }), + }, + OAuthAuthenticator: &mockAuthenticator, + OpenFGAClient: client, + } + ctx := context.Background() + err = jimm.Database.Migrate(ctx, false) + c.Assert(err, qt.IsNil) + + _, err = jimm.LoginWithSessionCookie(ctx, "") + c.Assert(err, qt.ErrorMatches, "missing cookie identity") + + user, err := jimm.LoginWithSessionCookie(ctx, "alice@canonical.com") + c.Assert(err, qt.IsNil) + c.Assert(user.Name, qt.Equals, "alice@canonical.com") +} diff --git a/internal/jimm/export_test.go b/internal/jimm/export_test.go index 411c81f6e..671e8e35c 100644 --- a/internal/jimm/export_test.go +++ b/internal/jimm/export_test.go @@ -10,6 +10,7 @@ import ( "github.com/canonical/jimm/v3/internal/db" "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/openfga" ) var ( @@ -47,3 +48,11 @@ func NewWatcherWithDeltaProcessedChannel(db db.Database, dialer Dialer, pubsub P func (j *JIMM) ListApplicationOfferUsers(ctx context.Context, offer names.ApplicationOfferTag, user *dbmodel.Identity, accessLevel string) ([]jujuparams.OfferUserDetails, error) { return j.listApplicationOfferUsers(ctx, offer, user, accessLevel) } + +func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, error) { + return j.getUser(ctx, identifier) +} + +func (j *JIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error { + return j.updateUserLastLogin(ctx, identifier) +} diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index b9ec53e14..bcde7e929 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -87,11 +87,6 @@ type JIMM struct { OAuthAuthenticator OAuthAuthenticator } -// OAuthAuthenticationService returns the JIMM's authentication service. -func (j *JIMM) OAuthAuthenticationService() OAuthAuthenticator { - return j.OAuthAuthenticator -} - // ResourceTag returns JIMM's controller tag stating its UUID. func (j *JIMM) ResourceTag() names.ControllerTag { return names.NewControllerTag(j.UUID) diff --git a/internal/jimm/user.go b/internal/jimm/user.go index ba9e8165e..202be4278 100644 --- a/internal/jimm/user.go +++ b/internal/jimm/user.go @@ -12,9 +12,23 @@ import ( "github.com/canonical/jimm/v3/internal/openfga" ) -// GetUser fetches the user specified by the user's email or the service accounts ID +// UserLogin fetches a user based on their identityName and updates their last login time. +func (j *JIMM) UserLogin(ctx context.Context, identityName string) (*openfga.User, error) { + const op = errors.Op("jimm.UserLogin") + user, err := j.getUser(ctx, identityName) + if err != nil { + return nil, errors.E(op, err, errors.CodeUnauthorized) + } + err = j.updateUserLastLogin(ctx, identityName) + if err != nil { + return nil, errors.E(op, err, errors.CodeUnauthorized) + } + return user, nil +} + +// getUser fetches the user specified by the user's email or the service accounts ID // and returns an openfga User that can be used to verify user's permissions. -func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, error) { +func (j *JIMM) getUser(ctx context.Context, identifier string) (*openfga.User, error) { const op = errors.Op("jimm.GetUser") user, err := dbmodel.NewIdentity(identifier) @@ -36,7 +50,8 @@ func (j *JIMM) GetUser(ctx context.Context, identifier string) (*openfga.User, e return u, nil } -func (j *JIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error { +// updateUserLastLogin updates the user's last login time in the database. +func (j *JIMM) updateUserLastLogin(ctx context.Context, identifier string) error { const op = errors.Op("jimm.UpdateUserLastLogin") user, err := dbmodel.NewIdentity(identifier) if err != nil { diff --git a/internal/jimmjwx/jwks.go b/internal/jimmjwx/jwks.go index 2657d66aa..0343b80be 100644 --- a/internal/jimmjwx/jwks.go +++ b/internal/jimmjwx/jwks.go @@ -106,10 +106,6 @@ func (jwks *JWKSService) StartJWKSRotator(ctx context.Context, checkRotateRequir credStore := jwks.credentialStore - // For logging and monitoring purposes, we have the rotator spit errors into - // this buffered channel ((size * amount) * 2 of errors we are currently aware of and doubling it to prevent blocks) - errorChan := make(chan error, 8) - if err := rotateJWKS(ctx, credStore, initialRotateRequiredTime); err != nil { zapctx.Error(ctx, "Rotate JWKS error", zap.Error(err)) return errors.E(op, err) @@ -121,14 +117,12 @@ func (jwks *JWKSService) StartJWKSRotator(ctx context.Context, checkRotateRequir // // In this case we generate a new set, which should expire in 3 months. go func() { - defer close(errorChan) for { select { case <-checkRotateRequired: if err := rotateJWKS(ctx, credStore, initialRotateRequiredTime); err != nil { - errorChan <- err + zapctx.Error(ctx, "security failure", zap.Any("op", op), zap.NamedError("jwks-error", err)) } - case <-ctx.Done(): zapctx.Debug(ctx, "Shutdown for JWKS rotator complete.") return @@ -136,24 +130,6 @@ func (jwks *JWKSService) StartJWKSRotator(ctx context.Context, checkRotateRequir } }() - // If for any reason the rotator has an error, we simply receive the error - // in another routine dedicated to logging said errors. - go func(errChan <-chan error) { - for err := range errChan { - zapctx.Error( - ctx, - "security failure", - zap.Any("op", op), - zap.NamedError("jwks-error", err), - ) - select { - case <-ctx.Done(): - return - default: - } - } - }(errorChan) - return nil } diff --git a/internal/jimmjwx/utils_test.go b/internal/jimmjwx/utils_test.go index f8111f5ea..91f315523 100644 --- a/internal/jimmjwx/utils_test.go +++ b/internal/jimmjwx/utils_test.go @@ -66,6 +66,9 @@ func startAndTestRotator(c *qt.C, ctx context.Context, store credentials.Credent for i := 0; i < 60; i++ { if ks == nil { ks, err = store.GetJWKS(ctx) + if err != nil { + c.Logf("failed to get JWKS: %s", err) + } time.Sleep(500 * time.Millisecond) continue } diff --git a/internal/jimmtest/auth.go b/internal/jimmtest/auth.go index 266399455..2eb24d476 100644 --- a/internal/jimmtest/auth.go +++ b/internal/jimmtest/auth.go @@ -67,7 +67,7 @@ func (a Authenticator) Authenticate(_ context.Context, _ *jujuparams.LoginReques return a.User, a.Err } -type MockOAuthAuthenticator struct { +type mockOAuthAuthenticator struct { jimm.OAuthAuthenticator c SimpleTester // PollingChan is used to simulate polling an OIDC server during the device flow. @@ -77,12 +77,15 @@ type MockOAuthAuthenticator struct { mockAccessToken string } -func NewMockOAuthAuthenticator(c SimpleTester, testChan <-chan string) MockOAuthAuthenticator { - return MockOAuthAuthenticator{c: c, PollingChan: testChan} +// NewMockOAuthAuthenticator creates a mock authenticator for tests. An channel can be passed in +// when testing the device flow to simulate polling an OIDC server. Provide a nil channel +// if the device flow will not be used in the test. +func NewMockOAuthAuthenticator(c SimpleTester, testChan <-chan string) mockOAuthAuthenticator { + return mockOAuthAuthenticator{c: c, PollingChan: testChan} } // Device is a mock implementation for the start of the device flow, returning dummy polling data. -func (m *MockOAuthAuthenticator) Device(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { +func (m *mockOAuthAuthenticator) Device(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { return &oauth2.DeviceAuthResponse{ DeviceCode: "test-device-code", UserCode: "test-user-code", @@ -95,7 +98,7 @@ func (m *MockOAuthAuthenticator) Device(ctx context.Context) (*oauth2.DeviceAuth // DeviceAccessToken is a mock implementation of the second step in the device flow where JIMM // polls an OIDC server for the device code. -func (m *MockOAuthAuthenticator) DeviceAccessToken(ctx context.Context, res *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { +func (m *mockOAuthAuthenticator) DeviceAccessToken(ctx context.Context, res *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { select { case username := <-m.PollingChan: m.polledUsername = username @@ -113,7 +116,7 @@ func (m *MockOAuthAuthenticator) DeviceAccessToken(ctx context.Context, res *oau // VerifySessionToken provides the mock implementation for verifying session tokens. // Allowing JIMM tests to create their own session tokens that will always be accepted. // Notice the use of jwt.ParseInsecure to skip JWT signature verification. -func (m *MockOAuthAuthenticator) VerifySessionToken(token string) (jwt.Token, error) { +func (m *mockOAuthAuthenticator) VerifySessionToken(token string) (jwt.Token, error) { errorFn := func(err error) error { return jimmerrors.E(err, jimmerrors.CodeUnauthorized) } @@ -136,7 +139,7 @@ func (m *MockOAuthAuthenticator) VerifySessionToken(token string) (jwt.Token, er // ExtractAndVerifyIDToken returns an ID token where the subject is equal to the username obtained during the device flow. // The auth token must match the one returned during the device flow. // If the polled username is empty it indicates an error that the device flow was not run prior to calling this function. -func (m *MockOAuthAuthenticator) ExtractAndVerifyIDToken(ctx context.Context, oauth2Token *oauth2.Token) (*oidc.IDToken, error) { +func (m *mockOAuthAuthenticator) ExtractAndVerifyIDToken(ctx context.Context, oauth2Token *oauth2.Token) (*oidc.IDToken, error) { if m.polledUsername == "" { return &oidc.IDToken{}, errors.New("unknown user for mock auth login") } @@ -147,25 +150,30 @@ func (m *MockOAuthAuthenticator) ExtractAndVerifyIDToken(ctx context.Context, oa } // Email returns the subject from an ID token. -func (m *MockOAuthAuthenticator) Email(idToken *oidc.IDToken) (string, error) { +func (m *mockOAuthAuthenticator) Email(idToken *oidc.IDToken) (string, error) { return idToken.Subject, nil } // UpdateIdentity is a no-op mock. -func (m *MockOAuthAuthenticator) UpdateIdentity(ctx context.Context, email string, token *oauth2.Token) error { +func (m *mockOAuthAuthenticator) UpdateIdentity(ctx context.Context, email string, token *oauth2.Token) error { return nil } // MintSessionToken creates an unsigned session token with the email provided. -func (m *MockOAuthAuthenticator) MintSessionToken(email string) (string, error) { +func (m *mockOAuthAuthenticator) MintSessionToken(email string) (string, error) { return newSessionToken(m.c, email, ""), nil } // AuthenticateBrowserSession always returns an error. -func (m *MockOAuthAuthenticator) AuthenticateBrowserSession(ctx context.Context, w http.ResponseWriter, req *http.Request) (context.Context, error) { +func (m *mockOAuthAuthenticator) AuthenticateBrowserSession(ctx context.Context, w http.ResponseWriter, req *http.Request) (context.Context, error) { return ctx, errors.New("authentication failed") } +// VerifyClientCredentials always returns a nil error. +func (m *mockOAuthAuthenticator) VerifyClientCredentials(ctx context.Context, clientID string, clientSecret string) error { + return nil +} + // newSessionToken returns a serialised JWT that can be used in tests. // Tests using a mock authenticator can provide an empty signatureSecret // while integration tests must provide the same secret used when verifying JWTs. diff --git a/internal/jimmtest/jimm_mock.go b/internal/jimmtest/jimm_mock.go index 04991f69a..5eff89a5e 100644 --- a/internal/jimmtest/jimm_mock.go +++ b/internal/jimmtest/jimm_mock.go @@ -17,6 +17,7 @@ import ( "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimm" jimmcreds "github.com/canonical/jimm/v3/internal/jimm/credentials" + "github.com/canonical/jimm/v3/internal/jimmtest/mocks" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/pubsub" @@ -29,6 +30,7 @@ import ( // will delegate to the requested funcion or if the funcion is nil return // a NotImplemented error. type JIMM struct { + mocks.LoginService AddAuditLogEntry_ func(ale *dbmodel.AuditLogEntry) AddCloudToController_ func(ctx context.Context, user *openfga.User, controllerName string, tag names.CloudTag, cloud jujuparams.Cloud, force bool) error AddController_ func(ctx context.Context, u *openfga.User, ctl *dbmodel.Controller) error @@ -63,7 +65,6 @@ type JIMM struct { GetControllerConfig_ func(ctx context.Context, u *dbmodel.Identity) (*dbmodel.ControllerConfig, error) GetCredentialStore_ func() jimmcreds.CredentialStore GetJimmControllerAccess_ func(ctx context.Context, user *openfga.User, tag names.UserTag) (string, error) - GetUser_ func(ctx context.Context, username string) (*openfga.User, error) GetUserCloudAccess_ func(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error) GetUserControllerAccess_ func(ctx context.Context, user *openfga.User, controller names.ControllerTag) (string, error) GetUserModelAccess_ func(ctx context.Context, user *openfga.User, model names.ModelTag) (string, error) @@ -109,7 +110,7 @@ type JIMM struct { UpdateCloud_ func(ctx context.Context, u *openfga.User, ct names.CloudTag, cloud jujuparams.Cloud) error UpdateCloudCredential_ func(ctx context.Context, u *openfga.User, args jimm.UpdateCloudCredentialArgs) ([]jujuparams.UpdateCredentialModelResult, error) UpdateMigratedModel_ func(ctx context.Context, user *openfga.User, modelTag names.ModelTag, targetControllerName string) error - UpdateUserLastLogin_ func(ctx context.Context, identifier string) error + UserLogin_ func(ctx context.Context, identityName string) (*openfga.User, error) ValidateModelUpgrade_ func(ctx context.Context, u *openfga.User, mt names.ModelTag, force bool) error WatchAllModelSummaries_ func(ctx context.Context, controller *dbmodel.Controller) (_ func() error, err error) } @@ -321,12 +322,6 @@ func (j *JIMM) GetJimmControllerAccess(ctx context.Context, user *openfga.User, } return j.GetJimmControllerAccess_(ctx, user, tag) } -func (j *JIMM) GetUser(ctx context.Context, username string) (*openfga.User, error) { - if j.GetUser_ == nil { - return nil, errors.E(errors.CodeNotImplemented) - } - return j.GetUser_(ctx, username) -} func (j *JIMM) GetUserCloudAccess(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error) { if j.GetUserCloudAccess_ == nil { return "", errors.E(errors.CodeNotImplemented) @@ -593,11 +588,11 @@ func (j *JIMM) UpdateMigratedModel(ctx context.Context, user *openfga.User, mode } return j.UpdateMigratedModel_(ctx, user, modelTag, targetControllerName) } -func (j *JIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error { - if j.UpdateUserLastLogin_ == nil { - return errors.E(errors.CodeNotImplemented) +func (j *JIMM) UserLogin(ctx context.Context, identityName string) (*openfga.User, error) { + if j.UserLogin_ == nil { + return nil, errors.E(errors.CodeNotImplemented) } - return j.UpdateUserLastLogin(ctx, identifier) + return j.UserLogin_(ctx, identityName) } func (j *JIMM) IdentityModelDefaults(ctx context.Context, user *dbmodel.Identity) (map[string]interface{}, error) { if j.IdentityModelDefaults_ == nil { diff --git a/internal/jimmtest/mocks/login.go b/internal/jimmtest/mocks/login.go new file mode 100644 index 000000000..0fa6fdf88 --- /dev/null +++ b/internal/jimmtest/mocks/login.go @@ -0,0 +1,54 @@ +// Copyright 2024 Canonical. +package mocks + +import ( + "context" + + "golang.org/x/oauth2" + + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/openfga" +) + +type LoginService struct { + LoginDevice_ func(ctx context.Context) (*oauth2.DeviceAuthResponse, error) + GetDeviceSessionToken_ func(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) + LoginClientCredentials_ func(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) + LoginWithSessionToken_ func(ctx context.Context, sessionToken string) (*openfga.User, error) + LoginWithSessionCookie_ func(ctx context.Context, identityID string) (*openfga.User, error) +} + +func (j *LoginService) LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { + if j.LoginDevice_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.LoginDevice_(ctx) +} + +func (j *LoginService) GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) { + if j.GetDeviceSessionToken_ == nil { + return "", errors.E(errors.CodeNotImplemented) + } + return j.GetDeviceSessionToken_(ctx, deviceOAuthResponse) +} + +func (j *LoginService) LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) { + if j.LoginClientCredentials_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.LoginClientCredentials_(ctx, clientID, clientSecret) +} + +func (j *LoginService) LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) { + if j.LoginWithSessionToken_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.LoginWithSessionToken_(ctx, sessionToken) +} + +func (j *LoginService) LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) { + if j.LoginWithSessionCookie_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.LoginWithSessionCookie_(ctx, identityID) +} diff --git a/internal/jujuapi/admin.go b/internal/jujuapi/admin.go index ed1b5ade8..246feb3bb 100644 --- a/internal/jujuapi/admin.go +++ b/internal/jujuapi/admin.go @@ -4,21 +4,32 @@ package jujuapi import ( "context" - stderrors "errors" "sort" "github.com/juju/juju/rpc" jujuparams "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" + "golang.org/x/oauth2" - "github.com/canonical/jimm/v3/internal/auth" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimm" "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/pkg/api/params" - jimmnames "github.com/canonical/jimm/v3/pkg/names" ) +// LoginService defines the set of methods used for login to JIMM. +type LoginService interface { + // LoginDevice is step 1 in the device flow and returns the OIDC server that the client should use for login. + LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) + // GetDeviceSessionToken polls the OIDC server waiting for the client to login and return a user scoped session token. + GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) + // LoginWithClientCredentials verifies a user by their client credentials. + LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) + // LoginWithSessionToken verifies a user based on their session token. + LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) + // LoginWithSessionCookie verifies a user based on an identity from a cookie obtained during websocket upgrade. + LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) +} + // unsupportedLogin returns an appropriate error for login attempts using // old version of the Admin facade. func unsupportedLogin() error { @@ -39,9 +50,9 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes const op = errors.Op("jujuapi.LoginDevice") response := params.LoginDeviceResponse{} - deviceResponse, err := jimm.LoginDevice(ctx, r.jimm.OAuthAuthenticationService()) + deviceResponse, err := r.jimm.LoginDevice(ctx) if err != nil { - return response, errors.E(op, err) + return response, errors.E(op, err, errors.CodeUnauthorized) } // NOTE: As this is on the controller root struct, and a new controller root // is created per WS, it is EXPECTED that the subsequent call to GetDeviceSessionToken @@ -63,9 +74,9 @@ func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetD const op = errors.Op("jujuapi.GetDeviceSessionToken") response := params.GetDeviceSessionTokenResponse{} - token, err := jimm.GetDeviceSessionToken(ctx, r.jimm.OAuthAuthenticationService(), r.jimm.GetCredentialStore(), r.deviceOAuthResponse) + token, err := r.jimm.GetDeviceSessionToken(ctx, r.deviceOAuthResponse) if err != nil { - return response, errors.E(op, err) + return response, errors.E(op, err, errors.CodeUnauthorized) } response.SessionToken = token @@ -81,19 +92,9 @@ func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetD func (r *controllerRoot) LoginWithSessionCookie(ctx context.Context) (jujuparams.LoginResult, error) { const op = errors.Op("jujuapi.LoginWithSessionCookie") - // If no identity ID has come through, then no cookie was present - // and as such authentication has failed. - if r.identityId == "" { - return jujuparams.LoginResult{}, errors.E(op, &auth.AuthenticationError{}) - } - - user, err := r.jimm.GetUser(ctx, r.identityId) - if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) - } - err = r.jimm.UpdateUserLastLogin(ctx, r.identityId) + user, err := r.jimm.LoginWithSessionCookie(ctx, r.identityId) if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) + return jujuparams.LoginResult{}, errors.E(op, err, errors.CodeUnauthorized) } r.mu.Lock() @@ -122,36 +123,12 @@ func (r *controllerRoot) LoginWithSessionCookie(ctx context.Context) (jujuparams // such that subsequent facade method calls can access the authenticated user. func (r *controllerRoot) LoginWithSessionToken(ctx context.Context, req params.LoginWithSessionTokenRequest) (jujuparams.LoginResult, error) { const op = errors.Op("jujuapi.LoginWithSessionToken") - authenticationSvc := r.jimm.OAuthAuthenticationService() - // Verify the session token - jwtToken, err := authenticationSvc.VerifySessionToken(req.SessionToken) + user, err := r.jimm.LoginWithSessionToken(ctx, req.SessionToken) if err != nil { - var aerr *auth.AuthenticationError - if stderrors.As(err, &aerr) { - return aerr.LoginResult, nil - } return jujuparams.LoginResult{}, errors.E(op, err, errors.CodeUnauthorized) } - // Get an OpenFGA user to place on the controllerRoot for this WS - // such that: - // - // - Subsequent calls are aware of the user - // - Authorisation checks are done against the openfga.User - email := jwtToken.Subject() - - // At this point, we know the user exists, so simply just get - // the user to create the session token. - user, err := r.jimm.GetUser(ctx, email) - if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) - } - err = r.jimm.UpdateUserLastLogin(ctx, email) - if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) - } - // TODO(ale8k): This isn't needed I don't think as controller roots are unique // per WS, but if anyone knows different please let me know. r.mu.Lock() @@ -178,29 +155,11 @@ func (r *controllerRoot) LoginWithSessionToken(ctx context.Context, req params.L func (r *controllerRoot) LoginWithClientCredentials(ctx context.Context, req params.LoginWithClientCredentialsRequest) (jujuparams.LoginResult, error) { const op = errors.Op("jujuapi.LoginWithClientCredentials") - clientIdWithDomain, err := jimmnames.EnsureValidServiceAccountId(req.ClientID) - if err != nil { - return jujuparams.LoginResult{}, errors.E("invalid client ID") - } - - authenticationSvc := r.jimm.OAuthAuthenticationService() - if authenticationSvc == nil { - return jujuparams.LoginResult{}, errors.E("authentication service not specified") - } - err = authenticationSvc.VerifyClientCredentials(ctx, req.ClientID, req.ClientSecret) + user, err := r.jimm.LoginClientCredentials(ctx, req.ClientID, req.ClientSecret) if err != nil { return jujuparams.LoginResult{}, errors.E(err, errors.CodeUnauthorized) } - user, err := r.jimm.GetUser(ctx, clientIdWithDomain) - if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) - } - err = r.jimm.UpdateUserLastLogin(ctx, clientIdWithDomain) - if err != nil { - return jujuparams.LoginResult{}, errors.E(op, err) - } - r.mu.Lock() r.user = user r.mu.Unlock() diff --git a/internal/jujuapi/admin_test.go b/internal/jujuapi/admin_test.go index dbaa8e3e1..94e2db196 100644 --- a/internal/jujuapi/admin_test.go +++ b/internal/jujuapi/admin_test.go @@ -175,11 +175,7 @@ func (s *adminSuite) TestBrowserLoginNoCookie(c *gc.C) { lr := &jujuparams.LoginResult{} err := conn.APICall("Admin", 4, "", "LoginWithSessionCookie", nil, lr) - c.Assert( - err, - gc.ErrorMatches, - "authentication failed", - ) + c.Assert(err, gc.ErrorMatches, `missing cookie identity \(unauthorized access\)`) } // TestDeviceLogin takes a test user through the flow of logging into jimm diff --git a/internal/jujuapi/controllerroot.go b/internal/jujuapi/controllerroot.go index ed6ddd033..f0e620d1e 100644 --- a/internal/jujuapi/controllerroot.go +++ b/internal/jujuapi/controllerroot.go @@ -29,6 +29,7 @@ import ( ) type JIMM interface { + LoginService AddAuditLogEntry(ale *dbmodel.AuditLogEntry) AddCloudToController(ctx context.Context, user *openfga.User, controllerName string, tag names.CloudTag, cloud jujuparams.Cloud, force bool) error AddController(ctx context.Context, u *openfga.User, ctl *dbmodel.Controller) error @@ -36,7 +37,6 @@ type JIMM interface { AddGroup(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) AddModel(ctx context.Context, u *openfga.User, args *jimm.ModelCreateArgs) (_ *jujuparams.ModelInfo, err error) AddServiceAccount(ctx context.Context, u *openfga.User, clientId string) error - OAuthAuthenticationService() jimm.OAuthAuthenticator AuthorizationClient() *openfga.OFGAClient ChangeModelCredential(ctx context.Context, user *openfga.User, modelTag names.ModelTag, cloudCredentialTag names.CloudCredentialTag) error CopyServiceAccountCredential(ctx context.Context, u *openfga.User, svcAcc *openfga.User, cloudCredentialTag names.CloudCredentialTag) (names.CloudCredentialTag, []jujuparams.UpdateCredentialModelResult, error) @@ -62,7 +62,6 @@ type JIMM interface { GetControllerConfig(ctx context.Context, u *dbmodel.Identity) (*dbmodel.ControllerConfig, error) GetCredentialStore() credentials.CredentialStore GetJimmControllerAccess(ctx context.Context, user *openfga.User, tag names.UserTag) (string, error) - GetUser(ctx context.Context, username string) (*openfga.User, error) GetUserCloudAccess(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error) GetUserControllerAccess(ctx context.Context, user *openfga.User, controller names.ControllerTag) (string, error) GetUserModelAccess(ctx context.Context, user *openfga.User, model names.ModelTag) (string, error) @@ -105,7 +104,7 @@ type JIMM interface { UpdateCloud(ctx context.Context, u *openfga.User, ct names.CloudTag, cloud jujuparams.Cloud) error UpdateCloudCredential(ctx context.Context, u *openfga.User, args jimm.UpdateCloudCredentialArgs) ([]jujuparams.UpdateCredentialModelResult, error) UpdateMigratedModel(ctx context.Context, user *openfga.User, modelTag names.ModelTag, targetControllerName string) error - UpdateUserLastLogin(ctx context.Context, identifier string) error + UserLogin(ctx context.Context, identityName string) (*openfga.User, error) ValidateModelUpgrade(ctx context.Context, u *openfga.User, mt names.ModelTag, force bool) error WatchAllModelSummaries(ctx context.Context, controller *dbmodel.Controller) (_ func() error, err error) } @@ -181,7 +180,7 @@ func (r *controllerRoot) masquerade(ctx context.Context, userTag string) (*openf if !r.user.JimmAdmin { return nil, errors.E(errors.CodeUnauthorized, "unauthorized") } - user, err := r.jimm.GetUser(ctx, ut.Id()) + user, err := r.jimm.UserLogin(ctx, ut.Id()) if err != nil { return nil, err } diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index 6bfd1fea1..df7c73318 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -151,7 +151,7 @@ func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Con TokenGen: &jwtGenerator, ConnectController: connectionFunc, AuditLog: auditLogger, - JIMM: s.jimm, + LoginService: s.jimm, AuthenticatedIdentityID: auth.SessionIdentityFromContext(ctx), } if err := jimmRPC.ProxySockets(ctx, proxyHelpers); err != nil { diff --git a/internal/rpc/client_test.go b/internal/rpc/client_test.go index 8fd7b26aa..092d3120b 100644 --- a/internal/rpc/client_test.go +++ b/internal/rpc/client_test.go @@ -251,7 +251,7 @@ func TestProxySockets(t *testing.T) { testTokenGen := testTokenGenerator{} f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) - c.Assert(err, qt.IsNil) + c.Check(err, qt.IsNil) return rpc.WebsocketConnectionWithMetadata{ Conn: connController, ModelName: "TestName", @@ -263,8 +263,10 @@ func TestProxySockets(t *testing.T) { TokenGen: &testTokenGen, ConnectController: f, AuditLog: auditLogger, + LoginService: &mockLoginService{}, } err := rpc.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.ErrorMatches, ".*use of closed network connection") errChan <- err return err }) @@ -280,8 +282,17 @@ func TestProxySockets(t *testing.T) { err = ws.WriteJSON(&msg) c.Assert(err, qt.IsNil) resp := rpc.Message{} - err = ws.ReadJSON(&resp) - c.Assert(err, qt.IsNil) + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } c.Assert(resp.Response, qt.DeepEquals, msg.Params) ws.Close() <-errChan // Ensure go routines are cleaned up @@ -299,7 +310,7 @@ func TestCancelProxySockets(t *testing.T) { testTokenGen := testTokenGenerator{} f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) - c.Assert(err, qt.IsNil) + c.Check(err, qt.IsNil) return rpc.WebsocketConnectionWithMetadata{ Conn: connController, ModelName: "TestName", @@ -311,8 +322,10 @@ func TestCancelProxySockets(t *testing.T) { TokenGen: &testTokenGen, ConnectController: f, AuditLog: auditLogger, + LoginService: &mockLoginService{}, } err := rpc.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.ErrorMatches, "Context cancelled") errChan <- err return err }) @@ -323,8 +336,7 @@ func TestCancelProxySockets(t *testing.T) { c.Assert(err, qt.IsNil) defer ws.Close() cancel() - err = <-errChan - c.Assert(err.Error(), qt.Equals, "Context cancelled") + <-errChan } func TestProxySocketsAuditLogs(t *testing.T) { @@ -337,10 +349,11 @@ func TestProxySocketsAuditLogs(t *testing.T) { errChan := make(chan error) srvJIMM := newServer(func(connClient *websocket.Conn) error { + defer connClient.Close() testTokenGen := testTokenGenerator{} f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) - c.Assert(err, qt.IsNil) + c.Check(err, qt.IsNil) return rpc.WebsocketConnectionWithMetadata{ Conn: connController, ModelName: "TestModelName", @@ -352,8 +365,10 @@ func TestProxySocketsAuditLogs(t *testing.T) { TokenGen: &testTokenGen, ConnectController: f, AuditLog: auditLogger, + LoginService: &mockLoginService{}, } err := rpc.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.ErrorMatches, ".*use of closed network connection") errChan <- err return err }) diff --git a/internal/rpc/proxy.go b/internal/rpc/proxy.go index f781510f9..4e72ec2c2 100644 --- a/internal/rpc/proxy.go +++ b/internal/rpc/proxy.go @@ -16,13 +16,10 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimm" - "github.com/canonical/jimm/v3/internal/jimm/credentials" "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/internal/servermon" "github.com/canonical/jimm/v3/internal/utils" apiparams "github.com/canonical/jimm/v3/pkg/api/params" - jimmnames "github.com/canonical/jimm/v3/pkg/names" ) const ( @@ -60,12 +57,14 @@ type WebsocketConnectionWithMetadata struct { ModelName string } -// JIMM represents the JIMM interface used by the proxy. -type JIMM interface { - GetUser(ctx context.Context, identifier string) (*openfga.User, error) - UpdateUserLastLogin(ctx context.Context, identifier string) error - OAuthAuthenticationService() jimm.OAuthAuthenticator - GetCredentialStore() credentials.CredentialStore +// LoginService represents the LoginService interface used by the proxy. +// Currently this is a duplicate of the [jujuapi.LoginService]. +type LoginService interface { + LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) + GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) + LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) + LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) + LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) } // ProxyHelpers contains all the necessary helpers for proxying a Juju client @@ -75,7 +74,7 @@ type ProxyHelpers struct { TokenGen TokenGenerator ConnectController func(context.Context) (WebsocketConnectionWithMetadata, error) AuditLog func(*dbmodel.AuditLogEntry) - JIMM JIMM + LoginService LoginService AuthenticatedIdentityID string } @@ -92,6 +91,10 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error { zapctx.Error(ctx, "Missing audit log function") return errors.E(op, "Missing audit log function") } + if helpers.LoginService == nil { + zapctx.Error(ctx, "Missing login service function") + return errors.E(op, "Missing login service function") + } errChan := make(chan error, 2) msgInFlight := inflightMsgs{messages: make(map[uint64]*message)} client := writeLockConn{conn: helpers.ConnClient} @@ -104,7 +107,7 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error { tokenGen: helpers.TokenGen, auditLog: helpers.AuditLog, conversationId: utils.NewConversationID(), - jimm: helpers.JIMM, + loginService: helpers.LoginService, authenticatedIdentityID: helpers.AuthenticatedIdentityID, }, errChan: errChan, @@ -242,7 +245,7 @@ type modelProxy struct { msgs *inflightMsgs auditLog func(*dbmodel.AuditLogEntry) tokenGen TokenGenerator - jimm JIMM + loginService LoginService modelName string conversationId string authenticatedIdentityID string @@ -609,7 +612,18 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie errorFnc := func(err error) (*message, *message, error) { return nil, nil, err } - controllerLoginMessageFnc := func(data []byte) (*message, *message, error) { + controllerLoginMessageFnc := func(user *openfga.User) (*message, *message, error) { + jwt, err := p.tokenGen.MakeLoginToken(ctx, user) + if err != nil { + return errorFnc(err) + } + data, err := json.Marshal(params.LoginRequest{ + AuthTag: names.NewUserTag(user.Name).String(), + Token: base64.StdEncoding.EncodeToString(jwt), + }) + if err != nil { + return errorFnc(err) + } m := *msg m.Type = "Admin" m.Request = "Login" @@ -619,7 +633,7 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie } switch msg.Request { case "LoginDevice": - deviceResponse, err := jimm.LoginDevice(ctx, p.jimm.OAuthAuthenticationService()) + deviceResponse, err := p.loginService.LoginDevice(ctx) if err != nil { return errorFnc(err) } @@ -635,7 +649,7 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie msg.Response = data return msg, nil, nil case "GetDeviceSessionToken": - sessionToken, err := jimm.GetDeviceSessionToken(ctx, p.jimm.OAuthAuthenticationService(), p.jimm.GetCredentialStore(), p.deviceOAuthResponse) + sessionToken, err := p.loginService.GetDeviceSessionToken(ctx, p.deviceOAuthResponse) if err != nil { return errorFnc(err) } @@ -654,95 +668,31 @@ func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (clie return errorFnc(err) } - // Verify the session token - token, err := p.jimm.OAuthAuthenticationService().VerifySessionToken(request.SessionToken) + user, err := p.loginService.LoginWithSessionToken(ctx, request.SessionToken) if err != nil { return errorFnc(err) } - email := token.Subject() - user, err := p.jimm.GetUser(ctx, email) - if err != nil { - return errorFnc(err) - } - err = p.jimm.UpdateUserLastLogin(ctx, email) - if err != nil { - return errorFnc(err) - } - - jwt, err := p.tokenGen.MakeLoginToken(ctx, user) - if err != nil { - return errorFnc(err) - } - data, err := json.Marshal(params.LoginRequest{ - AuthTag: names.NewUserTag(email).String(), - Token: base64.StdEncoding.EncodeToString(jwt), - }) - if err != nil { - return errorFnc(err) - } - return controllerLoginMessageFnc(data) + return controllerLoginMessageFnc(user) case "LoginWithClientCredentials": var request apiparams.LoginWithClientCredentialsRequest err := json.Unmarshal(msg.Params, &request) if err != nil { return errorFnc(err) } - clientIdWithDomain, err := jimmnames.EnsureValidServiceAccountId(request.ClientID) - if err != nil { - return errorFnc(err) - } - err = p.jimm.OAuthAuthenticationService().VerifyClientCredentials(ctx, request.ClientID, request.ClientSecret) + user, err := p.loginService.LoginClientCredentials(ctx, request.ClientID, request.ClientSecret) if err != nil { return errorFnc(err) } - user, err := p.jimm.GetUser(ctx, clientIdWithDomain) - if err != nil { - return errorFnc(err) - } - err = p.jimm.UpdateUserLastLogin(ctx, clientIdWithDomain) - if err != nil { - return errorFnc(err) - } - - jwt, err := p.tokenGen.MakeLoginToken(ctx, user) - if err != nil { - return errorFnc(err) - } - data, err := json.Marshal(params.LoginRequest{ - AuthTag: names.NewUserTag(clientIdWithDomain).String(), - Token: base64.StdEncoding.EncodeToString(jwt), - }) - if err != nil { - return errorFnc(err) - } - return controllerLoginMessageFnc(data) + return controllerLoginMessageFnc(user) case "LoginWithSessionCookie": - if p.modelProxy.authenticatedIdentityID == "" { - return errorFnc(errors.E(errors.CodeUnauthorized)) - } - user, err := p.jimm.GetUser(ctx, p.modelProxy.authenticatedIdentityID) - if err != nil { - return errorFnc(err) - } - err = p.jimm.UpdateUserLastLogin(ctx, p.modelProxy.authenticatedIdentityID) + user, err := p.loginService.LoginWithSessionCookie(ctx, p.modelProxy.authenticatedIdentityID) if err != nil { return errorFnc(err) } - jwt, err := p.tokenGen.MakeLoginToken(ctx, user) - if err != nil { - return errorFnc(err) - } - data, err := json.Marshal(params.LoginRequest{ - AuthTag: user.ResourceTag().String(), - Token: base64.StdEncoding.EncodeToString(jwt), - }) - if err != nil { - return errorFnc(err) - } - return controllerLoginMessageFnc(data) + return controllerLoginMessageFnc(user) case "Login": return errorFnc(errors.E("JIMM does not support login from old clients", errors.CodeNotSupported)) default: diff --git a/internal/rpc/proxy_test.go b/internal/rpc/proxy_test.go index 77c9d7a41..7d7ab6238 100644 --- a/internal/rpc/proxy_test.go +++ b/internal/rpc/proxy_test.go @@ -9,22 +9,18 @@ import ( "testing" "time" - "github.com/coreos/go-oidc/v3/oidc" qt "github.com/frankban/quicktest" "github.com/google/uuid" "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" - "github.com/lestrrat-go/jwx/v2/jwt" "golang.org/x/oauth2" "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimm" - "github.com/canonical/jimm/v3/internal/jimm/credentials" - "github.com/canonical/jimm/v3/internal/jimmtest" "github.com/canonical/jimm/v3/internal/openfga" "github.com/canonical/jimm/v3/internal/rpc" apiparams "github.com/canonical/jimm/v3/pkg/api/params" + jimmnames "github.com/canonical/jimm/v3/pkg/names" ) type message struct { @@ -238,9 +234,11 @@ func TestProxySocketsAdminFacade(t *testing.T) { for _, test := range tests { t.Run(test.about, func(t *testing.T) { ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() clientWebsocket := newMockWebsocketConnection(10) controllerWebsocket := newMockWebsocketConnection(10) - authenticator := &mockOAuthAuthenticator{ + loginSvc := &mockLoginService{ email: "alice@wonderland.io", clientID: clientID, clientSecret: clientSecret, @@ -257,59 +255,54 @@ func TestProxySocketsAdminFacade(t *testing.T) { ControllerUUID: uuid.NewString(), }, nil }, - AuditLog: func(*dbmodel.AuditLogEntry) {}, - JIMM: &mockJIMM{ - authenticator: authenticator, - }, + AuditLog: func(*dbmodel.AuditLogEntry) {}, + LoginService: loginSvc, AuthenticatedIdentityID: test.authenticateEntityID, } + var wg sync.WaitGroup + wg.Add(1) go func() { + defer wg.Done() err = rpc.ProxySockets(ctx, helpers) - c.Assert(err, qt.IsNil) + c.Assert(err, qt.ErrorMatches, "Context cancelled") }() data, err := json.Marshal(test.messageToSend) c.Assert(err, qt.IsNil) - select { - case clientWebsocket.read <- data: - default: - c.Fatal("failed to send message") - } + clientWebsocket.read <- data if test.expectedClientResponse != nil { select { case data := <-clientWebsocket.write: c.Assert(string(data), qt.JSONEquals, test.expectedClientResponse) - case <-time.Tick(10 * time.Minute): - c.Fatal("time out waiting for response") + case <-time.Tick(2 * time.Second): + c.Fatal("timed out waiting for response") } } if test.expectedControllerMessage != nil { select { case data := <-controllerWebsocket.write: c.Assert(string(data), qt.JSONEquals, test.expectedControllerMessage) - case <-time.Tick(10 * time.Minute): - c.Fatal("time out waiting for response") + case <-time.Tick(2 * time.Second): + c.Fatal("timed out waiting for response") } } + cancelFunc() + wg.Wait() + t.Logf("completed test %s", t.Name()) }) } } -type mockOAuthAuthenticator struct { - jimm.OAuthAuthenticator - - err error - +type mockLoginService struct { + err error email string clientID string clientSecret string - - updatedEmail string } -func (m *mockOAuthAuthenticator) Device(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { - if m.err != nil { - return nil, m.err +func (j *mockLoginService) LoginDevice(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { + if j.err != nil { + return nil, j.err } return &oauth2.DeviceAuthResponse{ DeviceCode: "test-device-code", @@ -320,94 +313,48 @@ func (m *mockOAuthAuthenticator) Device(ctx context.Context) (*oauth2.DeviceAuth Interval: int64(time.Minute.Seconds()), }, nil } - -func (m *mockOAuthAuthenticator) DeviceAccessToken(ctx context.Context, res *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { - if m.err != nil { - return nil, m.err - } - return &oauth2.Token{}, nil -} - -func (m *mockOAuthAuthenticator) ExtractAndVerifyIDToken(ctx context.Context, oauth2Token *oauth2.Token) (*oidc.IDToken, error) { - if m.err != nil { - return nil, m.err - } - return &oidc.IDToken{}, nil -} - -func (m *mockOAuthAuthenticator) Email(idToken *oidc.IDToken) (string, error) { - if m.err != nil { - return "", m.err - } - if m.email != "" { - return m.email, nil +func (j *mockLoginService) GetDeviceSessionToken(ctx context.Context, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) { + if j.err != nil { + return "", j.err } - return "", errors.E(errors.CodeNotFound) + return "test session token", nil } - -func (m *mockOAuthAuthenticator) UpdateIdentity(ctx context.Context, email string, token *oauth2.Token) error { - if m.err != nil { - return m.err +func (j *mockLoginService) LoginClientCredentials(ctx context.Context, clientID string, clientSecret string) (*openfga.User, error) { + if j.err != nil { + return nil, j.err } - m.updatedEmail = email - return nil -} - -func (m *mockOAuthAuthenticator) VerifyClientCredentials(ctx context.Context, clientID string, clientSecret string) error { - if m.err != nil { - return m.err + if clientID != j.clientID || clientSecret != j.clientSecret { + return nil, errors.E("invalid client credentials") } - if clientID == m.clientID && clientSecret == m.clientSecret { - return nil + clientIdWithDomain, err := jimmnames.EnsureValidServiceAccountId(clientID) + if err != nil { + return nil, errors.E("invalid client credential ID") } - return errors.E(errors.CodeUnauthorized) -} - -func (m *mockOAuthAuthenticator) MintSessionToken(email string) (string, error) { - if m.err != nil { - return "", m.err + identity, err := dbmodel.NewIdentity(clientIdWithDomain) + if err != nil { + return nil, err } - return "test session token", nil + return openfga.NewUser(identity, nil), nil } - -func (m *mockOAuthAuthenticator) VerifySessionToken(token string) (jwt.Token, error) { - if m.err != nil { - return nil, m.err +func (j *mockLoginService) LoginWithSessionToken(ctx context.Context, sessionToken string) (*openfga.User, error) { + if j.err != nil { + return nil, j.err } - t := jwt.New() - - if err := t.Set(jwt.SubjectKey, m.email); err != nil { + identity, err := dbmodel.NewIdentity(j.email) + if err != nil { return nil, err } - - return t, nil -} - -type mockJIMM struct { - authenticator *mockOAuthAuthenticator + return openfga.NewUser(identity, nil), nil } - -func (j *mockJIMM) OAuthAuthenticationService() jimm.OAuthAuthenticator { - return j.authenticator -} - -func (j *mockJIMM) GetUser(ctx context.Context, email string) (*openfga.User, error) { - identity, err := dbmodel.NewIdentity(email) +func (j *mockLoginService) LoginWithSessionCookie(ctx context.Context, identityID string) (*openfga.User, error) { + if j.err != nil { + return nil, j.err + } + identity, err := dbmodel.NewIdentity(j.email) if err != nil { return nil, err } - return openfga.NewUser( - identity, - nil, - ), nil -} - -func (j *mockJIMM) UpdateUserLastLogin(ctx context.Context, identifier string) error { - return nil -} - -func (j *mockJIMM) GetCredentialStore() credentials.CredentialStore { - return jimmtest.NewInMemoryCredentialStore() + return openfga.NewUser(identity, nil), nil } func newMockWebsocketConnection(capacity int) *mockWebsocketConnection { @@ -420,6 +367,7 @@ func newMockWebsocketConnection(capacity int) *mockWebsocketConnection { type mockWebsocketConnection struct { read chan []byte write chan []byte + once sync.Once } func (w *mockWebsocketConnection) ReadJSON(v interface{}) error { @@ -439,7 +387,7 @@ func (w *mockWebsocketConnection) WriteJSON(v interface{}) error { } func (w *mockWebsocketConnection) Close() error { - close(w.read) + w.once.Do(func() { close(w.read) }) return nil }