Skip to content

Commit

Permalink
Merge pull request #1378 from pkulik0/get-group-by-name
Browse files Browse the repository at this point in the history
Get group by name
  • Loading branch information
pkulik0 authored Oct 1, 2024
2 parents 85c03a9 + 5929ff6 commit 07af23c
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 33 deletions.
21 changes: 15 additions & 6 deletions internal/jimm/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,18 +717,27 @@ func (j *JIMM) CountGroups(ctx context.Context, user *openfga.User) (int, error)
return count, nil
}

// GetGroup returns a group based on the provided UUID.
func (j *JIMM) GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
const op = errors.Op("jimm.AddGroup")
// getGroup returns a group based on the provided UUID or name.
func (j *JIMM) getGroup(ctx context.Context, user *openfga.User, group *dbmodel.GroupEntry) (*dbmodel.GroupEntry, error) {
const op = errors.Op("jimm.getGroup")

if !user.JimmAdmin {
return nil, errors.E(op, errors.CodeUnauthorized, "unauthorized")
}
group := dbmodel.GroupEntry{UUID: uuid}
if err := j.Database.GetGroup(ctx, &group); err != nil {
if err := j.Database.GetGroup(ctx, group); err != nil {
return nil, errors.E(op, err)
}
return &group, nil
return group, nil
}

// GetGroupByUUID returns a group based on the provided UUID.
func (j *JIMM) GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
return j.getGroup(ctx, user, &dbmodel.GroupEntry{UUID: uuid})
}

// GetGroupByName returns a group based on the provided name.
func (j *JIMM) GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) {
return j.getGroup(ctx, user, &dbmodel.GroupEntry{Name: name})
}

// RenameGroup renames a group in JIMM's DB.
Expand Down
16 changes: 13 additions & 3 deletions internal/jimm/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ func TestCountGroups(t *testing.T) {
c.Assert(errors.ErrorCode(err), qt.Equals, errors.CodeAlreadyExists)
}

func TestGetGroupByID(t *testing.T) {
func TestGetGroup(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

Expand Down Expand Up @@ -834,9 +834,19 @@ func TestGetGroupByID(t *testing.T) {
c.Assert(err, qt.IsNil)
c.Assert(groupEntry.UUID, qt.Not(qt.Equals), "")

gotGroup, err := j.GetGroupByID(ctx, u, groupEntry.UUID)
gotGroupUuid, err := j.GetGroupByUUID(ctx, u, groupEntry.UUID)
c.Assert(err, qt.IsNil)
c.Assert(gotGroup, qt.DeepEquals, groupEntry)
c.Assert(gotGroupUuid, qt.DeepEquals, groupEntry)

gotGroupName, err := j.GetGroupByName(ctx, u, groupEntry.Name)
c.Assert(err, qt.IsNil)
c.Assert(gotGroupName, qt.DeepEquals, groupEntry)

_, err = j.GetGroupByUUID(ctx, u, "non-existent")
c.Assert(err, qt.Not(qt.IsNil))

_, err = j.GetGroupByName(ctx, u, "non-existent")
c.Assert(err, qt.Not(qt.IsNil))
}

func TestRemoveGroup(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion internal/jimm/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ func (act *addControllerTransactor) setCloudRegionControllerPriorities(cloud dbm

act.controller.CloudRegions = append(act.controller.CloudRegions, dbmodel.CloudRegionControllerPriority{
CloudRegion: reg,
Priority: uint(priority),
//nolint:gosec
Priority: uint(priority),
})
}
}
Expand Down
26 changes: 17 additions & 9 deletions internal/jimmtest/mocks/jimm_group_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ import (

// GroupService is an implementation of the jujuapi.GroupService interface.
type GroupService struct {
AddGroup_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error)
CountGroups_ func(ctx context.Context, user *openfga.User) (int, error)
GetGroupByID_ func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error)
ListGroups_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error)
RenameGroup_ func(ctx context.Context, user *openfga.User, oldName, newName string) error
RemoveGroup_ func(ctx context.Context, user *openfga.User, name string) error
AddGroup_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error)
CountGroups_ func(ctx context.Context, user *openfga.User) (int, error)
GetGroupByUUID_ func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error)
GetGroupByName_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error)
ListGroups_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error)
RenameGroup_ func(ctx context.Context, user *openfga.User, oldName, newName string) error
RemoveGroup_ func(ctx context.Context, user *openfga.User, name string) error
}

