Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Juju 7074/identities filter #1434

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 14 additions & 27 deletions internal/db/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,30 @@ func (d *Database) GetIdentityCloudCredentials(ctx context.Context, u *dbmodel.I
return credentials, nil
}

// ForEachIdentity iterates through every identity calling the given function
// for each one. If the given function returns an error the iteration
// will stop immediately and the error will be returned unmodified.
func (d *Database) ForEachIdentity(ctx context.Context, limit, offset int, f func(*dbmodel.Identity) error) (err error) {
const op = errors.Op("db.ForEachUSer")
// ListIdentities returns a paginated list of identities defined by limit and offset.
// match is used to fuzzy find based on entries' name using the LIKE operator (ex. LIKE %<match>%).
func (d *Database) ListIdentities(ctx context.Context, limit, offset int, match string) (_ []dbmodel.Identity, err error) {
const op = errors.Op("db.ListIdentities")
if err := d.ready(); err != nil {
return errors.E(op, err)
return nil, errors.E(op, err)
}

durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op))
defer durationObserver()
defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op))

db := d.DB.WithContext(ctx)
rows, err := db.
Model(&dbmodel.Identity{}).
Order("name asc").
Limit(limit).
Offset(offset).
Rows()
if err != nil {
return errors.E(op, err)
if match != "" {
db = db.Where("name LIKE ?", "%"+match+"%")
}
defer rows.Close()
for rows.Next() {
var identity dbmodel.Identity
if err := db.ScanRows(rows, &identity); err != nil {
return errors.E(op, err)
}
if err := f(&identity); err != nil {
return err
}
db = db.Order("name asc")
db = db.Limit(limit)
db = db.Offset(offset)
var identities []dbmodel.Identity
if err := db.Find(&identities).Error; err != nil {
return nil, errors.E(op, dbError(err))
}
if err := rows.Err(); err != nil {
return errors.E(op, dbError(err))
}
return nil
return identities, nil
}

// CountIdentities counts the number of identities.
Expand Down
33 changes: 6 additions & 27 deletions internal/db/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *dbSuite) TestGetIdentityCloudCredentials(c *qt.C) {
c.Assert(credentials, qt.DeepEquals, []dbmodel.CloudCredential{cred1, cred2})
}

