From 5b44ac4b9864946a8fab018b66883d70d31f6c99 Mon Sep 17 00:00:00 2001 From: GeoWill Date: Thu, 14 Nov 2024 13:42:39 +0000 Subject: [PATCH] Add a database flag to update_addressbase command --- .../management/commands/update_addressbase.py | 9 ++++++++- .../apps/addressbase/tests/test_update_addressbase.py | 10 +++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/polling_stations/apps/addressbase/management/commands/update_addressbase.py b/polling_stations/apps/addressbase/management/commands/update_addressbase.py index b10147302..50a60aab6 100644 --- a/polling_stations/apps/addressbase/management/commands/update_addressbase.py +++ b/polling_stations/apps/addressbase/management/commands/update_addressbase.py @@ -100,6 +100,11 @@ def add_arguments(self, parser): "--uprntocouncil-s3-uri", help="S3 URI for UPRN to Council data file", ) + parser.add_argument( + "--database", + default=PRINCIPAL_DB_NAME, + help="Database name. Defaults to PRINCIPAL_DB_NAME - i.e. RDS if you're on EC2", + ) def teardown(self): self.stdout.write( @@ -129,8 +134,10 @@ def handle(self, *args, **options): self.stdout.write(f"addressbase_path set to {addressbase_path}") self.stdout.write(f"uprntocouncil_path to {uprntocouncil_path}") + database_name = options["database"] + # Get the principal (i.e. RDS) DB - cursor = connections[PRINCIPAL_DB_NAME].cursor() + cursor = connections[database_name].cursor() # Create addressbase updater and set cursor addressbase_updater = AddressbaseUpdater() diff --git a/polling_stations/apps/addressbase/tests/test_update_addressbase.py b/polling_stations/apps/addressbase/tests/test_update_addressbase.py index 88c2a7145..68a4f31b9 100644 --- a/polling_stations/apps/addressbase/tests/test_update_addressbase.py +++ b/polling_stations/apps/addressbase/tests/test_update_addressbase.py @@ -2,6 +2,7 @@ from pathlib import Path from unittest.mock import patch +from django.core.management import call_command from django.db import connection from django.test import TestCase, TransactionTestCase from uk_geo_utils.base_importer import BaseImporter @@ -138,18 +139,13 @@ def test_success(self): self.assertEqual(Address.objects.count(), 2) self.assertEqual(UprnToCouncil.objects.count(), 2) - # setup command - cmd = UpdateAddressbaseCommand() - - # supress output - cmd.stdout = StringIO() - # import data opts = { "addressbase_path": addressbase_path, "uprntocouncil_path": uprntocouncil_path, + "stdout": StringIO(), } - cmd.handle(**opts) + call_command("update_addressbase", **opts) self.assertEqual(Address.objects.count(), 4) self.assertEqual(UprnToCouncil.objects.count(), 4)