Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Django 5 support. #537

Merged
merged 13 commits into from
May 29, 2024
Merged
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down Expand Up @@ -90,11 +90,11 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
20 changes: 8 additions & 12 deletions hostpolicy/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class HostPolicyPermissionsUpdateDestroy(M2MPermissions,
permission_classes = (IsSuperOrHostPolicyAdminOrReadOnly, )


class HostPolicyAtomList(HostPolicyAtomLogMixin, MregListCreateAPIView):
class HostPolicyAtomList(HostPolicyAtomLogMixin, LowerCaseLookupMixin, MregListCreateAPIView):

queryset = HostPolicyAtom.objects.all()
serializer_class = serializers.HostPolicyAtomSerializer
Expand All @@ -75,11 +75,9 @@ class HostPolicyAtomList(HostPolicyAtomLogMixin, MregListCreateAPIView):
filterset_class = HostPolicyAtomFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
# Due to the overriding of get_queryset, we need to manually use lower()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
if self.get_object_from_request(request):
content = {"ERROR": "name already in use"}
return Response(content, status=status.HTTP_409_CONFLICT)

return super().post(request, *args, **kwargs)

Expand All @@ -99,7 +97,7 @@ def _role_prefetcher(qs):
'atoms', queryset=HostPolicyAtom.objects.order_by('name')))


class HostPolicyRoleList(HostPolicyRoleLogMixin, MregListCreateAPIView):
class HostPolicyRoleList(HostPolicyRoleLogMixin, LowerCaseLookupMixin, MregListCreateAPIView):

queryset = HostPolicyRole.objects.all()
serializer_class = serializers.HostPolicyRoleSerializer
Expand All @@ -108,11 +106,9 @@ class HostPolicyRoleList(HostPolicyRoleLogMixin, MregListCreateAPIView):
filterset_class = HostPolicyRoleFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
# Due to the overriding of get_queryset, we need to manually use lower()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
if self.get_object_from_request(request):
content = {"ERROR": "name already in use"}
return Response(content, status=status.HTTP_409_CONFLICT)
return super().post(request, *args, **kwargs)


Expand Down
6 changes: 6 additions & 0 deletions mreg/api/v1/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class Meta:


class HostFilterSet(filters.FilterSet):

# It's weird that we have to define the id field here, but it's necessary for the filters to work.
id = filters.NumberFilter(field_name="id")
id__in = filters.BaseInFilter(field_name="id")
id__gt = filters.NumberFilter(field_name="id", lookup_expr="gt")
id__lt = filters.NumberFilter(field_name="id", lookup_expr="lt")
class Meta:
model = Host
fields = "__all__"
Expand Down
7 changes: 7 additions & 0 deletions mreg/api/v1/tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ def test_change_label_name(self):
response = self.assert_get('/api/v1/labels/')
data = response.json()
self.assertEqual("newname", data['results'][0]['name'])

def test_label_name_case_insensitive(self):
"""Test that label names are case insensitive."""
self.assert_post('/api/v1/labels/', {'name': 'case_insensitive', 'description': 'Case insensitive'})
self.assert_post_and_409('/api/v1/labels/', {'name': 'CASE_INSENSITIVE', 'description': 'Case insensitive'})
self.assert_get_and_200('/api/v1/labels/name/case_insensitive')
self.assert_get_and_200('/api/v1/labels/name/CASE_INSENSITIVE')
31 changes: 30 additions & 1 deletion mreg/api/v1/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,38 @@ def setUp(self):
clean_and_save(self.host_one)
clean_and_save(self.host_two)

def _one_hit_and_host_one(self, query: str):
"""Check that we only have one hit and it is host_one"""
response = self.assert_get(f"/hosts/?{query}")
hits = response.json()['results']
self.assertEqual(len(hits), 1)
self.assertEqual(hits[0]['name'], self.host_one.name)

def test_hosts_get_200_ok(self):
""""Getting an existing entry should return 200"""
self.assert_get('/hosts/%s' % self.host_one.name)
self.assert_get('/hosts/%s' % self.host_one.name)

def test_host_get_200_ok_by_id(self):
"""Getting an existing entry by id should return 200"""
self._one_hit_and_host_one(f"id={self.host_one.id}")

def test_host_get_200_ok_by_id_gt_and_lt(self):
"""Getting an existing entry by id should return 200"""
id = self.host_one.id
(id_after, id_before) = (id + 1, id - 1)
self._one_hit_and_host_one(f"id__gt={id_before}&id__lt={id_after}")

def test_host_get_200_ok_by_id_in(self):
"""Getting an existing entry by id should return 200"""
self._one_hit_and_host_one(f"id__in={self.host_one.id}")

