diff --git a/README.md b/README.md index d9e6ac3..4dba663 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,15 @@ The following settings are additionally recognised: By default, objects referenced within imported content will be recursively imported to ensure that those references are still valid on the destination site. However, this is not always desirable - for example, if this happened for the Page model, this would imply that any pages linked from an imported page would get imported as well, along with any pages linked from _those_ pages, and so on, leading to an unpredictable number of extra pages being added anywhere in the page tree as a side-effect of the import. Models listed in `WAGTAILTRANSFER_NO_FOLLOW_MODELS` will thus be skipped in this process, leaving the reference unresolved. The effect this has on the referencing page will vary according to the kind of relation: nullable foreign keys, one-to-many and many-to-many relations will simply omit the missing object; references in rich text and StreamField will become broken links (just as linking a page and then deleting it would); while non-nullable foreign keys will prevent the object from being created at all (meaning that any objects referencing _that_ object will end up with unresolved references, to be handled by the same set of rules). +* `WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS = [('wagtailimages.image', 'tagged_items')]` + + Specifies a list of models, their reverse relations to follow, and whether deletions should be synced, when identifying object references that should be imported to the destination site. Defaults to `[('wagtailimages.image', 'tagged_items', True)]`. + + By default, Wagtail Transfer will not follow reverse relations (other than importing child models of `ClusterableModel` subclasses) when identifying referenced models. Specifying a `(model, reverse_relationship_name, track_deletions)` in `WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS` means that when + encountering that model and relation, Wagtail Transfer will follow the reverse relationship from the specified model and add the models found to the import if they do not exist on the destination site. This is typically useful in cases such as tags on non-Page models. The `track_deletions` boolean, + if `True`, will delete any models in the reverse relation on the destination site that do not exist in the source site's reverse relation. As a result, + it should only be used for models that behave strictly like child models but do not use `ParentalKey` - for example, tags, where importing an image with deleted tags should delete those tag linking models on the destination site as well. + Note that these settings do not accept models that are defined as subclasses through [multi-table inheritance](https://docs.djangoproject.com/en/stable/topics/db/models/#multi-table-inheritance) - in particular, they cannot be used to define behaviour that only applies to specific subclasses of Page. * `WAGTAILTRANSFER_CHOOSER_API_PROXY_TIMEOUT = 5` diff --git a/docs/settings.md b/docs/settings.md index fabe13e..104c415 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -87,12 +87,30 @@ that object will end up with unresolved references, to be handled by the same se Note that these settings do not accept models that are defined as subclasses through multi-table inheritance - in particular, they cannot be used to define behaviour that only applies to specific subclasses of Page. + +### `WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS` + +```python +WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS = [('wagtailimages.image', 'tagged_items', True)] +``` + +Specifies a list of models, their reverse relations to follow, and whether deletions should be synced, when identifying object references that should be imported to the destination site. Defaults to `[('wagtailimages.image', 'tagged_items', True)]`. + +By default, Wagtail Transfer will not follow reverse relations (other than importing child models of `ClusterableModel` subclasses) when identifying referenced models. Specifying a `(model, reverse_relationship_name, track_deletions)` in `WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS` means that when +encountering that model and relation, Wagtail Transfer will follow the reverse relationship from the specified model and add the models found to the import if they do not exist on the destination site. This is typically useful in cases such as tags on non-Page models. The `track_deletions` boolean, +if `True`, will delete any models in the reverse relation on the destination site that do not exist in the source site's reverse relation. As a result, +it should only be used for models that behave strictly like child models but do not use `ParentalKey` - for example, tags, where importing an image with deleted tags should delete those tag linking models on the destination site as well. + + +### `WAGTAILTRANSFER_CHOOSER_API_PROXY_TIMEOUT` + ```python WAGTAILTRANSFER_CHOOSER_API_PROXY_TIMEOUT = 5 ``` By default, each API call made to browse the page tree on the source server has a timeout limit of 5 seconds. If you find this threshold is too low, you can increase it. This may be of particular use if you are running two local runservers to test or extend Wagtail Transfer. + ## Hooks ### `register_field_adapters` diff --git a/tests/migrations/0016_advert_tags.py b/tests/migrations/0016_advert_tags.py new file mode 100644 index 0000000..b704e29 --- /dev/null +++ b/tests/migrations/0016_advert_tags.py @@ -0,0 +1,20 @@ +# Generated by Django 3.0.5 on 2020-10-16 11:13 + +from django.db import migrations +import taggit.managers + + +class Migration(migrations.Migration): + + dependencies = [ + ('taggit', '0003_taggeditem_add_unique_index'), + ('tests', '0015_longadvert'), + ] + + operations = [ + migrations.AddField( + model_name='advert', + name='tags', + field=taggit.managers.TaggableManager(help_text='A comma-separated list of tags.', through='taggit.TaggedItem', to='taggit.Tag', verbose_name='Tags'), + ), + ] diff --git a/tests/models.py b/tests/models.py index 71333d7..be6ae5d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,6 @@ from django.db import models from modelcluster.fields import ParentalKey, ParentalManyToManyField +from taggit.managers import TaggableManager from wagtail.core.fields import RichTextField, StreamField from wagtail.core.models import Orderable, Page from wagtail.snippets.models import register_snippet @@ -13,6 +14,7 @@ class SimplePage(Page): class Advert(models.Model): slogan = models.CharField(max_length=255) + tags = TaggableManager() class LongAdvert(Advert): diff --git a/tests/settings.py b/tests/settings.py index 9f0e682..c2c24a5 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -134,6 +134,8 @@ } } +WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS = [('wagtailimages.image', 'tagged_items', True), ('tests.advert', 'tagged_items', True)] + WAGTAILTRANSFER_SECRET_KEY = 'i-am-the-local-secret-key' WAGTAILTRANSFER_UPDATE_RELATED_MODELS = ['wagtailimages.Image', 'tests.advert'] diff --git a/tests/tests/test_api.py b/tests/tests/test_api.py index 73f45f9..b39c3bf 100644 --- a/tests/tests/test_api.py +++ b/tests/tests/test_api.py @@ -195,15 +195,17 @@ def test_parental_keys(self): data = json.loads(response.content) page_data = None + section_data = [] for obj in data['objects']: if obj['model'] == 'tests.sectionedpage' and obj['pk'] == page.pk: page_data = obj - break + if obj['model'] == 'tests.sectionedpagesection': + section_data.append(obj) self.assertEqual(len(page_data['fields']['sections']), 2) - self.assertEqual(page_data['fields']['sections'][0]['model'], 'tests.sectionedpagesection') - self.assertEqual(page_data['fields']['sections'][0]['fields']['title'], "Create the universe") - section_id = page_data['fields']['sections'][0]['pk'] + self.assertEqual(section_data[0]['model'], 'tests.sectionedpagesection') + self.assertTrue(section_data[0]['fields']['title'] == "Create the universe") + section_id = page_data['fields']['sections'][0] # there should also be a uid mapping for the section matching_uids = [ @@ -529,6 +531,21 @@ def test_model_with_multi_table_inheritance(self): self.assertEqual(data['objects'][0]['model'], 'tests.longadvert') # the child object should be serialized + def test_model_with_tags(self): + # test that a reverse relation such as tagged_items is followed to obtain references to the + # tagged_items, if the model and relationship are specified in WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS + ad = Advert.objects.create(slogan='test') + ad.tags.add('test_tag') + + response = self.get({ + 'tests.advert': [ad.pk] + }) + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + + mapped_models = {mapping[0] for mapping in data['mappings']} + self.assertIn('taggit.taggeditem', mapped_models) + def test_image(self): with open(os.path.join(FIXTURES_DIR, 'wagtail.jpg'), 'rb') as f: image = Image.objects.create( diff --git a/tests/tests/test_import.py b/tests/tests/test_import.py index 25a4edd..f5ed40d 100644 --- a/tests/tests/test_import.py +++ b/tests/tests/test_import.py @@ -1,3 +1,4 @@ +import importlib import os.path import shutil from unittest import mock @@ -5,7 +6,7 @@ from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.core.files.images import ImageFile -from django.test import TestCase +from django.test import TestCase, override_settings from wagtail.core.models import Collection, Page from wagtail.images.models import Image @@ -307,28 +308,27 @@ def test_import_page_with_child_models(self): "live": true, "slug": "how-to-boil-an-egg", "intro": "This is how to boil an egg", - "sections": [ - { - "model": "tests.sectionedpagesection", - "pk": 101, - "fields": { - "sort_order": 0, - "title": "Boil the outside of the egg", - "body": "...", - "page": 100 - } - }, - { - "model": "tests.sectionedpagesection", - "pk": 102, - "fields": { - "sort_order": 1, - "title": "Boil the rest of the egg", - "body": "...", - "page": 100 - } - } - ] + "sections": [101, 102] + } + }, + { + "model": "tests.sectionedpagesection", + "pk": 101, + "fields": { + "sort_order": 0, + "title": "Boil the outside of the egg", + "body": "...", + "page": 100 + } + }, + { + "model": "tests.sectionedpagesection", + "pk": 102, + "fields": { + "sort_order": 1, + "title": "Boil the rest of the egg", + "body": "...", + "page": 100 } } ] @@ -370,28 +370,27 @@ def test_import_page_with_child_models(self): "live": true, "slug": "how-to-boil-an-egg", "intro": "This is still how to boil an egg", - "sections": [ - { - "model": "tests.sectionedpagesection", - "pk": 102, - "fields": { - "sort_order": 0, - "title": "Boil the egg", - "body": "...", - "page": 100 - } - }, - { - "model": "tests.sectionedpagesection", - "pk": 103, - "fields": { - "sort_order": 1, - "title": "Eat the egg", - "body": "...", - "page": 100 - } - } - ] + "sections": [102, 103] + } + }, + { + "model": "tests.sectionedpagesection", + "pk": 102, + "fields": { + "sort_order": 0, + "title": "Boil the egg", + "body": "...", + "page": 100 + } + }, + { + "model": "tests.sectionedpagesection", + "pk": 103, + "fields": { + "sort_order": 1, + "title": "Eat the egg", + "body": "...", + "page": 100 } } ] @@ -622,7 +621,7 @@ def test_import_image_with_file(self, get): "file_size": 18521, "file_hash": "e4eab12cc50b6b9c619c9ddd20b61d8e6a961ada", "tags": "[]", - "tagged_items": "[]" + "tagged_items": [] } } ] @@ -686,7 +685,7 @@ def test_import_image_with_file_without_root_collection_mapping(self, get): "file_size": 18521, "file_hash": "e4eab12cc50b6b9c619c9ddd20b61d8e6a961ada", "tags": "[]", - "tagged_items": "[]" + "tagged_items": [] } } ] @@ -773,7 +772,7 @@ def test_existing_image_is_not_refetched(self, get): "file_size": 1160, "file_hash": "45c5db99aea04378498883b008ee07528f5ae416", "tags": "[]", - "tagged_items": "[]" + "tagged_items": [] } } ] @@ -853,7 +852,7 @@ def test_replace_image(self, get): "file_size": 27, "file_hash": "e4eab12cc50b6b9c619c9ddd20b61d8e6a961ada", "tags": "[]", - "tagged_items": "[]" + "tagged_items": [] } } ] @@ -1415,3 +1414,161 @@ def test_import_multi_table_model(self): self.assertIsNotNone(imported_ad) self.assertEqual(imported_ad.slogan, "test") self.assertEqual(imported_ad.description, "longertest") + + def test_import_model_with_generic_foreign_key(self): + # test importing a model with a generic foreign key by importing a model that implements tagging using standard taggit (not ParentalKey) + data = """{ + "ids_for_import": [["tests.advert", 4]], + "mappings": [ + ["taggit.tag", 152, "ac92b2ba-0fa6-11eb-800b-287fcf66f689"], + ["tests.advert", 4, "ac931726-0fa6-11eb-800c-287fcf66f689"], + ["taggit.taggeditem", 150, "ac938e5a-0fa6-11eb-800d-287fcf66f689"] + ], + "objects": [ + { + "model": "tests.advert", + "pk": 4, + "fields": {"longadvert": null, "sponsoredpage": null, "slogan": "test", "tags": "[]", "tagged_items": null} + }, + { + "model": "taggit.taggeditem", + "pk": 150, + "fields": {"content_object": ["tests.advert", 4], "tag": 152} + }, + { + "model": "taggit.tag", + "pk": 152, + "fields": {"name": "test_tag", "slug": "testtag"} + } + ] + }""" + + importer = ImportPlanner(root_page_source_pk=1, destination_parent_id=None) + importer.add_json(data) + importer.run() + + imported_ad = Advert.objects.filter(id=4).first() + self.assertIsNotNone(imported_ad) + self.assertEqual(imported_ad.tags.first().name, "test_tag") + + def test_import_model_with_deleted_reverse_related_models(self): + # test re-importing a model where WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS is used to track tag deletions + # will delete tags correctly + data = """{ + "ids_for_import": [["tests.advert", 4]], + "mappings": [ + ["taggit.tag", 152, "ac92b2ba-0fa6-11eb-800b-287fcf66f689"], + ["tests.advert", 4, "ac931726-0fa6-11eb-800c-287fcf66f689"], + ["taggit.taggeditem", 150, "ac938e5a-0fa6-11eb-800d-287fcf66f689"] + ], + "objects": [ + { + "model": "tests.advert", + "pk": 4, + "fields": {"longadvert": null, "sponsoredpage": null, "slogan": "test", "tags": "[]", "tagged_items": [150]} + }, + { + "model": "taggit.taggeditem", + "pk": 150, + "fields": {"content_object": ["tests.advert", 4], "tag": 152} + }, + { + "model": "taggit.tag", + "pk": 152, + "fields": {"name": "test_tag", "slug": "testtag"} + } + ] + }""" + + importer = ImportPlanner(root_page_source_pk=1, destination_parent_id=None) + importer.add_json(data) + importer.run() + + imported_ad = Advert.objects.filter(id=4).first() + self.assertIsNotNone(imported_ad) + self.assertEqual(imported_ad.tags.first().name, "test_tag") + + data = """{ + "ids_for_import": [["tests.advert", 4]], + "mappings": [ + ["tests.advert", 4, "ac931726-0fa6-11eb-800c-287fcf66f689"] + ], + "objects": [ + { + "model": "tests.advert", + "pk": 4, + "fields": {"longadvert": null, "sponsoredpage": null, "slogan": "test", "tags": "[]", "tagged_items": []} + } + ] + }""" + + importer = ImportPlanner(root_page_source_pk=1, destination_parent_id=None) + importer.add_json(data) + importer.run() + + imported_ad = Advert.objects.filter(id=4).first() + self.assertIsNotNone(imported_ad) + self.assertIsNone(imported_ad.tags.first()) + + @override_settings(WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS=[('tests.advert', 'tagged_items', False)]) + def test_import_model_with_untracked_deleted_reverse_related_models(self): + # test re-importing a model where WAGTAILTRANFER_FOLLOWED_REVERSE_RELATIONS is not used to track tag deletions + # will not delete tags + from wagtail_transfer import field_adapters + importlib.reload(field_adapters) + # force reload field adapters as followed/deleted variables are set on module load, so will not get new setting + data = """{ + "ids_for_import": [["tests.advert", 4]], + "mappings": [ + ["taggit.tag", 152, "ac92b2ba-0fa6-11eb-800b-287fcf66f689"], + ["tests.advert", 4, "ac931726-0fa6-11eb-800c-287fcf66f689"], + ["taggit.taggeditem", 150, "ac938e5a-0fa6-11eb-800d-287fcf66f689"] + ], + "objects": [ + { + "model": "tests.advert", + "pk": 4, + "fields": {"longadvert": null, "sponsoredpage": null, "slogan": "test", "tags": "[]", "tagged_items": [150]} + }, + { + "model": "taggit.taggeditem", + "pk": 150, + "fields": {"content_object": ["tests.advert", 4], "tag": 152} + }, + { + "model": "taggit.tag", + "pk": 152, + "fields": {"name": "test_tag", "slug": "testtag"} + } + ] + }""" + + importer = ImportPlanner(root_page_source_pk=1, destination_parent_id=None) + importer.add_json(data) + importer.run() + + imported_ad = Advert.objects.filter(id=4).first() + self.assertIsNotNone(imported_ad) + self.assertEqual(imported_ad.tags.first().name, "test_tag") + + data = """{ + "ids_for_import": [["tests.advert", 4]], + "mappings": [ + ["tests.advert", 4, "ac931726-0fa6-11eb-800c-287fcf66f689"] + ], + "objects": [ + { + "model": "tests.advert", + "pk": 4, + "fields": {"longadvert": null, "sponsoredpage": null, "slogan": "test", "tags": "[]", "tagged_items": []} + } + ] + }""" + + importer = ImportPlanner(root_page_source_pk=1, destination_parent_id=None) + importer.add_json(data) + importer.run() + + imported_ad = Advert.objects.filter(id=4).first() + self.assertIsNotNone(imported_ad) + self.assertIsNotNone(imported_ad.tags.first()) diff --git a/wagtail_transfer/field_adapters.py b/wagtail_transfer/field_adapters.py index 6d055fe..6f014a3 100644 --- a/wagtail_transfer/field_adapters.py +++ b/wagtail_transfer/field_adapters.py @@ -4,14 +4,19 @@ from urllib.parse import urlparse from django.conf import settings +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models.fields.reverse_related import ManyToOneRel +from django.utils.functional import cached_property +from modelcluster.fields import ParentalKey from taggit.managers import TaggableManager from wagtail.core import hooks from wagtail.core.fields import RichTextField, StreamField from .files import File, FileTransferError, get_file_hash, get_file_size -from .models import get_base_model +from .locators import get_locator_for_model +from .models import get_base_model, get_base_model_for_path from .richtext import get_reference_handler from .streamfield import get_object_references, update_object_ids @@ -19,6 +24,14 @@ from django.utils.encoding import is_protected_type +WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS = getattr(settings, "WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS", [('wagtailimages.image', 'tagged_items', True)]) +FOLLOWED_REVERSE_RELATIONS = { + (model_label.lower(), relation.lower()) for model_label, relation, _ in WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS +} +DELETED_REVERSE_RELATIONS = { + (model_label.lower(), relation.lower()) for model_label, relation, track_deletions in WAGTAILTRANSFER_FOLLOWED_REVERSE_RELATIONS if track_deletions +} + class FieldAdapter: def __init__(self, field): @@ -57,6 +70,12 @@ def get_dependencies(self, value): """ return set() + def get_object_deletions(self, instance, value, context): + """ + A set of (base_model_class, id) tuples for objects that must be deleted at the destination site + """ + return set() + def update_object_references(self, value, destination_ids_by_source): """ Return a modified version of value with object references replaced by their corresponding @@ -73,6 +92,25 @@ def populate_field(self, instance, value, context): value = self.update_object_references(value, context.destination_ids_by_source) setattr(instance, self.field.get_attname(), value) + def get_managed_fields(self): + """ + Normally, a FieldAdapter will adapt a single field. However, more complex fields like + GenericForeignKey may 'manage' several other fields. get_managed_fields returns a list of names + of managed fields, whose field adapters should not be used when serializing the model. Note + that if a managed field also has managed fields itself, these will also be ignored when + serializing the model - the current field adapter is expected to address all managed fields in + the chain. + """ + return [] + + def get_objects_to_serialize(self, instance): + """ + Return a set of (model_class, id) pairs for objects that should be serialized on export, before + it is known whether or not they exist or should be updated at the destination site + """ + return set() + + class ForeignKeyAdapter(FieldAdapter): def __init__(self, field): @@ -101,32 +139,105 @@ def update_object_references(self, value, destination_ids_by_source): return destination_ids_by_source.get((self.related_base_model, value)) +class GenericForeignKeyAdapter(FieldAdapter): + def serialize(self, instance): + linked_instance = getattr(instance, self.field.name, None) + if linked_instance: + # here we do not use the base model, as the GFK could be pointing specifically at the child + # which needs to be represented accurately + return (linked_instance._meta.label_lower, linked_instance.pk) + + def get_object_references(self, instance): + linked_instance = getattr(instance, self.field.name, None) + if linked_instance: + return {(get_base_model(linked_instance), linked_instance.pk)} + return set() + + def get_dependencies(self, value): + if value is None: + return set() + + model_path, model_id = value + base_model = get_base_model_for_path(model_path) + + # GenericForeignKey itself has no blank or null properties, so we need to determine its nullable status + # from the underlying fields it uses + options = self.field.model._meta + ct_field = options.get_field(self.field.ct_field) + fk_field = options.get_field(self.field.ct_field) + + if all((ct_field.blank, ct_field.null, fk_field.blank, fk_field.null)): + # field is nullable, so it's a soft dependency; we can leave the field empty in the + # case that the target object cannot be created + return {(base_model, model_id, False)} + else: + # this is a hard dependency + return {(base_model, model_id, True)} + + def update_object_references(self, value, destination_ids_by_source): + if value: + model_path, model_id = value + base_model = get_base_model_for_path(model_path) + return (model_path, destination_ids_by_source.get((base_model, model_id))) + + def populate_field(self, instance, value, context): + model_id, content_type = None, None + if value: + model_path, model_id = self.update_object_references(value, context.destination_ids_by_source) + content_type = ContentType.objects.get_by_natural_key(*model_path.split('.')) + + setattr(instance, instance._meta.get_field(self.field.ct_field).get_attname(), content_type.pk) + setattr(instance, self.field.fk_field, model_id) + + def get_managed_fields(self): + return [self.field.fk_field, self.field.ct_field] + + class ManyToOneRelAdapter(FieldAdapter): def __init__(self, field): super().__init__(field) - self.related_field = field.field - self.related_model = field.related_model - - from .serializers import get_model_serializer - self.related_model_serializer = get_model_serializer(self.related_model) + self.related_field = getattr(field, 'field', None) or getattr(field, 'remote_field', None) + self.related_base_model = get_base_model(field.related_model) + self.is_parental = isinstance(self.related_field, ParentalKey) + self.is_followed = (get_base_model(self.field.model)._meta.label_lower, self.name) in FOLLOWED_REVERSE_RELATIONS def _get_related_objects(self, instance): return getattr(instance, self.name).all() def serialize(self, instance): - return [ - self.related_model_serializer.serialize(obj) - for obj in self._get_related_objects(instance) - ] + if self.is_parental or self.is_followed: + return list(self._get_related_objects(instance).values_list('pk', flat=True)) def get_object_references(self, instance): refs = set() - for obj in self._get_related_objects(instance): - refs.update(self.related_model_serializer.get_object_references(obj)) + if self.is_parental or self.is_followed: + for pk in self._get_related_objects(instance).values_list('pk', flat=True): + refs.add((self.related_base_model, pk)) return refs + def get_object_deletions(self, instance, value, context): + if (self.is_parental or (get_base_model(self.field.model)._meta.label_lower, self.name) in DELETED_REVERSE_RELATIONS): + value = value or [] + uids = {context.uids_by_source[(self.related_base_model, pk)] for pk in value} + # delete any related objects on the existing object if they can't be mapped back + # to one of the uids in the new set + locator = get_locator_for_model(self.related_base_model) + matched_destination_ids = set() + for uid in uids: + child = locator.find(uid) + if child is not None: + matched_destination_ids.add(child.pk) + + return {child for child in self._get_related_objects(instance) if child.pk not in matched_destination_ids} + return set() + + def get_objects_to_serialize(self, instance): + if self.is_parental: + return getattr(instance, self.name).all() + return set() + def populate_field(self, instance, value, context): - raise Exception('populate_field is not supported on many-to-one relations') + pass class RichTextAdapter(FieldAdapter): @@ -240,10 +351,8 @@ def populate_field(self, instance, value, context): pass -class GenericRelationAdapter(FieldAdapter): - def populate_field(self, instance, value, context): - # TODO - pass +class GenericRelationAdapter(ManyToOneRelAdapter): + pass class AdapterRegistry: @@ -257,6 +366,7 @@ class AdapterRegistry: models.ManyToManyField: ManyToManyFieldAdapter, TaggableManager: TaggableManagerAdapter, GenericRelation: GenericRelationAdapter, + GenericForeignKey: GenericForeignKeyAdapter, } def __init__(self): @@ -283,7 +393,5 @@ def get_field_adapter(self, field): adapter_class = self.adapters_by_field_class[field_class] return adapter_class(field) - raise ValueError("No adapter found for field: %r" % field) - adapter_registry = AdapterRegistry() diff --git a/wagtail_transfer/operations.py b/wagtail_transfer/operations.py index 8d8e68f..058d0bd 100644 --- a/wagtail_transfer/operations.py +++ b/wagtail_transfer/operations.py @@ -1,4 +1,5 @@ import json +from copy import copy from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -206,12 +207,30 @@ def add_json(self, json_data): """ data = json.loads(json_data) - # add source id -> uid mappings to the uids_by_source dict + # for each ID in the import list, add to base_import_ids as an object explicitly selected + # for import + for model_path, source_id in data['ids_for_import']: + model = get_base_model_for_path(model_path) + self.base_import_ids.add((model, source_id)) + + # add source id -> uid mappings to the uids_by_source dict, and add objectives + # for importing referenced models for model_path, source_id, jsonish_uid in data['mappings']: model = get_base_model_for_path(model_path) uid = get_locator_for_model(model).uid_from_json(jsonish_uid) self.context.uids_by_source[(model, source_id)] = uid + base_import = (model, source_id) in self.base_import_ids + + if base_import or model_path not in NO_FOLLOW_MODELS: + objective = Objective( + model, source_id, self.context, + must_update=(base_import or model_path in UPDATE_RELATED_MODELS) + ) + + # add to the set of objectives that need handling + self._add_objective(objective) + # add object data to the object_data_by_source dict for obj_data in data['objects']: self._add_object_data_to_lookup(obj_data) @@ -219,16 +238,6 @@ def add_json(self, json_data): # retry tasks that were previously postponed due to missing object data self._retry_tasks() - # for each ID in the import list, add to base_import_ids as an object explicitly selected - # for import, and add an objective to specify that we want an up-to-date copy of that - # object on the destination site - for model_path, source_id in data['ids_for_import']: - model = get_base_model_for_path(model_path) - self.base_import_ids.add((model, source_id)) - objective = Objective(model, source_id, self.context, must_update=True) - - # add to the set of objectives that need handling - self._add_objective(objective) # Process all unhandled objectives - which may trigger new objectives as dependencies of # the resulting operations - until no unhandled objectives remain @@ -243,10 +252,29 @@ def _add_object_data_to_lookup(self, obj_data): def _add_objective(self, objective): # add to the set of objectives that need handling, unless it's one we've already seen - # (in which case it's either in the queue to be handled, or has been handled already) - if objective not in self.objectives: - self.objectives.add(objective) - self.unhandled_objectives.add(objective) + # (in which case it's either in the queue to be handled, or has been handled already). + # An objective to update a model supercedes an objective to ensure it exists + + if not objective.must_update: + update_objective = copy(objective) + update_objective.must_update = True + else: + update_objective = objective + + if update_objective in self.objectives: + # We're already updating the model, so this objective isn't relevant + return + elif objective.must_update: + # We're going to add a new objective to update the model + # so we should remove any existing objective that doesn't update the model + no_update_objective = copy(objective) + no_update_objective.must_update = False + self.objectives.discard(no_update_objective) + self.unhandled_objectives.discard(no_update_objective) + + self.objectives.add(objective) + self.unhandled_objectives.add(objective) + def _handle_objective(self, objective): if not objective.exists_at_destination: @@ -355,35 +383,16 @@ def _handle_task(self, task): related_base_model = get_base_model(rel.related_model) child_uids = set() - for child_obj_data in object_data['fields'][rel.name]: - # Add child object data to the object_data_by_source lookup - self._add_object_data_to_lookup(child_obj_data) + for child_obj_pk in object_data['fields'][rel.name]: # Add an objective for handling the child object. Regardless of whether # this is a 'create' or 'update' task, we want the child objects to be at # their most up-to-date versions, so set the objective to 'must update' + self._add_objective( - Objective(related_base_model, child_obj_data['pk'], self.context, must_update=True) + Objective(related_base_model, child_obj_pk, self.context, must_update=True) ) - # look up the child object's UID - uid = self.context.uids_by_source[(related_base_model, child_obj_data['pk'])] - child_uids.add(uid) - - if action == 'update': - # delete any child objects on the existing object if they can't be mapped back - # to one of the uids in the new set - locator = get_locator_for_model(related_base_model) - matched_destination_ids = set() - for uid in child_uids: - child = locator.find(uid) - if child is not None: - matched_destination_ids.add(child.pk) - - for child in getattr(obj, rel.name).all(): - if child.pk not in matched_destination_ids: - self.operations.add(DeleteModel(child)) - if operation is not None: self.operations.add(operation) @@ -410,6 +419,9 @@ def _handle_task(self, task): Objective(model, source_id, self.context, must_update=(model._meta.label_lower in UPDATE_RELATED_MODELS)) ) + for instance in operation.deletions(self.context): + self.operations.add(DeleteModel(instance)) + def _retry_tasks(self): """ Retry tasks that were previously postponed due to missing object data @@ -449,6 +461,13 @@ def run(self): with transaction.atomic(): for operation in operation_order: operation.run(self.context) + + # pages must only have revisions saved after all child objects have been updated, imported, or deleted, otherwise + # they will capture outdated versions of child objects in the revision + for operation in operation_order: + if isinstance(operation.instance, Page): + operation.instance.save_revision() + def _check_satisfiable(self, operation, statuses): # Check whether the given operation's dependencies are satisfiable. statuses is a dict of @@ -556,6 +575,10 @@ def dependencies(self): """ return set() + def deletions(self, context): + # the set of objects that must be deleted when we import this object + return set() + class SaveOperationMixin: """ @@ -574,16 +597,15 @@ def base_model(self): def _populate_fields(self, context): for field in self.model._meta.get_fields(): - if not isinstance(field, models.Field): - # populate data for actual fields only; ignore reverse relations - continue - try: value = self.object_data['fields'][field.name] except KeyError: continue - adapter_registry.get_field_adapter(field).populate_field(self.instance, value, context) + adapter = adapter_registry.get_field_adapter(field) + + if adapter: + adapter.populate_field(self.instance, value, context) def _populate_many_to_many_fields(self, context): save_needed = False @@ -624,12 +646,25 @@ def dependencies(self): deps = super().dependencies for field in self.model._meta.get_fields(): - if isinstance(field, models.Field): - val = self.object_data['fields'].get(field.name) - deps.update(adapter_registry.get_field_adapter(field).get_dependencies(val)) + val = self.object_data['fields'].get(field.name) + adapter = adapter_registry.get_field_adapter(field) + if adapter: + deps.update(adapter.get_dependencies(val)) return deps + def deletions(self, context): + # the set of objects that must be deleted when we import this object + + deletions = super().deletions(context) + for field in self.model._meta.get_fields(): + val = self.object_data['fields'].get(field.name) + adapter = adapter_registry.get_field_adapter(field) + if adapter: + deletions.update(adapter.get_object_deletions(self.instance, val, context)) + + return deletions + class CreateModel(SaveOperationMixin, Operation): def __init__(self, model, object_data): @@ -685,10 +720,6 @@ def _save(self, context): # Add the page to the database as a child of parent parent.add_child(instance=self.instance) - if isinstance(self.instance, Page): - # Also save this as a revision, so that it exists in revision history - self.instance.save_revision(changed=False) - class UpdateModel(SaveOperationMixin, Operation): def __init__(self, instance, object_data): @@ -701,15 +732,6 @@ def run(self, context): self._save(context) self._populate_many_to_many_fields(context) - def _save(self, context): - super()._save(context) - if isinstance(self.instance, Page): - # Also save this as a revision, so that: - # * the edit-page view will pick up this imported version rather than any currently-existing drafts - # * it exists in revision history - # * the Page.draft_title field (as used in page listings in the admin) is updated to match the real title - self.instance.save_revision(changed=False) - class DeleteModel(Operation): def __init__(self, instance): diff --git a/wagtail_transfer/serializers.py b/wagtail_transfer/serializers.py index e4fb60a..bac37fb 100644 --- a/wagtail_transfer/serializers.py +++ b/wagtail_transfer/serializers.py @@ -77,30 +77,24 @@ def __init__(self, model): self.model = model self.base_model = get_base_model(model) - self.field_adapters = [] + field_adapters = [] + adapter_managed_fields = [] for field in self.model._meta.get_fields(): if field.name in self.ignored_fields: continue - if isinstance(field, models.Field): - # this is a genuine field rather than a reverse relation - - # ignore primary keys (including MTI parent pointers) - if field.primary_key: - continue - else: - # this is probably a reverse relation, so fetch its related field - try: - related_field = field.field - except AttributeError: - # we don't know what sort of pseudo-field this is, so skip it + # ignore primary keys (including MTI parent pointers) + if getattr(field, 'primary_key', False): continue - # ignore relations other than ParentalKey - if not isinstance(related_field, ParentalKey): - continue + adapter = adapter_registry.get_field_adapter(field) + + if adapter: + adapter_managed_fields = adapter_managed_fields + adapter.get_managed_fields() + field_adapters.append(adapter) + + self.field_adapters = [adapter for adapter in field_adapters if adapter.name not in adapter_managed_fields] - self.field_adapters.append(adapter_registry.get_field_adapter(field)) def get_objects_by_ids(self, ids): """ @@ -134,6 +128,12 @@ def get_object_references(self, instance): refs.update(f.get_object_references(instance)) return refs + def get_objects_to_serialize(self, instance): + objects = set() + for f in self.field_adapters: + objects.update(f.get_objects_to_serialize(instance)) + return objects + class TreeModelSerializer(ModelSerializer): ignored_fields = ['path', 'depth', 'numchild'] diff --git a/wagtail_transfer/views.py b/wagtail_transfer/views.py index a89a1a8..0ee85ac 100644 --- a/wagtail_transfer/views.py +++ b/wagtail_transfer/views.py @@ -38,10 +38,15 @@ def pages_for_export(request, root_page_id): objects = [] object_references = set() - for page in pages: - serializer = get_model_serializer(type(page)) - objects.append(serializer.serialize(page)) - object_references.update(serializer.get_object_references(page)) + models_to_serialize = set(pages) + serialized_models = set() + + while models_to_serialize: + model = models_to_serialize.pop() + serializer = get_model_serializer(type(model)) + objects.append(serializer.serialize(model)) + object_references.update(serializer.get_object_references(model)) + models_to_serialize.update(serializer.get_objects_to_serialize(model).difference(serialized_models)) mappings = [] for model, pk in object_references: @@ -82,10 +87,15 @@ def models_for_export(request, model_path, object_id=None): objects = [] object_references = set() - for model_object in model_objects: - serializer = get_model_serializer(type(model_object)) - objects.append(serializer.serialize(model_object)) - object_references.update(serializer.get_object_references(model_object)) + models_to_serialize = set(model_objects) + serialized_models = set() + + while models_to_serialize: + model = models_to_serialize.pop() + serializer = get_model_serializer(type(model)) + objects.append(serializer.serialize(model)) + object_references.update(serializer.get_object_references(model)) + models_to_serialize.update(serializer.get_objects_to_serialize(model).difference(serialized_models)) mappings = [] for model, pk in object_references: @@ -119,15 +129,20 @@ def objects_for_export(request): objects = [] object_references = set() + serialized_models = set() + models_to_serialize = set() for model_path, ids in request_data.items(): model = get_model_for_path(model_path) serializer = get_model_serializer(model) - for obj in serializer.get_objects_by_ids(ids): - instance_serializer = get_model_serializer(type(obj)) # noqa - objects.append(instance_serializer.serialize(obj)) - object_references.update(instance_serializer.get_object_references(obj)) + models_to_serialize.update(serializer.get_objects_by_ids(ids)) + while models_to_serialize: + instance = models_to_serialize.pop() + serializer = get_model_serializer(type(instance)) + objects.append(serializer.serialize(instance)) + object_references.update(serializer.get_object_references(instance)) + models_to_serialize.update(serializer.get_objects_to_serialize(instance).difference(serialized_models)) mappings = [] for model, pk in object_references: