Skip to content

Commit

Permalink
address follow-up comments from previous PR
Browse files Browse the repository at this point in the history
  • Loading branch information
kian99 committed Aug 22, 2024
1 parent 0b38d20 commit f145d67
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 24 deletions.
8 changes: 5 additions & 3 deletions internal/common/pagination/entitlement.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ func (c *comboToken) MarshalToken() (string, error) {
func (c *comboToken) UnmarshalToken(token string) error {
out, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return fmt.Errorf("marshal entitlement token: %w", err)
return fmt.Errorf("unmarshal entitlement token: %w", err)
}

return json.Unmarshal(out, c)
if err := json.Unmarshal(out, c); err != nil {
return fmt.Errorf("failed to unmarshal combo token: %w", err)
}
return nil
}
7 changes: 6 additions & 1 deletion internal/common/pagination/entitlement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ func TestUnmarshalEntitlementToken(t *testing.T) {
{
desc: "Invalid token",
in: "abc",
expectedError: "marshal entitlement token: illegal base64 data at input byte 0",
expectedError: "unmarshal entitlement token: illegal base64 data at input byte 0",
},
{
desc: "Invalid JSON in valid Base64 string",
in: "c29tZSBpbnZhbGlkIHRva2VuCg==",
expectedError: "failed to unmarshal combo token: invalid character 's' looking for beginning of value",
},
}