def test_host_get_200_ok_by_contact(self):
"""Getting an existing entry by ip should return 200"""
self._one_hit_and_host_one(f"contact={self.host_one.contact}")

def test_host_get_200_ok_by_name(self):
"""Getting an existing entry by name should return 200"""
self._one_hit_and_host_one(f"name={self.host_one.name}")

def test_hosts_get_case_insensitive_200_ok(self):
""""Getting an existing entry should return 200"""
Expand Down
4 changes: 2 additions & 2 deletions mreg/api/v1/tests/tests_bacnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def test_post_no_id_201_created(self):
response = self.assert_post(self.basepath, post_data)
response = self.assert_get(response['Location'])
self.assertIn('id', response.data)
self.assertEquals(response.data['host'], self.host_two.id)
self.assertEqual(response.data['host'], self.host_two.id)

def test_post_with_hostname_instead_of_id(self):
post_data = {'hostname': self.host_two.name}
response = self.assert_post(self.basepath, post_data)
response = self.assert_get(response['Location'])
self.assertIn('id', response.data)
self.assertEquals(response.data['host'], self.host_two.id)
self.assertEqual(response.data['host'], self.host_two.id)

def test_post_without_host_400(self):
"""Posting a new entry without specifying a host should return 400 bad request"""
Expand Down
14 changes: 7 additions & 7 deletions mreg/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,13 @@ def post(self, request, *args, **kwargs):
ipdata = {"host": host.pk, "ipaddress": ipkey}
ip = Ipaddress()
ipserializer = IpaddressSerializer(ip, data=ipdata)
if ipserializer.is_valid(raise_exception=True):
self.perform_create(ipserializer)
location = request.path + host.name
return Response(
status=status.HTTP_201_CREATED,
headers={"Location": location},
)
ipserializer.is_valid(raise_exception=True)
self.perform_create(ipserializer)
location = request.path + host.name
return Response(
status=status.HTTP_201_CREATED,
headers={"Location": location},
)
else:
host = Host()
hostserializer = HostSerializer(host, data=hostdata)
Expand Down
12 changes: 5 additions & 7 deletions mreg/api/v1/views_hostgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _hostgroup_prefetcher(qs):
'owners', queryset=Group.objects.order_by('name')))


class HostGroupList(HostGroupLogMixin, MregListCreateAPIView):
class HostGroupList(HostGroupLogMixin, LowerCaseLookupMixin, MregListCreateAPIView):
"""
get:
Lists all hostgroups in use.
Expand All @@ -76,14 +76,12 @@ class HostGroupList(HostGroupLogMixin, MregListCreateAPIView):
serializer_class = serializers.HostGroupSerializer
permission_classes = (IsSuperOrGroupAdminOrReadOnly, )
filterset_class = HostGroupFilterSet
lookup_field = 'name'

def post(self, request, *args, **kwargs):
if "name" in request.data:
# We need to manually use lower() here due to the overriden get_queryset()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'hostgroup name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
self.lookup_field = 'name'
if self.get_object_from_request(request):
content = {'ERROR': 'hostgroup name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
return super().post(request, *args, **kwargs)


Expand Down
16 changes: 8 additions & 8 deletions mreg/api/v1/views_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
from .filters import LabelFilterSet


class LabelList(MregListCreateAPIView):
class LabelList(MregListCreateAPIView, LowerCaseLookupMixin):
queryset = Label.objects.all()
serializer_class = serializers.LabelSerializer
permission_classes = (IsSuperOrAdminOrReadOnly,)
filterset_class = LabelFilterSet
lookup_field = "name"

def post(self, request, *args, **kwargs):
if "name" in request.data:
if self.get_queryset().filter(name=request.data["name"]).exists():
content = {"ERROR": "Label name already in use"}
return Response(content, status=status.HTTP_409_CONFLICT)
self.lookup_field = "name"
def post(self, request, *args, **kwargs):
if self.get_object_from_request(request):
content = {"ERROR": "Label name already in use"}
return Response(content, status=status.HTTP_409_CONFLICT)
return super().post(request, *args, **kwargs)


Expand All @@ -43,8 +42,9 @@ class LabelDetail(LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):
permission_classes = (IsSuperOrAdminOrReadOnly,)


class LabelDetailByName(MregRetrieveUpdateDestroyAPIView):
class LabelDetailByName(LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):
queryset = Label.objects.all()
serializer_class = serializers.LabelSerializer
permission_classes = (IsSuperOrAdminOrReadOnly,)
filterset_class = LabelFilterSet
lookup_field = "name"
30 changes: 23 additions & 7 deletions mreg/managers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@

from typing import Any, Dict, Type
from django.db import models

from .fields import LowerCaseCharField


class LowerCaseManager(models.Manager):
class LowerCaseManager(models.Manager[Any]):
"""A manager that lowercases all values of LowerCaseCharFields in filter/exclude/get calls."""

@property
def lowercase_fields(self):
"""A list of field names that are LowerCaseCharFields.