func (j *GroupService) AddGroup(ctx context.Context, u *openfga.User, name string) (*dbmodel.GroupEntry, error) {
Expand All @@ -40,11 +41,18 @@ func (j *GroupService) CountGroups(ctx context.Context, user *openfga.User) (int
return j.CountGroups_(ctx, user)
}

func (j *GroupService) GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
if j.GetGroupByID_ == nil {
func (j *GroupService) GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
if j.GetGroupByUUID_ == nil {
return nil, errors.E(errors.CodeNotImplemented)
}
return j.GetGroupByID_(ctx, user, uuid)
return j.GetGroupByUUID_(ctx, user, uuid)
}

func (j *GroupService) GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) {
if j.GetGroupByName_ == nil {
return nil, errors.E(errors.CodeNotImplemented)
}
return j.GetGroupByName_(ctx, user, name)
}

func (j *GroupService) ListGroups(ctx context.Context, user *openfga.User, filters pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error) {
Expand Down
19 changes: 16 additions & 3 deletions internal/jujuapi/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const (
type GroupService interface {
AddGroup(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error)
CountGroups(ctx context.Context, user *openfga.User) (int, error)
GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error)
GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error)
GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error)
ListGroups(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error)
RenameGroup(ctx context.Context, user *openfga.User, oldName, newName string) error
RemoveGroup(ctx context.Context, user *openfga.User, name string) error
Expand Down Expand Up @@ -58,15 +59,27 @@ func (r *controllerRoot) AddGroup(ctx context.Context, req apiparams.AddGroupReq
return resp, nil
}