func (s *dbSuite) TestForEachIdentity(c *qt.C) {
func (s *dbSuite) TestListIdentities(c *qt.C) {
err := s.Database.Migrate(context.Background(), false)
c.Assert(err, qt.IsNil)

Expand All @@ -192,41 +192,20 @@ func (s *dbSuite) TestForEachIdentity(c *qt.C) {
err = s.Database.GetIdentity(context.Background(), id)
c.Assert(err, qt.IsNil)
}
firstIdentities := []*dbmodel.Identity{}
ctx := context.Background()
err = s.Database.ForEachIdentity(ctx, 5, 0, func(ge *dbmodel.Identity) error {
firstIdentities = append(firstIdentities, ge)
return nil
})
firstIdentities, err := s.Database.ListIdentities(ctx, 5, 0, "")
c.Assert(err, qt.IsNil)
for i := 0; i < 5; i++ {
c.Assert(firstIdentities[i].Name, qt.Equals, fmt.Sprintf("bob%[email protected]", i))
}
secondIdentities := []*dbmodel.Identity{}
err = s.Database.ForEachIdentity(ctx, 5, 5, func(ge *dbmodel.Identity) error {
secondIdentities = append(secondIdentities, ge)
return nil
})
secondIdentities, err := s.Database.ListIdentities(ctx, 5, 5, "")
c.Assert(err, qt.IsNil)
for i := 0; i < 5; i++ {
c.Assert(secondIdentities[i].Name, qt.Equals, fmt.Sprintf("bob%[email protected]", i+5))
}
}

func (s *dbSuite) TestForEachIdentityError(c *qt.C) {
err := s.Database.Migrate(context.Background(), false)
filteredIdentities, err := s.Database.ListIdentities(ctx, 5, 0, "bob0")
c.Assert(err, qt.IsNil)
ctx := context.Background()
// add one identity
id, _ := dbmodel.NewIdentity("[email protected]")
err = s.Database.GetIdentity(context.Background(), id)
c.Assert(err, qt.IsNil)

// test error is returned
errTest := errors.E("test-error")
err = s.Database.ForEachIdentity(ctx, 5, 0, func(ge *dbmodel.Identity) error {
return errTest
})
c.Assert(err, qt.IsNotNil)
c.Assert(err.Error(), qt.Equals, errTest.Error())
c.Assert(filteredIdentities, qt.HasLen, 1)
c.Assert(filteredIdentities[0].Name, qt.Equals, "[email protected]")
}
16 changes: 8 additions & 8 deletions internal/jimm/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ func (j *JIMM) FetchIdentity(ctx context.Context, id string) (*openfga.User, err
}

// ListIdentities lists a page of users in our database and parse them into openfga entities.
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
// `match` will filter the list for fuzzy find on identity name.
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
const op = errors.Op("jimm.ListIdentities")

if !user.JimmAdmin {
return nil, errors.E(op, errors.CodeUnauthorized, "unauthorized")
}
identities, err := j.Database.ListIdentities(ctx, pagination.Limit(), pagination.Offset(), match)
var users []openfga.User
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel like I'm reading that book aha. Love it.


var identities []openfga.User
err := j.Database.ForEachIdentity(ctx, filter.Limit(), filter.Offset(), func(ge *dbmodel.Identity) error {
u := openfga.NewUser(ge, j.OpenFGAClient)
identities = append(identities, *u)
return nil
})
for _, id := range identities {
users = append(users, *openfga.NewUser(&id, j.OpenFGAClient))
}
if err != nil {
return nil, errors.E(op, err)
}
return identities, nil
return users, nil
}