Note: This is a cached property to avoid recalculating the list every time it is accessed.
We are making the assumption that the model's fields do not change during runtime...
"""

if not hasattr(self, "_lowercase_fields_cache"):
self._lowercase_fields_cache = [
field.name
Expand All @@ -16,27 +24,35 @@ def lowercase_fields(self):
]
return self._lowercase_fields_cache

def _lowercase_fields(self, **kwargs):
lower_kwargs = {}
def _lowercase_fields(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Lowercase all values of LowerCaseCharFields in kwargs."""

lower_kwargs: Dict[str, Any] = {}
for key, value in kwargs.items():
field_name = key.split("__")[0]
if field_name in self.lowercase_fields and isinstance(value, str):
value = value.lower()
lower_kwargs[key] = value
return lower_kwargs

def filter(self, **kwargs):
def filter(self, **kwargs: Dict[str, Any]):
"""Lowercase all values of LowerCaseCharFields in kwargs during filtering."""
return super().filter(**self._lowercase_fields(**kwargs))

def exclude(self, **kwargs):
def exclude(self, **kwargs: Dict[str, Any]):
"""Lowercase all values of LowerCaseCharFields in kwargs during excluding."""
return super().exclude(**self._lowercase_fields(**kwargs))

def get(self, **kwargs):
def get(self, **kwargs: Dict[str, Any]):
"""Lowercase all values of LowerCaseCharFields in kwargs during get."""
return super().get(**self._lowercase_fields(**kwargs))


def lower_case_manager_factory(base_manager):
def lower_case_manager_factory(base_manager: Type[models.Manager[Any]]):
"""A factory function to create a LowerCaseManager for a given base_manager."""

class LowerCaseBaseManager(base_manager, LowerCaseManager):
"""A manager that lowercases all values of LowerCaseCharFields in filter/exclude/get calls."""
pass

return LowerCaseBaseManager
67 changes: 63 additions & 4 deletions mreg/mixins.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
from typing import Protocol, Any, Dict, Union
from django.shortcuts import get_object_or_404
from django.http import HttpRequest
from rest_framework.request import Request
from django.db.models import QuerySet


class DetailViewProtocol(Protocol):
"""Protocol that defines the expected methods and attributes for the mixin."""

request: HttpRequest
kwargs: Dict[str, Any]
lookup_field: str

def get_queryset(self) -> QuerySet[Any]:
"""Method to get the queryset."""
...

def filter_queryset(self, queryset: QuerySet[Any]) -> QuerySet[Any]:
"""Method to filter the queryset."""
...

def check_object_permissions(self, request: HttpRequest, obj: Any) -> None:
"""Method to check object permissions."""
...


class LowerCaseLookupMixin:
"""A mixin to make DRF detail view lookup case insensitive."""

def get_object(self):
def get_object(self: DetailViewProtocol) -> Any:
"""Returns the object the view is displaying.

This method is overriden to make the lookup case insensitive.
"""
This method is overridden to make the lookup case insensitive.

:returns: The object the view is displaying.
"""
queryset = self.filter_queryset(self.get_queryset())
filter_kwargs = {self.lookup_field: self.kwargs[self.lookup_field].lower()}

Expand All @@ -18,4 +43,38 @@ def get_object(self):
# May raise a permission denied
self.check_object_permissions(self.request, obj)

return obj
return obj

def get_object_from_request(
self: DetailViewProtocol, request: Request, field: Union[str, None] = None
) -> Union[Any, None]:
"""Return an object from the queryset based on data from the request, if any.

The object is found in the queryset by querying with field = request.data[field]. If the field
is not defined, and the view offers a self.lookup_field, that field is used as a fallback.

Note: This is part of the LowerCaseLookupMixin, so the value of the field in request.data will
be lowercased when querying.

:param request: The request object.
:param field: The field to use for the lookup. If None, the view's lookup_field is used.

:returns: The object from the queryset or None.

:raises AttributeError: If no field is specified and the view does not have a lookup_field defined.
"""
if not field and not self.lookup_field:
raise AttributeError("If not specifying a field, the view must have lookup_field defined.")

lfield: str = field if field else self.lookup_field

if not request.data or not isinstance(request.data, dict):
return None

if self.lookup_field not in request.data:
return None

queryset = self.filter_queryset(self.get_queryset())
filter_kwargs: Dict[str, str] = {lfield: request.data[lfield].lower()}

return queryset.filter(**filter_kwargs).first()
Loading
Loading