Expand Down
2 changes: 1 addition & 1 deletion internal/jimm/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (j *JIMM) ListRelationshipTuples(ctx context.Context, user *openfga.User, t
//
// This functions provides a slightly higher-level abstraction in favor of ListRelationshipTuples.
func (j *JIMM) ListObjectRelations(ctx context.Context, user *openfga.User, object string, pageSize int32, entitlementToken pagination.EntitlementToken) ([]openfga.Tuple, pagination.EntitlementToken, error) {
const op = errors.Op("jimm.ListRelationshipTuples")
const op = errors.Op("jimm.ListObjectRelations")
var e pagination.EntitlementToken
if !user.JimmAdmin {
return nil, e, errors.E(op, errors.CodeUnauthorized, "unauthorized")
Expand Down
30 changes: 18 additions & 12 deletions internal/rebac_admin/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,19 @@ func (s *groupsService) GetGroupIdentities(ctx context.Context, groupId string,
data = append(data, resources.Identity{Email: identifier})
}
originalToken := filter.Token()
return &resources.PaginatedResponse[resources.Identity]{
resp := resources.PaginatedResponse[resources.Identity]{
Meta: resources.ResponseMeta{
Size: len(data),
PageToken: &originalToken,
},
Next: resources.Next{
PageToken: &nextToken,
},
Data: data,
}, nil
}
if nextToken != "" {
resp.Next = resources.Next{
PageToken: &nextToken,
}
}
return &resp, nil
}

// PatchGroupIdentities performs addition or removal of identities to/from a Group identified by `groupId`.
Expand Down Expand Up @@ -256,17 +259,20 @@ func (s *groupsService) GetGroupEntitlements(ctx context.Context, groupId string
return nil, err
}
originalToken := filter.Token()
nextToken := nextEntitlmentToken.String()
return &resources.PaginatedResponse[resources.EntityEntitlement]{
resp := resources.PaginatedResponse[resources.EntityEntitlement]{
Meta: resources.ResponseMeta{
Size: len(tuples),
PageToken: &originalToken,
},
Next: resources.Next{
PageToken: &nextToken,
},
Data: utils.ToEntityEntitlements(tuples),
}, nil
}
if nextEntitlmentToken.String() != "" {
nextToken := nextEntitlmentToken.String()
resp.Next = resources.Next{
PageToken: &nextToken,
}
}
return &resp, nil
}

// PatchGroupEntitlements performs addition or removal of an Entitlement to/from a Group identified by `groupId`.
Expand All @@ -280,7 +286,7 @@ func (s *groupsService) PatchGroupEntitlements(ctx context.Context, groupId stri
}
groupTag := jimmnames.NewGroupTag(groupId)
tuple := apiparams.RelationshipTuple{
Object: ofganames.WithMemberRelation(groupTag.String()),
Object: ofganames.WithMemberRelation(groupTag),
}
var toRemove []apiparams.RelationshipTuple
var toAdd []apiparams.RelationshipTuple
Expand Down
35 changes: 28 additions & 7 deletions internal/rebac_admin/groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func TestGetGroupIdentities(t *testing.T) {
c := qt.New(t)
var listTuplesErr error
var getGroupErr error
var continuationToken string
testTuple := openfga.Tuple{
Object: &ofga.Entity{Kind: "user", ID: "foo"},
Relation: ofga.Relation("member"),
Expand All @@ -159,8 +160,8 @@ func TestGetGroupIdentities(t *testing.T) {
},
},
RelationService: mocks.RelationService{
ListRelationshipTuples_: func(ctx context.Context, user *openfga.User, tuple params.RelationshipTuple, pageSize int32, continuationToken string) ([]openfga.Tuple, string, error) {
return []openfga.Tuple{testTuple}, "continuation-token", listTuplesErr
ListRelationshipTuples_: func(ctx context.Context, user *openfga.User, tuple params.RelationshipTuple, pageSize int32, ct string) ([]openfga.Tuple, string, error) {
return []openfga.Tuple{testTuple}, continuationToken, listTuplesErr
},
},
}
Expand All @@ -178,12 +179,19 @@ func TestGetGroupIdentities(t *testing.T) {
c.Assert(err, qt.ErrorMatches, ".*group doesn't exist")
getGroupErr = nil

continuationToken = "continuation-token"
res, err := groupSvc.GetGroupIdentities(ctx, newUUID.String(), &resources.GetGroupsItemIdentitiesParams{})
c.Assert(err, qt.IsNil)
c.Assert(res, qt.IsNotNil)
c.Assert(res.Data, qt.HasLen, 1)
c.Assert(*res.Next.PageToken, qt.Equals, "continuation-token")

continuationToken = ""
res, err = groupSvc.GetGroupIdentities(ctx, newUUID.String(), &resources.GetGroupsItemIdentitiesParams{})
c.Assert(err, qt.IsNil)
c.Assert(res, qt.IsNotNil)
c.Assert(res.Next.PageToken, qt.IsNil)

listTuplesErr = errors.New("foo")
_, err = groupSvc.GetGroupIdentities(ctx, newUUID.String(), &resources.GetGroupsItemIdentitiesParams{})
c.Assert(err, qt.ErrorMatches, "foo")
Expand Down Expand Up @@ -232,15 +240,17 @@ func TestPatchGroupIdentities(t *testing.T) {

func TestGetGroupEntitlements(t *testing.T) {
c := qt.New(t)
var listRelationsErr error
var continuationToken string
testTuple := openfga.Tuple{
Object: &ofga.Entity{Kind: "user", ID: "foo"},
Relation: ofga.Relation("member"),
Target: &ofga.Entity{Kind: "group", ID: "my-group"},
}
jimm := jimmtest.JIMM{
RelationService: mocks.RelationService{
ListObjectRelations_: func(ctx context.Context, user *openfga.User, object string, pageSize int32, continuationToken pagination.EntitlementToken) ([]openfga.Tuple, pagination.EntitlementToken, error) {
return []openfga.Tuple{testTuple}, pagination.NewEntitlementToken("next-page-token"), nil
ListObjectRelations_: func(ctx context.Context, user *openfga.User, object string, pageSize int32, ct pagination.EntitlementToken) ([]openfga.Tuple, pagination.EntitlementToken, error) {
return []openfga.Tuple{testTuple}, pagination.NewEntitlementToken(continuationToken), listRelationsErr
},
},
}
Expand All @@ -252,16 +262,27 @@ func TestGetGroupEntitlements(t *testing.T) {
_, err := groupSvc.GetGroupEntitlements(ctx, "invalid-group-id", nil)
c.Assert(err, qt.ErrorMatches, ".* invalid group ID")

continuationToken = "random-token"
res, err := groupSvc.GetGroupEntitlements(ctx, uuid.New().String(), &resources.GetGroupsItemEntitlementsParams{})
c.Assert(err, qt.IsNil)
c.Assert(res, qt.IsNotNil)
c.Assert(res.Data, qt.HasLen, 1)
c.Assert(res.Next.PageToken, qt.Not(qt.Equals), "")
c.Assert(*res.Next.PageToken, qt.Equals, "random-token")

continuationToken = ""
res, err = groupSvc.GetGroupEntitlements(ctx, uuid.New().String(), &resources.GetGroupsItemEntitlementsParams{})
c.Assert(err, qt.IsNil)
c.Assert(res, qt.IsNotNil)
c.Assert(res.Next.PageToken, qt.IsNil)

// Test using the previous tokens page token.
res, err = groupSvc.GetGroupEntitlements(ctx, uuid.New().String(), &resources.GetGroupsItemEntitlementsParams{NextToken: res.Next.PageToken})
nextToken := "some-token"
res, err = groupSvc.GetGroupEntitlements(ctx, uuid.New().String(), &resources.GetGroupsItemEntitlementsParams{NextToken: &nextToken})
c.Assert(err, qt.IsNil)
c.Assert(res, qt.IsNotNil)

listRelationsErr = errors.New("foo")
_, err = groupSvc.GetGroupEntitlements(ctx, uuid.New().String(), &resources.GetGroupsItemEntitlementsParams{})
c.Assert(err, qt.ErrorMatches, "foo")
}

func TestPatchGroupEntitlements(t *testing.T) {
Expand Down

0 comments on commit f145d67

Please sign in to comment.