Skip to content

Commit

Permalink
Merge pull request #11 from closetool/master
Browse files Browse the repository at this point in the history
feat: filtered interface implementation
  • Loading branch information
hsluoyz authored Jul 25, 2021
2 parents 3edabda + 8c3c762 commit 1ed4f45
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
65 changes: 65 additions & 0 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package entadapter
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"

"entgo.io/ent/dialect"
Expand All @@ -30,6 +32,7 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v4/stdlib"
_ "github.com/lib/pq"

//_ "github.com/mattn/go-sqlite3"
"github.com/pkg/errors"
)
Expand All @@ -42,6 +45,18 @@ const (
type Adapter struct {
client *ent.Client
ctx context.Context

filtered bool
}

type Filter struct {
Ptype []string
V0 []string
V1 []string
V2 []string
V3 []string
V4 []string
V5 []string
}

type Option func(a *Adapter) error
Expand Down Expand Up @@ -111,6 +126,56 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
return nil
}

// LoadFilteredPolicy loads only policy rules that match the filter.
// Filter is a map[string][]string, key denotes ptype, []string is policy
func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {

filterValue, ok := filter.(Filter)
if !ok {
return fmt.Errorf("invalid filter type: %v", reflect.TypeOf(filter))
}

session := a.client.CasbinRule.Query()
if len(filterValue.Ptype) != 0 {
session.Where(casbinrule.PtypeIn(filterValue.Ptype...))
}
if len(filterValue.V0) != 0 {
session.Where(casbinrule.V0In(filterValue.V0...))
}
if len(filterValue.V1) != 0 {
session.Where(casbinrule.V1In(filterValue.V1...))
}
if len(filterValue.V2) != 0 {
session.Where(casbinrule.V2In(filterValue.V2...))
}
if len(filterValue.V3) != 0 {
session.Where(casbinrule.V3In(filterValue.V3...))
}
if len(filterValue.V4) != 0 {
session.Where(casbinrule.V4In(filterValue.V4...))
}
if len(filterValue.V5) != 0 {
session.Where(casbinrule.V5In(filterValue.V5...))
}

lines, err := session.All(a.ctx)
if err != nil {
return err
}

for _, line := range lines {
loadPolicyLine(line, model)
}
a.filtered = true

return nil
}

// IsFiltered returns true if the loaded policy has been filtered.
func (a *Adapter) IsFiltered() bool {
return a.filtered
}

// SavePolicy saves all policy rules to the storage.
func (a *Adapter) SavePolicy(model model.Model) error {
return a.WithTx(func(tx *ent.Tx) error {
Expand Down
27 changes: 27 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
"github.com/casbin/ent-adapter/ent"
"github.com/stretchr/testify/assert"
)

func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
Expand Down Expand Up @@ -243,6 +244,32 @@ func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {
testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}

func testFilteredPolicy(t *testing.T, a *Adapter) {
// NewEnforcer() without an adapter will not auto load the policy
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", "examples/rbac_policy.csv")

// Now set the adapter
e.SetAdapter(a)

assert.Nil(t, e.SavePolicy())

// Load only alice's policies
assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"alice"}}))
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}})

// Load only bob's policies
assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"bob"}}))
testGetPolicy(t, e, [][]string{{"bob", "data2", "write"}})

// Load policies for data2_admin
assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"data2_admin"}}))
testGetPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})

// Load policies for alice and bob
assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"alice", "bob"}}))
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}

func TestAdapters(t *testing.T) {
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/casbin")
testAutoSave(t, a)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ require (
github.com/lib/pq v1.10.0
//github.com/mattn/go-sqlite3 v1.14.6
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.7.0
)

0 comments on commit 1ed4f45

Please sign in to comment.