diff --git a/adapter.go b/adapter.go index 3606c63..9b8d025 100644 --- a/adapter.go +++ b/adapter.go @@ -200,12 +200,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, // This is part of the Auto-Save feature. func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error { return a.WithTx(func(tx *ent.Tx) error { - lines := make([]*ent.CasbinRuleCreate, 0) - for _, rule := range rules { - lines = append(lines, a.savePolicyLine(tx, ptype, rule)) - } - _, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx) - return err + return a.createPolicies(tx, ptype, rules) }) } @@ -215,7 +210,15 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err return a.WithTx(func(tx *ent.Tx) error { for _, rule := range rules { instance := a.toInstance(ptype, rule) - if err := tx.CasbinRule.DeleteOne(instance).Exec(a.ctx); err != nil { + if _, err := tx.CasbinRule.Delete().Where( + casbinrule.PtypeEQ(instance.Ptype), + casbinrule.V0EQ(instance.V0), + casbinrule.V1EQ(instance.V1), + casbinrule.V2EQ(instance.V2), + casbinrule.V3EQ(instance.V3), + casbinrule.V4EQ(instance.V4), + casbinrule.V5EQ(instance.V5), + ).Exec(a.ctx); err != nil { return err } } @@ -319,3 +322,137 @@ func (a *Adapter) savePolicyLine(tx *ent.Tx, ptype string, rule []string) *ent.C return line } + +// UpdatePolicy updates a policy rule from storage. +// This is part of the Auto-Save feature. +func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error { + return a.WithTx(func(tx *ent.Tx) error { + rule := a.toInstance(ptype, oldRule) + line := tx.CasbinRule.Update().Where( + casbinrule.PtypeEQ(rule.Ptype), + casbinrule.V0EQ(rule.V0), + casbinrule.V1EQ(rule.V1), + casbinrule.V2EQ(rule.V2), + casbinrule.V3EQ(rule.V3), + casbinrule.V4EQ(rule.V4), + casbinrule.V5EQ(rule.V5), + ) + rule = a.toInstance(ptype, newPolicy) + line.SetV0(rule.V0) + line.SetV1(rule.V1) + line.SetV2(rule.V2) + line.SetV3(rule.V3) + line.SetV4(rule.V4) + line.SetV5(rule.V5) + _, err := line.Save(a.ctx) + return err + }) +} + +// UpdatePolicies updates some policy rules to storage, like db, redis. +func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error { + return a.WithTx(func(tx *ent.Tx) error { + for _, policy := range oldRules { + rule := a.toInstance(ptype, policy) + if _, err := tx.CasbinRule.Delete().Where( + casbinrule.PtypeEQ(rule.Ptype), + casbinrule.V0EQ(rule.V0), + casbinrule.V1EQ(rule.V1), + casbinrule.V2EQ(rule.V2), + casbinrule.V3EQ(rule.V3), + casbinrule.V4EQ(rule.V4), + casbinrule.V5EQ(rule.V5), + ).Exec(a.ctx); err != nil { + return err + } + } + lines := make([]*ent.CasbinRuleCreate, 0) + for _, policy := range newRules { + lines = append(lines, a.savePolicyLine(tx, ptype, policy)) + } + if _, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx); err != nil { + return err + } + return nil + }) +} + +// UpdateFilteredPolicies deletes old rules and adds new rules. +func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { + oldPolicies := make([][]string, 0) + err := a.WithTx(func(tx *ent.Tx) error { + line := tx.CasbinRule.Query() + if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V0EQ(fieldValues[0-fieldIndex])) + } + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V1EQ(fieldValues[1-fieldIndex])) + } + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V2EQ(fieldValues[2-fieldIndex])) + } + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V3EQ(fieldValues[3-fieldIndex])) + } + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V4EQ(fieldValues[4-fieldIndex])) + } + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { + line = line.Where(casbinrule.V5EQ(fieldValues[5-fieldIndex])) + } + rules, err := line.All(a.ctx) + if err != nil { + return err + } + for _, rule := range rules { + if _, err := tx.CasbinRule.Delete().Where( + casbinrule.IDEQ(rule.ID), + ).Exec(a.ctx); err != nil { + return err + } + } + a.createPolicies(tx, ptype, newPolicies) + for _, rule := range rules { + oldPolicies = append(oldPolicies, CasbinRuleToStringArray(rule)) + } + return nil + }) + if err != nil { + return nil, err + } + return oldPolicies, nil +} + +func (a *Adapter) createPolicies(tx *ent.Tx, ptype string, policies [][]string) error { + lines := make([]*ent.CasbinRuleCreate, 0) + for _, policy := range policies { + lines = append(lines, a.savePolicyLine(tx, ptype, policy)) + } + if _, err := tx.CasbinRule.CreateBulk(lines...).Save(a.ctx); err != nil { + return err + } + return nil +} + +func CasbinRuleToStringArray(rule *ent.CasbinRule) []string { + arr := make([]string, 0) + if rule.V0 != "" { + arr = append(arr, rule.V0) + } + if rule.V1 != "" { + arr = append(arr, rule.V1) + } + if rule.V2 != "" { + arr = append(arr, rule.V2) + } + if rule.V3 != "" { + arr = append(arr, rule.V3) + } + if rule.V4 != "" { + arr = append(arr, rule.V4) + } + if rule.V5 != "" { + arr = append(arr, rule.V5) + } + return arr +} diff --git a/adapter_test.go b/adapter_test.go index 00f159a..31d7e27 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -183,6 +183,10 @@ func testAutoSave(t *testing.T, a *Adapter) { e.RemoveFilteredPolicy(0, "data2_admin") e.LoadPolicy() testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) + + e.RemovePolicies([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) + e.LoadPolicy() + testGetPolicy(t, e, [][]string{}) } //func testFilteredPolicy(t *testing.T, a *Adapter) { @@ -225,7 +229,7 @@ func testUpdatePolicies(t *testing.T, a *Adapter) { e.EnableAutoSave(true) e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}}) e.LoadPolicy() - testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) } func testUpdateFilteredPolicies(t *testing.T, a *Adapter) { @@ -234,7 +238,7 @@ func testUpdateFilteredPolicies(t *testing.T, a *Adapter) { e.EnableAutoSave(true) e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read") - e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write") + e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2") e.LoadPolicy() testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}}) } @@ -263,4 +267,14 @@ func TestAdapters(t *testing.T) { a = initAdapterWithClientInstance(t, db) testAutoSave(t, a) testSaveLoad(t, a) + + a = initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/casbin") + testUpdatePolicy(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) + + a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable dbname=casbin") + testUpdatePolicy(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) }