Skip to content

Commit

Permalink
Juju 7074/identities filter (#1434)
Browse files Browse the repository at this point in the history
* filter identities

* add tests

* update godoc

* fix test

* rename params

* address pr comments

---------

Co-authored-by: Ales Stimec <[email protected]>
  • Loading branch information
SimoneDutto and alesstimec authored Nov 15, 2024
1 parent 8001a25 commit a2562bc
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 75 deletions.
41 changes: 14 additions & 27 deletions internal/db/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,30 @@ func (d *Database) GetIdentityCloudCredentials(ctx context.Context, u *dbmodel.I
return credentials, nil
}

// ForEachIdentity iterates through every identity calling the given function
// for each one. If the given function returns an error the iteration
// will stop immediately and the error will be returned unmodified.
func (d *Database) ForEachIdentity(ctx context.Context, limit, offset int, f func(*dbmodel.Identity) error) (err error) {
const op = errors.Op("db.ForEachUSer")
// ListIdentities returns a paginated list of identities defined by limit and offset.
// match is used to fuzzy find based on entries' name using the LIKE operator (ex. LIKE %<match>%).
func (d *Database) ListIdentities(ctx context.Context, limit, offset int, match string) (_ []dbmodel.Identity, err error) {
const op = errors.Op("db.ListIdentities")
if err := d.ready(); err != nil {
return errors.E(op, err)
return nil, errors.E(op, err)
}

durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op))
defer durationObserver()
defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op))

db := d.DB.WithContext(ctx)
rows, err := db.
Model(&dbmodel.Identity{}).
Order("name asc").
Limit(limit).
Offset(offset).
Rows()
if err != nil {
return errors.E(op, err)
if match != "" {
db = db.Where("name LIKE ?", "%"+match+"%")
}
defer rows.Close()
for rows.Next() {
var identity dbmodel.Identity
if err := db.ScanRows(rows, &identity); err != nil {
return errors.E(op, err)
}
if err := f(&identity); err != nil {
return err
}
db = db.Order("name asc")
db = db.Limit(limit)
db = db.Offset(offset)
var identities []dbmodel.Identity
if err := db.Find(&identities).Error; err != nil {
return nil, errors.E(op, dbError(err))
}
if err := rows.Err(); err != nil {
return errors.E(op, dbError(err))
}
return nil
return identities, nil
}

// CountIdentities counts the number of identities.
Expand Down
33 changes: 6 additions & 27 deletions internal/db/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *dbSuite) TestGetIdentityCloudCredentials(c *qt.C) {
c.Assert(credentials, qt.DeepEquals, []dbmodel.CloudCredential{cred1, cred2})
}