// GetGroup returns group information based on a group ID.
// GetGroup returns group information based on a UUID or name.
func (r *controllerRoot) GetGroup(ctx context.Context, req apiparams.GetGroupRequest) (apiparams.Group, error) {
const op = errors.Op("jujuapi.GetGroup")

groupEntry, err := r.jimm.GetGroupByID(ctx, r.user, req.UUID)
var groupEntry *dbmodel.GroupEntry
var err error
switch {
case req.UUID != "" && req.Name != "":
return apiparams.Group{}, errors.E(op, errors.CodeBadRequest, "only one of UUID or Name should be provided")
case req.UUID != "":
groupEntry, err = r.jimm.GetGroupByUUID(ctx, r.user, req.UUID)
case req.Name != "":
groupEntry, err = r.jimm.GetGroupByName(ctx, r.user, req.Name)
default:
return apiparams.Group{}, errors.E(op, errors.CodeBadRequest, "no UUID or Name provided")
}
if err != nil {
zapctx.Error(ctx, "failed to get group", zaputil.Error(err))
return apiparams.Group{}, errors.E(op, err)
}

return apiparams.Group{
UUID: groupEntry.UUID,
Name: groupEntry.Name,
Expand Down
11 changes: 9 additions & 2 deletions internal/jujuapi/access_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,19 @@ func (s *accessControlSuite) TestGetGroup(c *gc.C) {
created, err := client.AddGroup(&apiparams.AddGroupRequest{Name: "test-group"})
c.Assert(err, jc.ErrorIsNil)

retrieved, err := client.GetGroup(&apiparams.GetGroupRequest{UUID: created.UUID})
retrievedUuid, err := client.GetGroup(&apiparams.GetGroupRequest{UUID: created.UUID})
c.Assert(err, jc.ErrorIsNil)
c.Assert(retrieved.Group, gc.DeepEquals, created.Group)
c.Assert(retrievedUuid.Group, gc.DeepEquals, created.Group)

retrievedName, err := client.GetGroup(&apiparams.GetGroupRequest{Name: created.Name})
c.Assert(err, jc.ErrorIsNil)
c.Assert(retrievedName.Group, gc.DeepEquals, created.Group)

_, err = client.GetGroup(&apiparams.GetGroupRequest{UUID: "non-existent"})
c.Assert(err, gc.ErrorMatches, ".*not found.*")

_, err = client.GetGroup(&apiparams.GetGroupRequest{Name: created.Name, UUID: created.UUID})
c.Assert(err, gc.ErrorMatches, ".*only one of.*")
}

func (s *accessControlSuite) TestRemoveGroup(c *gc.C) {
Expand Down
8 changes: 4 additions & 4 deletions internal/rebac_admin/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (s *groupsService) GetGroup(ctx context.Context, groupId string) (*resource
if err != nil {
return nil, err
}
group, err := s.jimm.GetGroupByID(ctx, user, groupId)
group, err := s.jimm.GetGroupByUUID(ctx, user, groupId)
if err != nil {
if errors.ErrorCode(err) == errors.CodeNotFound {
return nil, v1.NewNotFoundError("failed to find group")
Expand All @@ -100,7 +100,7 @@ func (s *groupsService) UpdateGroup(ctx context.Context, group *resources.Group)
if group.Id == nil {
return nil, v1.NewValidationError("missing group ID")
}
existingGroup, err := s.jimm.GetGroupByID(ctx, user, *group.Id)
existingGroup, err := s.jimm.GetGroupByUUID(ctx, user, *group.Id)
if err != nil {
if errors.ErrorCode(err) == errors.CodeNotFound {
return nil, v1.NewNotFoundError("failed to find group")
Expand All @@ -123,7 +123,7 @@ func (s *groupsService) DeleteGroup(ctx context.Context, groupId string) (bool,
if err != nil {
return false, err
}
existingGroup, err := s.jimm.GetGroupByID(ctx, user, groupId)
existingGroup, err := s.jimm.GetGroupByUUID(ctx, user, groupId)
if err != nil {
if errors.ErrorCode(err) == errors.CodeNotFound {
return false, nil
Expand All @@ -148,7 +148,7 @@ func (s *groupsService) GetGroupIdentities(ctx context.Context, groupId string,
}
filter := utils.CreateTokenPaginationFilter(params.Size, params.NextToken, params.NextPageToken)
groupTag := jimmnames.NewGroupTag(groupId)
_, err = s.jimm.GetGroupByID(ctx, user, groupId)
_, err = s.jimm.GetGroupByUUID(ctx, user, groupId)
if err != nil {
if errors.ErrorCode(err) == errors.CodeNotFound {
return nil, v1.NewNotFoundError("group not found")
Expand Down
6 changes: 3 additions & 3 deletions internal/rebac_admin/groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestUpdateGroup(t *testing.T) {
var renameErr error
jimm := jimmtest.JIMM{
GroupService: mocks.GroupService{
GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
return &dbmodel.GroupEntry{UUID: groupID, Name: "test-group"}, nil
},
RenameGroup_: func(ctx context.Context, user *openfga.User, oldName, newName string) error {
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestDeleteGroup(t *testing.T) {
var deleteErr error
jimm := jimmtest.JIMM{
GroupService: mocks.GroupService{
GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
return &dbmodel.GroupEntry{UUID: uuid, Name: "test-group"}, nil
},
RemoveGroup_: func(ctx context.Context, user *openfga.User, name string) error {
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestGetGroupIdentities(t *testing.T) {
}
jimm := jimmtest.JIMM{
GroupService: mocks.GroupService{
GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) {
return nil, getGroupErr
},
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (c *Client) AddGroup(req *params.AddGroupRequest) (params.AddGroupResponse,
return resp, err
}

// GetGroup returns the group with the given UUID.
// GetGroup returns the group with the given UUID or name. Only one should be provided.
func (c *Client) GetGroup(req *params.GetGroupRequest) (params.GetGroupResponse, error) {
var resp params.GetGroupResponse
err := c.caller.APICall("JIMM", 4, "", "GetGroup", req, &resp)
Expand Down
4 changes: 3 additions & 1 deletion pkg/api/params/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ type AddGroupResponse struct {
Group
}

// GetGroupRequest holds a request to get a group.
// GetGroupRequest holds a request to get a group by UUID or name.
type GetGroupRequest struct {
// UUID holds the UUID of the group to be retrieved.
UUID string `json:"uuid"`
// Name holds the name of the group to be retrieved.
Name string `json:"name"`
}

// GetGroupResponse holds the details of the group.
Expand Down

0 comments on commit 07af23c

Please sign in to comment.