Skip to content

Commit

Permalink
Add a way to register custom field adapters with the 'register_field_…
Browse files Browse the repository at this point in the history
…adapter' hook, as well as docs
  • Loading branch information
jacobtoppm committed Sep 9, 2020
1 parent cb48db9 commit 7832be2
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 26 deletions.
33 changes: 32 additions & 1 deletion docs/settings.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Settings
# Settings and Hooks

## Settings

### `WAGTAILTRANSFER_SECRET_KEY`

Expand Down Expand Up @@ -84,3 +86,32 @@ that object will end up with unresolved references, to be handled by the same se

Note that these settings do not accept models that are defined as subclasses through multi-table inheritance - in
particular, they cannot be used to define behaviour that only applies to specific subclasses of Page.

## Hooks

### `register_field_adapters`

Field adapters are classes used by Wagtail Transfer to serialize and identify references from fields when exporting,
and repopulate them with the serialised data when importing. You can register a custom field adapter by using the
`register_field_adapters` hook. A function registered with this hook should return a dictionary which maps field classes
to field adapter classes (note that with inheritance, the field adapter registered with the closest ancestor class will be used).
For example, to register a custom field adapter against Django's `models.Field`:

```python
# <my_app>/wagtail_hooks.py

from django.db import models

from wagtail.core import hooks
from wagtail_transfer.field_adapters import FieldAdapter


class MyCustomAdapter(FieldAdapter):
pass


@hooks.register('register_field_adapters')
def register_my_custom_adapter():
return {models.Field: MyCustomAdapter}

```
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ nav:
- Basic Usage: basic_usage.md
- How It Works: how_it_works.md
- Management commands: management_commands.md
- Settings Reference: settings.md
- Settings and Hooks: settings.md
57 changes: 38 additions & 19 deletions wagtail_transfer/field_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.db import models
from django.db.models.fields.reverse_related import ManyToOneRel
from taggit.managers import TaggableManager
from wagtail.core import hooks
from wagtail.core.fields import RichTextField, StreamField

from .files import File, FileTransferError, get_file_hash, get_file_size
Expand Down Expand Up @@ -240,26 +241,44 @@ def populate_field(self, instance, value, context):
pass


ADAPTERS_BY_FIELD_CLASS = {
models.Field: FieldAdapter,
models.ForeignKey: ForeignKeyAdapter,
ManyToOneRel: ManyToOneRelAdapter,
RichTextField: RichTextAdapter,
StreamField: StreamFieldAdapter,
models.FileField: FileAdapter,
models.ManyToManyField: ManyToManyFieldAdapter,
TaggableManager: TaggableManagerAdapter,
GenericRelation: GenericRelationAdapter,
}
class AdapterRegistry:
BASE_ADAPTERS_BY_FIELD_CLASS = {
models.Field: FieldAdapter,
models.ForeignKey: ForeignKeyAdapter,
ManyToOneRel: ManyToOneRelAdapter,
RichTextField: RichTextAdapter,
StreamField: StreamFieldAdapter,
models.FileField: FileAdapter,
models.ManyToManyField: ManyToManyFieldAdapter,
TaggableManager: TaggableManagerAdapter,
GenericRelation: GenericRelationAdapter,
}

def __init__(self):
self._scanned_for_adapters = False
self.adapters_by_field_class = {}

def _scan_for_adapters(self):
adapters = dict(self.BASE_ADAPTERS_BY_FIELD_CLASS)

@lru_cache(maxsize=None)
def get_field_adapter(field):
# find the adapter class for the most specific class in the field's inheritance tree
for fn in hooks.get_hooks('register_field_adapters'):
adapters.update(fn())

self.adapters_by_field_class = adapters

for field_class in type(field).__mro__:
if field_class in ADAPTERS_BY_FIELD_CLASS:
adapter_class = ADAPTERS_BY_FIELD_CLASS[field_class]
return adapter_class(field)
@lru_cache(maxsize=None)
def get_field_adapter(self, field):
# find the adapter class for the most specific class in the field's inheritance tree

raise ValueError("No adapter found for field: %r" % field)
if not self._scanned_for_adapters:
self._scan_for_adapters()

for field_class in type(field).__mro__:
if field_class in self.adapters_by_field_class:
adapter_class = self.adapters_by_field_class[field_class]
return adapter_class(field)

raise ValueError("No adapter found for field: %r" % field)


adapter_registry = AdapterRegistry()
6 changes: 3 additions & 3 deletions wagtail_transfer/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from treebeard.mp_tree import MP_Node
from wagtail.core.models import Page

from .field_adapters import get_field_adapter
from .field_adapters import adapter_registry
from .locators import get_locator_for_model
from .models import get_base_model, get_base_model_for_path, get_model_for_path

Expand Down Expand Up @@ -583,7 +583,7 @@ def _populate_fields(self, context):
except KeyError:
continue

get_field_adapter(field).populate_field(self.instance, value, context)
adapter_registry.get_field_adapter(field).populate_field(self.instance, value, context)

def _populate_many_to_many_fields(self, context):
save_needed = False
Expand Down Expand Up @@ -626,7 +626,7 @@ def dependencies(self):
for field in self.model._meta.get_fields():
if isinstance(field, models.Field):
val = self.object_data['fields'].get(field.name)
deps.update(get_field_adapter(field).get_dependencies(val))
deps.update(adapter_registry.get_field_adapter(field).get_dependencies(val))

return deps

Expand Down
4 changes: 2 additions & 2 deletions wagtail_transfer/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from treebeard.mp_tree import MP_Node
from wagtail.core.models import Page

from .field_adapters import get_field_adapter
from .field_adapters import adapter_registry
from .models import get_base_model


Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, model):
if not isinstance(related_field, ParentalKey):
continue

self.field_adapters.append(get_field_adapter(field))
self.field_adapters.append(adapter_registry.get_field_adapter(field))

def get_objects_by_ids(self, ids):
"""
Expand Down

0 comments on commit 7832be2

Please sign in to comment.