diff --git a/internal/jimm/access.go b/internal/jimm/access.go index bf3d20940..f9a2b2373 100644 --- a/internal/jimm/access.go +++ b/internal/jimm/access.go @@ -717,18 +717,27 @@ func (j *JIMM) CountGroups(ctx context.Context, user *openfga.User) (int, error) return count, nil } -// GetGroup returns a group based on the provided UUID. -func (j *JIMM) GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { - const op = errors.Op("jimm.AddGroup") +// getGroup returns a group based on the provided UUID or name. +func (j *JIMM) getGroup(ctx context.Context, user *openfga.User, group *dbmodel.GroupEntry) (*dbmodel.GroupEntry, error) { + const op = errors.Op("jimm.getGroup") if !user.JimmAdmin { return nil, errors.E(op, errors.CodeUnauthorized, "unauthorized") } - group := dbmodel.GroupEntry{UUID: uuid} - if err := j.Database.GetGroup(ctx, &group); err != nil { + if err := j.Database.GetGroup(ctx, group); err != nil { return nil, errors.E(op, err) } - return &group, nil + return group, nil +} + +// GetGroupByUUID returns a group based on the provided UUID. +func (j *JIMM) GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + return j.getGroup(ctx, user, &dbmodel.GroupEntry{UUID: uuid}) +} + +// GetGroupByName returns a group based on the provided name. +func (j *JIMM) GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) { + return j.getGroup(ctx, user, &dbmodel.GroupEntry{Name: name}) } // RenameGroup renames a group in JIMM's DB. diff --git a/internal/jimm/access_test.go b/internal/jimm/access_test.go index 509a452ad..9e0975869 100644 --- a/internal/jimm/access_test.go +++ b/internal/jimm/access_test.go @@ -806,7 +806,7 @@ func TestCountGroups(t *testing.T) { c.Assert(errors.ErrorCode(err), qt.Equals, errors.CodeAlreadyExists) } -func TestGetGroupByID(t *testing.T) { +func TestGetGroup(t *testing.T) { c := qt.New(t) ctx := context.Background() @@ -834,9 +834,19 @@ func TestGetGroupByID(t *testing.T) { c.Assert(err, qt.IsNil) c.Assert(groupEntry.UUID, qt.Not(qt.Equals), "") - gotGroup, err := j.GetGroupByID(ctx, u, groupEntry.UUID) + gotGroupUuid, err := j.GetGroupByUUID(ctx, u, groupEntry.UUID) c.Assert(err, qt.IsNil) - c.Assert(gotGroup, qt.DeepEquals, groupEntry) + c.Assert(gotGroupUuid, qt.DeepEquals, groupEntry) + + gotGroupName, err := j.GetGroupByName(ctx, u, groupEntry.Name) + c.Assert(err, qt.IsNil) + c.Assert(gotGroupName, qt.DeepEquals, groupEntry) + + _, err = j.GetGroupByUUID(ctx, u, "non-existent") + c.Assert(err, qt.Not(qt.IsNil)) + + _, err = j.GetGroupByName(ctx, u, "non-existent") + c.Assert(err, qt.Not(qt.IsNil)) } func TestRemoveGroup(t *testing.T) { diff --git a/internal/jimm/controller.go b/internal/jimm/controller.go index a81198d99..3916008f8 100644 --- a/internal/jimm/controller.go +++ b/internal/jimm/controller.go @@ -161,7 +161,8 @@ func (act *addControllerTransactor) setCloudRegionControllerPriorities(cloud dbm act.controller.CloudRegions = append(act.controller.CloudRegions, dbmodel.CloudRegionControllerPriority{ CloudRegion: reg, - Priority: uint(priority), + //nolint:gosec + Priority: uint(priority), }) } } diff --git a/internal/jimmtest/mocks/jimm_group_mock.go b/internal/jimmtest/mocks/jimm_group_mock.go index 8065635d5..4259405f1 100644 --- a/internal/jimmtest/mocks/jimm_group_mock.go +++ b/internal/jimmtest/mocks/jimm_group_mock.go @@ -18,12 +18,13 @@ import ( // GroupService is an implementation of the jujuapi.GroupService interface. type GroupService struct { - AddGroup_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) - CountGroups_ func(ctx context.Context, user *openfga.User) (int, error) - GetGroupByID_ func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) - ListGroups_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error) - RenameGroup_ func(ctx context.Context, user *openfga.User, oldName, newName string) error - RemoveGroup_ func(ctx context.Context, user *openfga.User, name string) error + AddGroup_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) + CountGroups_ func(ctx context.Context, user *openfga.User) (int, error) + GetGroupByUUID_ func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) + GetGroupByName_ func(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) + ListGroups_ func(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error) + RenameGroup_ func(ctx context.Context, user *openfga.User, oldName, newName string) error + RemoveGroup_ func(ctx context.Context, user *openfga.User, name string) error } func (j *GroupService) AddGroup(ctx context.Context, u *openfga.User, name string) (*dbmodel.GroupEntry, error) { @@ -40,11 +41,18 @@ func (j *GroupService) CountGroups(ctx context.Context, user *openfga.User) (int return j.CountGroups_(ctx, user) } -func (j *GroupService) GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { - if j.GetGroupByID_ == nil { +func (j *GroupService) GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + if j.GetGroupByUUID_ == nil { return nil, errors.E(errors.CodeNotImplemented) } - return j.GetGroupByID_(ctx, user, uuid) + return j.GetGroupByUUID_(ctx, user, uuid) +} + +func (j *GroupService) GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) { + if j.GetGroupByName_ == nil { + return nil, errors.E(errors.CodeNotImplemented) + } + return j.GetGroupByName_(ctx, user, name) } func (j *GroupService) ListGroups(ctx context.Context, user *openfga.User, filters pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error) { diff --git a/internal/jujuapi/access_control.go b/internal/jujuapi/access_control.go index 8218734a6..ba05b312e 100644 --- a/internal/jujuapi/access_control.go +++ b/internal/jujuapi/access_control.go @@ -28,7 +28,8 @@ const ( type GroupService interface { AddGroup(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) CountGroups(ctx context.Context, user *openfga.User) (int, error) - GetGroupByID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) + GetGroupByUUID(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) + GetGroupByName(ctx context.Context, user *openfga.User, name string) (*dbmodel.GroupEntry, error) ListGroups(ctx context.Context, user *openfga.User, filter pagination.LimitOffsetPagination) ([]dbmodel.GroupEntry, error) RenameGroup(ctx context.Context, user *openfga.User, oldName, newName string) error RemoveGroup(ctx context.Context, user *openfga.User, name string) error @@ -58,15 +59,27 @@ func (r *controllerRoot) AddGroup(ctx context.Context, req apiparams.AddGroupReq return resp, nil } -// GetGroup returns group information based on a group ID. +// GetGroup returns group information based on a UUID or name. func (r *controllerRoot) GetGroup(ctx context.Context, req apiparams.GetGroupRequest) (apiparams.Group, error) { const op = errors.Op("jujuapi.GetGroup") - groupEntry, err := r.jimm.GetGroupByID(ctx, r.user, req.UUID) + var groupEntry *dbmodel.GroupEntry + var err error + switch { + case req.UUID != "" && req.Name != "": + return apiparams.Group{}, errors.E(op, errors.CodeBadRequest, "only one of UUID or Name should be provided") + case req.UUID != "": + groupEntry, err = r.jimm.GetGroupByUUID(ctx, r.user, req.UUID) + case req.Name != "": + groupEntry, err = r.jimm.GetGroupByName(ctx, r.user, req.Name) + default: + return apiparams.Group{}, errors.E(op, errors.CodeBadRequest, "no UUID or Name provided") + } if err != nil { zapctx.Error(ctx, "failed to get group", zaputil.Error(err)) return apiparams.Group{}, errors.E(op, err) } + return apiparams.Group{ UUID: groupEntry.UUID, Name: groupEntry.Name, diff --git a/internal/jujuapi/access_control_test.go b/internal/jujuapi/access_control_test.go index 544a97fed..f1c3127b4 100644 --- a/internal/jujuapi/access_control_test.go +++ b/internal/jujuapi/access_control_test.go @@ -70,12 +70,19 @@ func (s *accessControlSuite) TestGetGroup(c *gc.C) { created, err := client.AddGroup(&apiparams.AddGroupRequest{Name: "test-group"}) c.Assert(err, jc.ErrorIsNil) - retrieved, err := client.GetGroup(&apiparams.GetGroupRequest{UUID: created.UUID}) + retrievedUuid, err := client.GetGroup(&apiparams.GetGroupRequest{UUID: created.UUID}) c.Assert(err, jc.ErrorIsNil) - c.Assert(retrieved.Group, gc.DeepEquals, created.Group) + c.Assert(retrievedUuid.Group, gc.DeepEquals, created.Group) + + retrievedName, err := client.GetGroup(&apiparams.GetGroupRequest{Name: created.Name}) + c.Assert(err, jc.ErrorIsNil) + c.Assert(retrievedName.Group, gc.DeepEquals, created.Group) _, err = client.GetGroup(&apiparams.GetGroupRequest{UUID: "non-existent"}) c.Assert(err, gc.ErrorMatches, ".*not found.*") + + _, err = client.GetGroup(&apiparams.GetGroupRequest{Name: created.Name, UUID: created.UUID}) + c.Assert(err, gc.ErrorMatches, ".*only one of.*") } func (s *accessControlSuite) TestRemoveGroup(c *gc.C) { diff --git a/internal/rebac_admin/groups.go b/internal/rebac_admin/groups.go index 135a00cdc..70fbc1772 100644 --- a/internal/rebac_admin/groups.go +++ b/internal/rebac_admin/groups.go @@ -81,7 +81,7 @@ func (s *groupsService) GetGroup(ctx context.Context, groupId string) (*resource if err != nil { return nil, err } - group, err := s.jimm.GetGroupByID(ctx, user, groupId) + group, err := s.jimm.GetGroupByUUID(ctx, user, groupId) if err != nil { if errors.ErrorCode(err) == errors.CodeNotFound { return nil, v1.NewNotFoundError("failed to find group") @@ -100,7 +100,7 @@ func (s *groupsService) UpdateGroup(ctx context.Context, group *resources.Group) if group.Id == nil { return nil, v1.NewValidationError("missing group ID") } - existingGroup, err := s.jimm.GetGroupByID(ctx, user, *group.Id) + existingGroup, err := s.jimm.GetGroupByUUID(ctx, user, *group.Id) if err != nil { if errors.ErrorCode(err) == errors.CodeNotFound { return nil, v1.NewNotFoundError("failed to find group") @@ -123,7 +123,7 @@ func (s *groupsService) DeleteGroup(ctx context.Context, groupId string) (bool, if err != nil { return false, err } - existingGroup, err := s.jimm.GetGroupByID(ctx, user, groupId) + existingGroup, err := s.jimm.GetGroupByUUID(ctx, user, groupId) if err != nil { if errors.ErrorCode(err) == errors.CodeNotFound { return false, nil @@ -148,7 +148,7 @@ func (s *groupsService) GetGroupIdentities(ctx context.Context, groupId string, } filter := utils.CreateTokenPaginationFilter(params.Size, params.NextToken, params.NextPageToken) groupTag := jimmnames.NewGroupTag(groupId) - _, err = s.jimm.GetGroupByID(ctx, user, groupId) + _, err = s.jimm.GetGroupByUUID(ctx, user, groupId) if err != nil { if errors.ErrorCode(err) == errors.CodeNotFound { return nil, v1.NewNotFoundError("group not found") diff --git a/internal/rebac_admin/groups_test.go b/internal/rebac_admin/groups_test.go index 3fc964bbb..9dd1598d2 100644 --- a/internal/rebac_admin/groups_test.go +++ b/internal/rebac_admin/groups_test.go @@ -51,7 +51,7 @@ func TestUpdateGroup(t *testing.T) { var renameErr error jimm := jimmtest.JIMM{ GroupService: mocks.GroupService{ - GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { return &dbmodel.GroupEntry{UUID: groupID, Name: "test-group"}, nil }, RenameGroup_: func(ctx context.Context, user *openfga.User, oldName, newName string) error { @@ -120,7 +120,7 @@ func TestDeleteGroup(t *testing.T) { var deleteErr error jimm := jimmtest.JIMM{ GroupService: mocks.GroupService{ - GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { return &dbmodel.GroupEntry{UUID: uuid, Name: "test-group"}, nil }, RemoveGroup_: func(ctx context.Context, user *openfga.User, name string) error { @@ -155,7 +155,7 @@ func TestGetGroupIdentities(t *testing.T) { } jimm := jimmtest.JIMM{ GroupService: mocks.GroupService{ - GetGroupByID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { return nil, getGroupErr }, }, diff --git a/pkg/api/client.go b/pkg/api/client.go index 79307dc9a..bf4818441 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -123,7 +123,7 @@ func (c *Client) AddGroup(req *params.AddGroupRequest) (params.AddGroupResponse, return resp, err } -// GetGroup returns the group with the given UUID. +// GetGroup returns the group with the given UUID or name. Only one should be provided. func (c *Client) GetGroup(req *params.GetGroupRequest) (params.GetGroupResponse, error) { var resp params.GetGroupResponse err := c.caller.APICall("JIMM", 4, "", "GetGroup", req, &resp) diff --git a/pkg/api/params/params.go b/pkg/api/params/params.go index 60aa855bd..93c38fd56 100644 --- a/pkg/api/params/params.go +++ b/pkg/api/params/params.go @@ -285,10 +285,12 @@ type AddGroupResponse struct { Group } -// GetGroupRequest holds a request to get a group. +// GetGroupRequest holds a request to get a group by UUID or name. type GetGroupRequest struct { // UUID holds the UUID of the group to be retrieved. UUID string `json:"uuid"` + // Name holds the name of the group to be retrieved. + Name string `json:"name"` } // GetGroupResponse holds the details of the group.