Skip to content

Commit

Permalink
Merge pull request #66 from stevejalim/pluggable-serializers
Browse files Browse the repository at this point in the history
Support custom/pluggable serializers
  • Loading branch information
jacobtoppm committed Jan 20, 2021
2 parents cd9fd8e + 8c6b184 commit 37bad68
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 23 deletions.
33 changes: 33 additions & 0 deletions docs/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,36 @@ def register_my_custom_adapter():
return {models.Field: MyCustomAdapter}

```


### `register_custom_serializers`

In exceptional cases, such as limiting the fields you export to only a subset of the content, you may need to use a custom serializer instead of the default `PageSerializer`.
You can register a custom serializer by using the `register_custom_serializers` hook.
A function registered with this hook should return a dictionary which maps model classes to serializer classes (note that with inheritance, the serializer registered with the closest ancestor class will be used).
For example, to register a custom serializer against `myapp.MyModel`:

```python
# <my_app>/wagtail_hooks.py


from wagtail.core import hooks
from wagtail_transfer.serializers import PageSerializer

from myapp.models import MyModel

class MyModelCustomSerializer(PageSerializer):

ignored_fields = PageSerializer.ignored_fields + [
'secret_field_1',
'environment_specific_data_field_123',
...
]
pass


@hooks.register('register_custom_serializers')
def register_my_custom_serializer():
return {MyModel: MyModelCustomSerializer}

```
7 changes: 4 additions & 3 deletions wagtail_transfer/field_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_object_deletions(self, instance, value, context):

return {child for child in self._get_related_objects(instance) if child.pk not in matched_destination_ids}
return set()

def get_objects_to_serialize(self, instance):
if self.is_parental:
return getattr(instance, self.name).all()
Expand Down Expand Up @@ -372,14 +372,15 @@ class AdapterRegistry:
def __init__(self):
self._scanned_for_adapters = False
self.adapters_by_field_class = {}

def _scan_for_adapters(self):
adapters = dict(self.BASE_ADAPTERS_BY_FIELD_CLASS)

for fn in hooks.get_hooks('register_field_adapters'):
adapters.update(fn())

self.adapters_by_field_class = adapters
self._scanned_for_adapters = True

@lru_cache(maxsize=None)
def get_field_adapter(self, field):
Expand Down
49 changes: 34 additions & 15 deletions wagtail_transfer/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from modelcluster.fields import ParentalKey
from treebeard.mp_tree import MP_Node
from wagtail.core import hooks
from wagtail.core.models import Page

from .field_adapters import adapter_registry
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, model):

# ignore primary keys (including MTI parent pointers)
if getattr(field, 'primary_key', False):
continue
continue

adapter = adapter_registry.get_field_adapter(field)

Expand All @@ -95,7 +95,6 @@ def __init__(self, model):

self.field_adapters = [adapter for adapter in field_adapters if adapter.name not in adapter_managed_fields]


def get_objects_by_ids(self, ids):
"""
Given a list of IDs, return a list of model instances that we can
Expand Down Expand Up @@ -169,17 +168,37 @@ def get_objects_by_ids(self, ids):
return self.model.objects.filter(pk__in=ids).specific()


SERIALIZERS_BY_MODEL_CLASS = {
models.Model: ModelSerializer,
MP_Node: TreeModelSerializer,
Page: PageSerializer,
}
class SerializerRegistry:
BASE_SERIALIZERS_BY_MODEL_CLASS = {
models.Model: ModelSerializer,
MP_Node: TreeModelSerializer,
Page: PageSerializer,
}

def __init__(self):
self._scanned_for_serializers = False
self.serializers_by_model_class = {}

def _scan_for_serializers(self):
serializers = dict(self.BASE_SERIALIZERS_BY_MODEL_CLASS)

for fn in hooks.get_hooks('register_custom_serializers'):
serializers.update(fn())

self.serializers_by_model_class = serializers
self._scanned_for_serializers = True

@lru_cache(maxsize=None)
def get_model_serializer(self, model):
# find the serializer class for the most specific class in the model's inheritance tree

if not self._scanned_for_serializers:
self._scan_for_serializers()

for cls in model.__mro__:
if cls in self.serializers_by_model_class:
serializer_class = self.serializers_by_model_class[cls]
return serializer_class(model)


@lru_cache(maxsize=None)
def get_model_serializer(model):
# find the serializer class for the most specific class in the model's inheritance tree
for cls in model.__mro__:
if cls in SERIALIZERS_BY_MODEL_CLASS:
serializer_class = SERIALIZERS_BY_MODEL_CLASS[cls]
return serializer_class(model)
serializer_registry = SerializerRegistry()
10 changes: 5 additions & 5 deletions wagtail_transfer/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .locators import get_locator_for_model
from .models import get_model_for_path
from .operations import ImportPlanner
from .serializers import get_model_serializer
from .serializers import serializer_registry
from .vendor.wagtail_admin_api.serializers import AdminPageSerializer
from .vendor.wagtail_admin_api.views import PagesAdminAPIViewSet

Expand All @@ -43,7 +43,7 @@ def pages_for_export(request, root_page_id):

while models_to_serialize:
model = models_to_serialize.pop()
serializer = get_model_serializer(type(model))
serializer = serializer_registry.get_model_serializer(type(model))
objects.append(serializer.serialize(model))
object_references.update(serializer.get_object_references(model))
models_to_serialize.update(serializer.get_objects_to_serialize(model).difference(serialized_models))
Expand Down Expand Up @@ -92,7 +92,7 @@ def models_for_export(request, model_path, object_id=None):

while models_to_serialize:
model = models_to_serialize.pop()
serializer = get_model_serializer(type(model))
serializer = serializer_registry.get_model_serializer(type(model))
objects.append(serializer.serialize(model))
object_references.update(serializer.get_object_references(model))
models_to_serialize.update(serializer.get_objects_to_serialize(model).difference(serialized_models))
Expand Down Expand Up @@ -134,12 +134,12 @@ def objects_for_export(request):

for model_path, ids in request_data.items():
model = get_model_for_path(model_path)
serializer = get_model_serializer(model)
serializer = serializer_registry.get_model_serializer(model)

models_to_serialize.update(serializer.get_objects_by_ids(ids))
while models_to_serialize:
instance = models_to_serialize.pop()
serializer = get_model_serializer(type(instance))
serializer = serializer_registry.get_model_serializer(type(instance))
objects.append(serializer.serialize(instance))
object_references.update(serializer.get_object_references(instance))
models_to_serialize.update(serializer.get_objects_to_serialize(instance).difference(serialized_models))
Expand Down

0 comments on commit 37bad68

Please sign in to comment.