Skip to content

Commit

Permalink
fix: UpdateFilteredPolicies match ptype and ignore empty filed (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
huijiezheng authored Feb 5, 2023
1 parent e631a83 commit fb14004
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,22 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
return a.WithTx(func(tx *ent.Tx) error {
cond := make([]predicate.CasbinRule, 0)
cond = append(cond, casbinrule.PtypeEQ(ptype))
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) && len(fieldValues[0-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V0EQ(fieldValues[0-fieldIndex]))
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) && len(fieldValues[1-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V1EQ(fieldValues[1-fieldIndex]))
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) && len(fieldValues[2-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V2EQ(fieldValues[2-fieldIndex]))
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) && len(fieldValues[3-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V3EQ(fieldValues[3-fieldIndex]))
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) && len(fieldValues[4-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V4EQ(fieldValues[4-fieldIndex]))
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) && len(fieldValues[5-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V5EQ(fieldValues[5-fieldIndex]))
}
_, err := tx.CasbinRule.Delete().Where(
Expand Down Expand Up @@ -446,37 +446,47 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []
func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
oldPolicies := make([][]string, 0)
err := a.WithTx(func(tx *ent.Tx) error {
line := tx.CasbinRule.Query()
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V0EQ(fieldValues[0-fieldIndex]))
cond := make([]predicate.CasbinRule, 0)
cond = append(cond, casbinrule.PtypeEQ(ptype))
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) && len(fieldValues[0-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V0EQ(fieldValues[0-fieldIndex]))
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V1EQ(fieldValues[1-fieldIndex]))
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) && len(fieldValues[1-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V1EQ(fieldValues[1-fieldIndex]))
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V2EQ(fieldValues[2-fieldIndex]))
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) && len(fieldValues[2-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V2EQ(fieldValues[2-fieldIndex]))
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V3EQ(fieldValues[3-fieldIndex]))
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) && len(fieldValues[3-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V3EQ(fieldValues[3-fieldIndex]))
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V4EQ(fieldValues[4-fieldIndex]))
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) && len(fieldValues[4-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V4EQ(fieldValues[4-fieldIndex]))
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line = line.Where(casbinrule.V5EQ(fieldValues[5-fieldIndex]))
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) && len(fieldValues[5-fieldIndex]) > 0 {
cond = append(cond, casbinrule.V5EQ(fieldValues[5-fieldIndex]))
}
rules, err := line.All(a.ctx)
rules, err := tx.CasbinRule.Query().
Where(cond...).
All(a.ctx)
if err != nil {
return err
}
for _, rule := range rules {
if _, err := tx.CasbinRule.Delete().Where(
casbinrule.IDEQ(rule.ID),
).Exec(a.ctx); err != nil {
return err
}
ruleIDs := make([]int, 0, len(rules))
for _, r := range rules {
ruleIDs = append(ruleIDs, r.ID)
}

_, err = tx.CasbinRule.Delete().
Where(casbinrule.IDIn(ruleIDs...)).
Exec(a.ctx)
if err != nil {
return err
}

if err := a.createPolicies(tx, ptype, newPolicies); err != nil {
return err
}
a.createPolicies(tx, ptype, newPolicies)
for _, rule := range rules {
oldPolicies = append(oldPolicies, CasbinRuleToStringArray(rule))
}
Expand Down

0 comments on commit fb14004

Please sign in to comment.