Skip to content

Commit

Permalink
Makes default rules overridable in UserAuthProtocol (#2751)
Browse files Browse the repository at this point in the history
* Make default rules overridable

The intention here is to be able to extend this class and skip/change the default auth rules from the child class

NHUB-572

* Add `include_fields` to `ESQuery`

NHUB-572
  • Loading branch information
eos87 authored Nov 14, 2024
1 parent de1cfbc commit 2a80550
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
19 changes: 11 additions & 8 deletions superdesk/core/auth/user_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ class UserAuthProtocol:
async def authenticate(self, request: Request) -> Any | None:
raise SuperdeskApiError.unauthorizedError()

def get_default_auth_rules(self) -> list[AuthRule]:
from .rules import login_required_auth_rule, endpoint_intrinsic_auth_rule

default_rules: list[AuthRule] = [
login_required_auth_rule,
endpoint_intrinsic_auth_rule,
]

return default_rules

async def authorize(self, request: Request) -> Any | None:
endpoint_rules = request.endpoint.get_auth_rules()
if endpoint_rules is False:
Expand All @@ -17,14 +27,7 @@ async def authorize(self, request: Request) -> Any | None:
elif isinstance(endpoint_rules, dict):
endpoint_rules = cast(list[AuthRule], endpoint_rules.get(request.method) or [])

from .rules import login_required_auth_rule, endpoint_intrinsic_auth_rule

default_rules: list[AuthRule] = [
login_required_auth_rule,
endpoint_intrinsic_auth_rule,
]

for rule in default_rules + (endpoint_rules or []):
for rule in self.get_default_auth_rules() + (endpoint_rules or []):
response = await rule(request)
if response is not None:
return response
Expand Down
6 changes: 5 additions & 1 deletion superdesk/core/types/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ESQuery:
query: ESBoolQuery = Field(default_factory=ESBoolQuery)
post_filter: ESBoolQuery = Field(default_factory=ESBoolQuery)
aggs: dict[str, Any] = Field(default_factory=dict)
include_fields: list[str] = Field(default_factory=list)
exclude_fields: list[str] = Field(default_factory=list)

def generate_query_dict(self, query: dict[str, Any] | None = None) -> dict[str, Any]:
Expand Down Expand Up @@ -122,7 +123,10 @@ def generate_query_dict(self, query: dict[str, Any] | None = None) -> dict[str,
if self.post_filter.filter:
query["post_filter"]["bool"]["filter"] = self.post_filter.filter

if self.exclude_fields:
# these two are mutually exclusive
if self.include_fields:
query.setdefault("_source", {}).setdefault("includes", []).extend(self.include_fields)
elif self.exclude_fields:
query.setdefault("_source", {}).setdefault("excludes", []).extend(self.exclude_fields)

return query
Expand Down

0 comments on commit 2a80550

Please sign in to comment.