func (s *dbSuite) TestForEachIdentity(c *qt.C) {
func (s *dbSuite) TestListIdentities(c *qt.C) {
err := s.Database.Migrate(context.Background(), false)
c.Assert(err, qt.IsNil)

Expand All @@ -192,41 +192,20 @@ func (s *dbSuite) TestForEachIdentity(c *qt.C) {
err = s.Database.GetIdentity(context.Background(), id)
c.Assert(err, qt.IsNil)
}
firstIdentities := []*dbmodel.Identity{}
ctx := context.Background()
err = s.Database.ForEachIdentity(ctx, 5, 0, func(ge *dbmodel.Identity) error {
firstIdentities = append(firstIdentities, ge)
return nil
})
firstIdentities, err := s.Database.ListIdentities(ctx, 5, 0, "")
c.Assert(err, qt.IsNil)
for i := 0; i < 5; i++ {
c.Assert(firstIdentities[i].Name, qt.Equals, fmt.Sprintf("bob%[email protected]", i))
}
secondIdentities := []*dbmodel.Identity{}
err = s.Database.ForEachIdentity(ctx, 5, 5, func(ge *dbmodel.Identity) error {
secondIdentities = append(secondIdentities, ge)
return nil
})
secondIdentities, err := s.Database.ListIdentities(ctx, 5, 5, "")
c.Assert(err, qt.IsNil)
for i := 0; i < 5; i++ {
c.Assert(secondIdentities[i].Name, qt.Equals, fmt.Sprintf("bob%[email protected]", i+5))
}
}

func (s *dbSuite) TestForEachIdentityError(c *qt.C) {
err := s.Database.Migrate(context.Background(), false)
filteredIdentities, err := s.Database.ListIdentities(ctx, 5, 0, "bob0")
c.Assert(err, qt.IsNil)
ctx := context.Background()
// add one identity
id, _ := dbmodel.NewIdentity("[email protected]")
err = s.Database.GetIdentity(context.Background(), id)
c.Assert(err, qt.IsNil)

// test error is returned
errTest := errors.E("test-error")
err = s.Database.ForEachIdentity(ctx, 5, 0, func(ge *dbmodel.Identity) error {
return errTest
})
c.Assert(err, qt.IsNotNil)
c.Assert(err.Error(), qt.Equals, errTest.Error())
c.Assert(filteredIdentities, qt.HasLen, 1)
c.Assert(filteredIdentities[0].Name, qt.Equals, "[email protected]")
}
16 changes: 8 additions & 8 deletions internal/jimm/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ func (j *JIMM) FetchIdentity(ctx context.Context, id string) (*openfga.User, err
}

// ListIdentities lists a page of users in our database and parse them into openfga entities.
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
// `match` will filter the list for fuzzy find on identity name.
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
const op = errors.Op("jimm.ListIdentities")

if !user.JimmAdmin {
return nil, errors.E(op, errors.CodeUnauthorized, "unauthorized")
}
identities, err := j.Database.ListIdentities(ctx, pagination.Limit(), pagination.Offset(), match)
var users []openfga.User

var identities []openfga.User
err := j.Database.ForEachIdentity(ctx, filter.Limit(), filter.Offset(), func(ge *dbmodel.Identity) error {
u := openfga.NewUser(ge, j.OpenFGAClient)
identities = append(identities, *u)
return nil
})
for _, id := range identities {
users = append(users, *openfga.NewUser(&id, j.OpenFGAClient))
}
if err != nil {
return nil, errors.E(op, err)
}
return identities, nil
return users, nil
}

