Skip to content

Commit

Permalink
feat: add db alias (#21)
Browse files Browse the repository at this point in the history
* feat: add db alias

* feat: add db alias

* docs: add CASBIN_DB_ALIAS argument to readme
  • Loading branch information
BustDot authored Jul 31, 2023
1 parent 15e7887 commit a6d61da
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 26 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def hello(request):
A string containing the file location of your casbin model.

### `CASBIN_ADAPTER`
A string containing the adapter import path. Defaults to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`
A string containing the adapter import path. Default to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`

### `CASBIN_ADAPTER_ARGS`
A tuple of arguments to be passed into the constructor of the adapter specified
Expand All @@ -80,6 +80,9 @@ E.g. if you wish to use the file adapter
set the adapter to `casbin.persist.adapters.FileAdapter` and use
`CASBIN_ADAPTER_ARGS = ('path/to/policy_file.csv',)`

### `CASBIN_DB_ALIAS`
The database the adapter uses. Default to "default".

### `CASBIN_WATCHER`
Watcher instance to be set as the watcher on the enforcer instance.

Expand Down
13 changes: 8 additions & 5 deletions casbin_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""

def __init__(self, db_alias="default"):
self.db_alias = db_alias

def load_policy(self, model):
"""loads all policy rules from the storage."""
try:
lines = CasbinRule.objects.all()
lines = CasbinRule.objects.using(self.db_alias).all()

for line in lines:
persist.load_policy_line(str(line), model)
Expand All @@ -41,7 +44,7 @@ def save_policy(self, model):
"""saves all policy rules to the storage."""
# See https://casbin.org/docs/en/adapters#autosave
# for why this is deleting all rules
CasbinRule.objects.all().delete()
CasbinRule.objects.using(self.db_alias).all().delete()

lines = []
for sec in ["p", "g"]:
Expand All @@ -50,7 +53,7 @@ def save_policy(self, model):
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
lines.append(self._create_policy_line(ptype, rule))
CasbinRule.objects.bulk_create(lines)
CasbinRule.objects.using(self.db_alias).bulk_create(lines)
return True

def add_policy(self, sec, ptype, rule):
Expand All @@ -63,7 +66,7 @@ def remove_policy(self, sec, ptype, rule):
query_params = {"ptype": ptype}
for i, v in enumerate(rule):
query_params["v{}".format(i)] = v
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False

def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
Expand All @@ -77,5 +80,5 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return False
for i, v in enumerate(field_values):
query_params["v{}".format(i + field_index)] = v
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False
2 changes: 0 additions & 2 deletions casbin_adapter/apps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from django.apps import AppConfig
from django.db import connection
from django.db.utils import OperationalError, ProgrammingError


class CasbinAdapterConfig(AppConfig):
Expand Down
41 changes: 27 additions & 14 deletions casbin_adapter/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from casbin import Enforcer

from .adapter import Adapter
from .utils import import_class

logger = logging.getLogger(__name__)


class ProxyEnforcer(Enforcer):
_initialized = False
db_alias = "default"

def __init__(self, *args, **kwargs):
if self._initialized:
Expand All @@ -27,8 +27,9 @@ def _load(self):
model = getattr(settings, "CASBIN_MODEL")
adapter_loc = getattr(settings, "CASBIN_ADAPTER", "casbin_adapter.adapter.Adapter")
adapter_args = getattr(settings, "CASBIN_ADAPTER_ARGS", tuple())
self.db_alias = getattr(settings, "CASBIN_DB_ALIAS", "default")
Adapter = import_class(adapter_loc)
adapter = Adapter(*adapter_args)
adapter = Adapter(self.db_alias, *adapter_args)

super().__init__(model, adapter)
logger.debug("Casbin enforcer initialised")
Expand All @@ -44,7 +45,7 @@ def _load(self):
def __getattribute__(self, name):
safe_methods = ["__init__", "_load", "_initialized"]
if not super().__getattribute__("_initialized") and name not in safe_methods:
initialize_enforcer()
initialize_enforcer(self.db_alias)
if not super().__getattribute__("_initialized"):
raise Exception(
(
Expand All @@ -59,17 +60,29 @@ def __getattribute__(self, name):
enforcer = ProxyEnforcer()


def initialize_enforcer():
def initialize_enforcer(db_alias=None):
try:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
if row:
enforcer._load()
row = None
if db_alias:
with connection[db_alias].cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
else:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()

if row:
enforcer._load()
except (OperationalError, ProgrammingError):
pass
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
casbin==1.16.10
casbin>=1.16.10
Django

3 changes: 1 addition & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-r requirements.txt
setuptools==60.2.0
simpleeval==0.9.12
setuptools==60.2.0
1 change: 0 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import casbin
import simpleeval
from unittest import TestCase

from django.test import TestCase
from casbin_adapter.models import CasbinRule
Expand Down

0 comments on commit a6d61da

Please sign in to comment.