From 8c3c76216c1143ff50b822d1b8fca109c70561a9 Mon Sep 17 00:00:00 2001 From: closetool Date: Sun, 25 Jul 2021 20:20:31 +0800 Subject: [PATCH] feat: filtered interface implementation Signed-off-by: closetool --- adapter.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++ adapter_test.go | 27 ++++++++++++++++++++ go.mod | 1 + 3 files changed, 93 insertions(+) diff --git a/adapter.go b/adapter.go index 185c9fc..2aa4dda 100644 --- a/adapter.go +++ b/adapter.go @@ -17,6 +17,8 @@ package entadapter import ( "context" "database/sql" + "fmt" + "reflect" "strings" "entgo.io/ent/dialect" @@ -30,6 +32,7 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" + //_ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" ) @@ -42,6 +45,18 @@ const ( type Adapter struct { client *ent.Client ctx context.Context + + filtered bool +} + +type Filter struct { + Ptype []string + V0 []string + V1 []string + V2 []string + V3 []string + V4 []string + V5 []string } type Option func(a *Adapter) error @@ -111,6 +126,56 @@ func (a *Adapter) LoadPolicy(model model.Model) error { return nil } +// LoadFilteredPolicy loads only policy rules that match the filter. +// Filter is a map[string][]string, key denotes ptype, []string is policy +func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error { + + filterValue, ok := filter.(Filter) + if !ok { + return fmt.Errorf("invalid filter type: %v", reflect.TypeOf(filter)) + } + + session := a.client.CasbinRule.Query() + if len(filterValue.Ptype) != 0 { + session.Where(casbinrule.PtypeIn(filterValue.Ptype...)) + } + if len(filterValue.V0) != 0 { + session.Where(casbinrule.V0In(filterValue.V0...)) + } + if len(filterValue.V1) != 0 { + session.Where(casbinrule.V1In(filterValue.V1...)) + } + if len(filterValue.V2) != 0 { + session.Where(casbinrule.V2In(filterValue.V2...)) + } + if len(filterValue.V3) != 0 { + session.Where(casbinrule.V3In(filterValue.V3...)) + } + if len(filterValue.V4) != 0 { + session.Where(casbinrule.V4In(filterValue.V4...)) + } + if len(filterValue.V5) != 0 { + session.Where(casbinrule.V5In(filterValue.V5...)) + } + + lines, err := session.All(a.ctx) + if err != nil { + return err + } + + for _, line := range lines { + loadPolicyLine(line, model) + } + a.filtered = true + + return nil +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (a *Adapter) IsFiltered() bool { + return a.filtered +} + // SavePolicy saves all policy rules to the storage. func (a *Adapter) SavePolicy(model model.Model) error { return a.WithTx(func(tx *ent.Tx) error { diff --git a/adapter_test.go b/adapter_test.go index 31d7e27..fd7754d 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -21,6 +21,7 @@ import ( "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/util" "github.com/casbin/ent-adapter/ent" + "github.com/stretchr/testify/assert" ) func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) { @@ -243,6 +244,32 @@ func testUpdateFilteredPolicies(t *testing.T, a *Adapter) { testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}}) } +func testFilteredPolicy(t *testing.T, a *Adapter) { + // NewEnforcer() without an adapter will not auto load the policy + e, _ := casbin.NewEnforcer("examples/rbac_model.conf", "examples/rbac_policy.csv") + + // Now set the adapter + e.SetAdapter(a) + + assert.Nil(t, e.SavePolicy()) + + // Load only alice's policies + assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"alice"}})) + testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}}) + + // Load only bob's policies + assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"bob"}})) + testGetPolicy(t, e, [][]string{{"bob", "data2", "write"}}) + + // Load policies for data2_admin + assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"data2_admin"}})) + testGetPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + + // Load policies for alice and bob + assert.Nil(t, e.LoadFilteredPolicy(Filter{V0: []string{"alice", "bob"}})) + testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) +} + func TestAdapters(t *testing.T) { a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/casbin") testAutoSave(t, a) diff --git a/go.mod b/go.mod index 17a7dd0..e7d009a 100644 --- a/go.mod +++ b/go.mod @@ -10,4 +10,5 @@ require ( github.com/lib/pq v1.10.0 //github.com/mattn/go-sqlite3 v1.14.6 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.7.0 )