// CountIdentities returns the count of all the identities in our database.
Expand Down
16 changes: 12 additions & 4 deletions internal/jimm/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func TestListIdentities(t *testing.T) {
u := openfga.NewUser(&dbmodel.Identity{Name: "[email protected]"}, ofgaClient)
u.JimmAdmin = true

filter := pagination.NewOffsetFilter(10, 0)
users, err := j.ListIdentities(ctx, u, filter)
pag := pagination.NewOffsetFilter(10, 0)
users, err := j.ListIdentities(ctx, u, pag, "")
c.Assert(err, qt.IsNil)
c.Assert(len(users), qt.Equals, 0)

Expand All @@ -89,6 +89,7 @@ func TestListIdentities(t *testing.T) {
desc string
limit int
offset int
match string
identities []string
}{
{
Expand All @@ -109,11 +110,18 @@ func TestListIdentities(t *testing.T) {
offset: 6,
identities: []string{},
},
{
desc: "test with match",
limit: 5,
offset: 0,
identities: []string{userNames[0]},
match: "bob1",
},
}
for _, t := range testCases {
c.Run(t.desc, func(c *qt.C) {
filter = pagination.NewOffsetFilter(t.limit, t.offset)
identities, err := j.ListIdentities(ctx, u, filter)
pag = pagination.NewOffsetFilter(t.limit, t.offset)
identities, err := j.ListIdentities(ctx, u, pag, t.match)
c.Assert(err, qt.IsNil)
c.Assert(identities, qt.HasLen, len(t.identities))
for i := range len(t.identities) {
Expand Down
7 changes: 5 additions & 2 deletions internal/jimmhttp/rebac_admin/identities.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ func (s *identitiesService) ListIdentities(ctx context.Context, params *resource
return nil, err
}
page, nextPage, pagination := pagination.CreatePagination(params.Size, params.Page, count)

users, err := s.jimm.ListIdentities(ctx, user, pagination)
match := ""
if params.Filter != nil && *params.Filter != "" {
match = *params.Filter
}
users, err := s.jimm.ListIdentities(ctx, user, pagination, match)
if err != nil {
return nil, err
}
Expand Down
31 changes: 31 additions & 0 deletions internal/jimmhttp/rebac_admin/identities_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,37 @@ type identitiesSuite struct {

var _ = gc.Suite(&identitiesSuite{})

func (s *identitiesSuite) TestIdentitiesList(c *gc.C) {
ctx := context.Background()
ctx = rebac_handlers.ContextWithIdentity(ctx, s.AdminUser)
identitySvc := rebac_admin.NewidentitiesService(s.JIMM)
for i := range 5 {
user := names.NewUserTag(fmt.Sprintf("test-user-match-%[email protected]", i))
s.AddUser(c, user.Id())
}
pageSize := 5
page := 0
params := &resources.GetIdentitiesParams{Size: &pageSize, Page: &page}
res, err := identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(res.Meta.Size, gc.Equals, 5)

match := "test-user-match-1"
params.Filter = &match
res, err = identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(len(res.Data), gc.Equals, 1)

match = "test-user"
params.Filter = &match
res, err = identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(len(res.Data), gc.Equals, pageSize)
}

func (s *identitiesSuite) TestIdentityPatchGroups(c *gc.C) {
// initialization
ctx := context.Background()
Expand Down
6 changes: 3 additions & 3 deletions internal/jimmhttp/rebac_admin/identities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func TestListIdentities(t *testing.T) {
}
c := qt.New(t)
jimm := jimmtest.JIMM{
ListIdentities_: func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
start := filter.Offset()
end := start + filter.Limit()
ListIdentities_: func(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
start := pagination.Offset()
end := start + pagination.Limit()
if end > len(testUsers) {
end = len(testUsers)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/jujuapi/controllerroot.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type JIMM interface {
InitiateInternalMigration(ctx context.Context, user *openfga.User, modelTag names.ModelTag, targetController string) (jujuparams.InitiateMigrationResult, error)
InitiateMigration(ctx context.Context, user *openfga.User, spec jujuparams.MigrationSpec) (jujuparams.InitiateMigrationResult, error)
ListApplicationOffers(ctx context.Context, user *openfga.User, filters ...jujuparams.OfferFilter) ([]jujuparams.ApplicationOfferAdminDetailsV5, error)
ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error)
ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error)
ListResources(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination, namePrefixFilter, typeFilter string) ([]db.Resource, error)
Offer(ctx context.Context, user *openfga.User, offer jimm.AddApplicationOfferParams) error
PubSubHub() *pubsub.Hub
Expand Down
6 changes: 3 additions & 3 deletions internal/testutils/jimmtest/jimm_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type JIMM struct {
GetJimmControllerAccess_ func(ctx context.Context, user *openfga.User, tag names.UserTag) (string, error)
FetchIdentity_ func(ctx context.Context, username string) (*openfga.User, error)
CountIdentities_ func(ctx context.Context, user *openfga.User) (int, error)
ListIdentities_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error)
ListIdentities_ func(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error)
GetUserCloudAccess_ func(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error)
GetUserControllerAccess_ func(ctx context.Context, user *openfga.User, controller names.ControllerTag) (string, error)
GetUserModelAccess_ func(ctx context.Context, user *openfga.User, model names.ModelTag) (string, error)
Expand Down Expand Up @@ -227,11 +227,11 @@ func (j *JIMM) CountIdentities(ctx context.Context, user *openfga.User) (int, er
}
return j.CountIdentities_(ctx, user)
}
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
if j.ListIdentities_ == nil {
return nil, errors.E(errors.CodeNotImplemented)
}
return j.ListIdentities_(ctx, user, filter)
return j.ListIdentities_(ctx, user, pagination, match)
}
func (j *JIMM) GetUserCloudAccess(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error) {
if j.GetUserCloudAccess_ == nil {
Expand Down
Loading