diff --git a/docs/settings.md b/docs/settings.md index d2bc9d4..fe9940b 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -1,4 +1,6 @@ -# Settings +# Settings and Hooks + +## Settings ### `WAGTAILTRANSFER_SECRET_KEY` @@ -84,3 +86,32 @@ that object will end up with unresolved references, to be handled by the same se Note that these settings do not accept models that are defined as subclasses through multi-table inheritance - in particular, they cannot be used to define behaviour that only applies to specific subclasses of Page. + +## Hooks + +### `register_field_adapters` + +Field adapters are classes used by Wagtail Transfer to serialize and identify references from fields when exporting, +and repopulate them with the serialised data when importing. You can register a custom field adapter by using the +`register_field_adapters` hook. A function registered with this hook should return a dictionary which maps field classes +to field adapter classes (note that with inheritance, the field adapter registered with the closest ancestor class will be used). +For example, to register a custom field adapter against Django's `models.Field`: + +```python +# /wagtail_hooks.py + +from django.db import models + +from wagtail.core import hooks +from wagtail_transfer.field_adapters import FieldAdapter + + +class MyCustomAdapter(FieldAdapter): + pass + + +@hooks.register('register_field_adapters') +def register_my_custom_adapter(): + return {models.Field: MyCustomAdapter} + +``` diff --git a/mkdocs.yml b/mkdocs.yml index 4522734..0b0e5c7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,4 +14,4 @@ nav: - Basic Usage: basic_usage.md - How It Works: how_it_works.md - Management commands: management_commands.md - - Settings Reference: settings.md + - Settings and Hooks: settings.md diff --git a/wagtail_transfer/field_adapters.py b/wagtail_transfer/field_adapters.py index c40a41a..e085d2c 100644 --- a/wagtail_transfer/field_adapters.py +++ b/wagtail_transfer/field_adapters.py @@ -7,6 +7,7 @@ from django.db import models from django.db.models.fields.reverse_related import ManyToOneRel from taggit.managers import TaggableManager +from wagtail.core import hooks from wagtail.core.fields import RichTextField, StreamField from .files import File, FileTransferError, get_file_hash, get_file_size @@ -240,26 +241,44 @@ def populate_field(self, instance, value, context): pass -ADAPTERS_BY_FIELD_CLASS = { - models.Field: FieldAdapter, - models.ForeignKey: ForeignKeyAdapter, - ManyToOneRel: ManyToOneRelAdapter, - RichTextField: RichTextAdapter, - StreamField: StreamFieldAdapter, - models.FileField: FileAdapter, - models.ManyToManyField: ManyToManyFieldAdapter, - TaggableManager: TaggableManagerAdapter, - GenericRelation: GenericRelationAdapter, -} +class AdapterRegistry: + BASE_ADAPTERS_BY_FIELD_CLASS = { + models.Field: FieldAdapter, + models.ForeignKey: ForeignKeyAdapter, + ManyToOneRel: ManyToOneRelAdapter, + RichTextField: RichTextAdapter, + StreamField: StreamFieldAdapter, + models.FileField: FileAdapter, + models.ManyToManyField: ManyToManyFieldAdapter, + TaggableManager: TaggableManagerAdapter, + GenericRelation: GenericRelationAdapter, + } + def __init__(self): + self._scanned_for_adapters = False + self.adapters_by_field_class = {} + + def _scan_for_adapters(self): + adapters = dict(self.BASE_ADAPTERS_BY_FIELD_CLASS) -@lru_cache(maxsize=None) -def get_field_adapter(field): - # find the adapter class for the most specific class in the field's inheritance tree + for fn in hooks.get_hooks('register_field_adapters'): + adapters.update(fn()) + + self.adapters_by_field_class = adapters - for field_class in type(field).__mro__: - if field_class in ADAPTERS_BY_FIELD_CLASS: - adapter_class = ADAPTERS_BY_FIELD_CLASS[field_class] - return adapter_class(field) + @lru_cache(maxsize=None) + def get_field_adapter(self, field): + # find the adapter class for the most specific class in the field's inheritance tree - raise ValueError("No adapter found for field: %r" % field) + if not self._scanned_for_adapters: + self._scan_for_adapters() + + for field_class in type(field).__mro__: + if field_class in self.adapters_by_field_class: + adapter_class = self.adapters_by_field_class[field_class] + return adapter_class(field) + + raise ValueError("No adapter found for field: %r" % field) + + +adapter_registry = AdapterRegistry() diff --git a/wagtail_transfer/operations.py b/wagtail_transfer/operations.py index e91d001..8d8e68f 100644 --- a/wagtail_transfer/operations.py +++ b/wagtail_transfer/operations.py @@ -7,7 +7,7 @@ from treebeard.mp_tree import MP_Node from wagtail.core.models import Page -from .field_adapters import get_field_adapter +from .field_adapters import adapter_registry from .locators import get_locator_for_model from .models import get_base_model, get_base_model_for_path, get_model_for_path @@ -583,7 +583,7 @@ def _populate_fields(self, context): except KeyError: continue - get_field_adapter(field).populate_field(self.instance, value, context) + adapter_registry.get_field_adapter(field).populate_field(self.instance, value, context) def _populate_many_to_many_fields(self, context): save_needed = False @@ -626,7 +626,7 @@ def dependencies(self): for field in self.model._meta.get_fields(): if isinstance(field, models.Field): val = self.object_data['fields'].get(field.name) - deps.update(get_field_adapter(field).get_dependencies(val)) + deps.update(adapter_registry.get_field_adapter(field).get_dependencies(val)) return deps diff --git a/wagtail_transfer/serializers.py b/wagtail_transfer/serializers.py index 7943e83..2ef8240 100644 --- a/wagtail_transfer/serializers.py +++ b/wagtail_transfer/serializers.py @@ -5,7 +5,7 @@ from treebeard.mp_tree import MP_Node from wagtail.core.models import Page -from .field_adapters import get_field_adapter +from .field_adapters import adapter_registry from .models import get_base_model @@ -39,7 +39,7 @@ def __init__(self, model): if not isinstance(related_field, ParentalKey): continue - self.field_adapters.append(get_field_adapter(field)) + self.field_adapters.append(adapter_registry.get_field_adapter(field)) def get_objects_by_ids(self, ids): """