// CountIdentities returns the count of all the identities in our database.
Expand Down
16 changes: 12 additions & 4 deletions internal/jimm/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func TestListIdentities(t *testing.T) {
u := openfga.NewUser(&dbmodel.Identity{Name: "[email protected]"}, ofgaClient)
u.JimmAdmin = true

filter := pagination.NewOffsetFilter(10, 0)
users, err := j.ListIdentities(ctx, u, filter)
pag := pagination.NewOffsetFilter(10, 0)
users, err := j.ListIdentities(ctx, u, pag, "")
c.Assert(err, qt.IsNil)
c.Assert(len(users), qt.Equals, 0)

Expand All @@ -89,6 +89,7 @@ func TestListIdentities(t *testing.T) {
desc string
limit int
offset int
match string
identities []string
}{
{
Expand All @@ -109,11 +110,18 @@ func TestListIdentities(t *testing.T) {
offset: 6,
identities: []string{},
},
{
desc: "test with match",
limit: 5,
offset: 0,
identities: []string{userNames[0]},
match: "bob1",
},
}
for _, t := range testCases {
c.Run(t.desc, func(c *qt.C) {
filter = pagination.NewOffsetFilter(t.limit, t.offset)
identities, err := j.ListIdentities(ctx, u, filter)
pag = pagination.NewOffsetFilter(t.limit, t.offset)
identities, err := j.ListIdentities(ctx, u, pag, t.match)
c.Assert(err, qt.IsNil)
c.Assert(identities, qt.HasLen, len(t.identities))
for i := range len(t.identities) {
Expand Down
7 changes: 5 additions & 2 deletions internal/jimmhttp/rebac_admin/identities.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ func (s *identitiesService) ListIdentities(ctx context.Context, params *resource
return nil, err
}
page, nextPage, pagination := pagination.CreatePagination(params.Size, params.Page, count)

users, err := s.jimm.ListIdentities(ctx, user, pagination)
match := ""
if params.Filter != nil && *params.Filter != "" {
match = *params.Filter
}
users, err := s.jimm.ListIdentities(ctx, user, pagination, match)
if err != nil {
return nil, err
}
Expand Down
31 changes: 31 additions & 0 deletions internal/jimmhttp/rebac_admin/identities_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,37 @@ type identitiesSuite struct {

var _ = gc.Suite(&identitiesSuite{})

func (s *identitiesSuite) TestIdentitiesList(c *gc.C) {
ctx := context.Background()
ctx = rebac_handlers.ContextWithIdentity(ctx, s.AdminUser)
identitySvc := rebac_admin.NewidentitiesService(s.JIMM)
for i := range 5 {
user := names.NewUserTag(fmt.Sprintf("test-user-match-%[email protected]", i))
s.AddUser(c, user.Id())
}
pageSize := 5
page := 0
params := &resources.GetIdentitiesParams{Size: &pageSize, Page: &page}
res, err := identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(res.Meta.Size, gc.Equals, 5)

match := "test-user-match-1"
params.Filter = &match
res, err = identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(len(res.Data), gc.Equals, 1)

match = "test-user"
params.Filter = &match
res, err = identitySvc.ListIdentities(ctx, params)
c.Assert(err, gc.IsNil)
c.Assert(res, gc.Not(gc.IsNil))
c.Assert(len(res.Data), gc.Equals, pageSize)
}

func (s *identitiesSuite) TestIdentityPatchGroups(c *gc.C) {
// initialization
ctx := context.Background()
Expand Down
6 changes: 3 additions & 3 deletions internal/jimmhttp/rebac_admin/identities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func TestListIdentities(t *testing.T) {
}
c := qt.New(t)
jimm := jimmtest.JIMM{
ListIdentities_: func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
start := filter.Offset()
end := start + filter.Limit()
ListIdentities_: func(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
start := pagination.Offset()
end := start + pagination.Limit()
if end > len(testUsers) {
end = len(testUsers)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/jujuapi/controllerroot.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type JIMM interface {
InitiateInternalMigration(ctx context.Context, user *openfga.User, modelTag names.ModelTag, targetController string) (jujuparams.InitiateMigrationResult, error)
InitiateMigration(ctx context.Context, user *openfga.User, spec jujuparams.MigrationSpec) (jujuparams.InitiateMigrationResult, error)
ListApplicationOffers(ctx context.Context, user *openfga.User, filters ...jujuparams.OfferFilter) ([]jujuparams.ApplicationOfferAdminDetailsV5, error)
ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error)
ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error)
ListResources(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination, namePrefixFilter, typeFilter string) ([]db.Resource, error)
Offer(ctx context.Context, user *openfga.User, offer jimm.AddApplicationOfferParams) error
PubSubHub() *pubsub.Hub
Expand Down
6 changes: 3 additions & 3 deletions internal/testutils/jimmtest/jimm_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type JIMM struct {
GetJimmControllerAccess_ func(ctx context.Context, user *openfga.User, tag names.UserTag) (string, error)
FetchIdentity_ func(ctx context.Context, username string) (*openfga.User, error)
CountIdentities_ func(ctx context.Context, user *openfga.User) (int, error)
ListIdentities_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error)
ListIdentities_ func(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error)
GetUserCloudAccess_ func(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error)
GetUserControllerAccess_ func(ctx context.Context, user *openfga.User, controller names.ControllerTag) (string, error)
GetUserModelAccess_ func(ctx context.Context, user *openfga.User, model names.ModelTag) (string, error)
Expand Down Expand Up @@ -227,11 +227,11 @@ func (j *JIMM) CountIdentities(ctx context.Context, user *openfga.User) (int, er
}
return j.CountIdentities_(ctx, user)
}
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]openfga.User, error) {
func (j *JIMM) ListIdentities(ctx context.Context, user *openfga.User, pagination pagination.LimitOffsetPagination, match string) ([]openfga.User, error) {
if j.ListIdentities_ == nil {
return nil, errors.E(errors.CodeNotImplemented)
}
return j.ListIdentities_(ctx, user, filter)
return j.ListIdentities_(ctx, user, pagination, match)
}
func (j *JIMM) GetUserCloudAccess(ctx context.Context, user *openfga.User, cloud names.CloudTag) (string, error) {
if j.GetUserCloudAccess_ == nil {
Expand Down

0 comments on commit a2562bc

Please sign in to comment.