Skip to content

Commit

Permalink
Use the new enginefacade from oslo_db
Browse files Browse the repository at this point in the history
As per blueprint [1], the existing use of oslo_db session
handling (e.g., context.session.begin()) introduces potential issues.
Notably, unit tests failed during the Caracal release, though
no definitive deployment impact has been identified yet.

To future-proof the code and align with recommended practices,
we are migrating to the enginefacade pattern now.
This involves replacing:
with context.session.begin():
   context.session.add(obj)

with 'db_api.CONTEXT_WRITER.using(context)'

[1] https://blueprints.launchpad.net/neutron/+spec/enginefacade-switch
[2] Oslo db spec: http://specs.openstack.org/openstack/oslo-specs/specs/kilo/make-enginefacade-a-facade.html
  • Loading branch information
sven-rosenzweig committed Nov 22, 2024
1 parent 0586f85 commit 66f6e82
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
57 changes: 29 additions & 28 deletions networking_aci/plugins/ml2/drivers/mech_aci/allocations_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ def _allocate_vlan_segment(self, network, host_id, level, host_config):
segment_type = host_config.get('segment_type', 'vlan')
segment_physnet = host_config.get('physical_network', None)

session = db_api.get_writer_session()
ctx = context.get_admin_context()
segmentation_id = self._get_provider_attribute(network, "provider:segmentation_id")
network_id = network["id"]
segment = session.query(ml2_models.NetworkSegment).filter_by(segmentation_id=segmentation_id,
with db_api.CONTEXT_READER.using(ctx):
segment = ctx.session.query(ml2_models.NetworkSegment).filter_by(segmentation_id=segmentation_id,
physical_network=segment_physnet,
network_type=segment_type,
network_id=network_id,
level=level).first()

if not segment:
with session.begin(subtransactions=True):
with db_api.CONTEXT_WRITER.using(ctx):
segment = ml2_models.NetworkSegment(
id=uuidutils.generate_uuid(),
network_id=network_id,
Expand All @@ -123,7 +124,7 @@ def _allocate_vlan_segment(self, network, host_id, level, host_config):
segment_index=level,
is_dynamic=False
)
session.add(segment)
ctx.session.add(segment)

return AllocationsModel(host=host_id, level=level, segment_type=segment_type, segmentation_id=segmentation_id,
segment_id=segment.id, network_id=network_id)
Expand All @@ -135,21 +136,21 @@ def _allocate_vxlan_segment(self, network, host_id, level, host_config):
segment_physnet = host_config.get('physical_network', None)
network_id = network['id']

session = db_api.get_writer_session()
with db_api.exc_to_retry(sa.exc.IntegrityError), session.begin(subtransactions=True):
ctx = context.get_admin_context()
with db_api.exc_to_retry(sa.exc.IntegrityError), db_api.CONTEXT_WRITER.using(ctx):
LOG.debug("Searching for available allocation for host id %(host_id)s "
"segment_type %(segment_type)s network_id %(network_id)s segment_physnet %(segment_physnet)s",
{"host_id": host_id, "segment_type": segment_type, "segment_physnet": segment_physnet,
"network_id": network_id}
)

alloc = session.query(AllocationsModel).filter_by(host=host_id, level=level, segment_type=segment_type,
alloc = ctx.session.query(AllocationsModel).filter_by(host=host_id, level=level, segment_type=segment_type,
network_id=network_id).first()
if alloc and alloc.segment_id:
return alloc

# we regard a segment as unallocated if its segment_id is None
select = (session.query(AllocationsModel).
select = (ctx.session.query(AllocationsModel).
filter_by(host=host_id, level=level, segment_type=segment_type, segment_id=None))

# Selected segment can be allocated before update by someone else,
Expand All @@ -171,7 +172,7 @@ def _allocate_vxlan_segment(self, network, host_id, level, host_config):
segment_index=level,
is_dynamic=False
)
session.add(segment)
ctx.session.add(segment)

raw_segment = {
'host': alloc.host,
Expand All @@ -182,7 +183,7 @@ def _allocate_vxlan_segment(self, network, host_id, level, host_config):
LOG.debug("%(type)s segment allocated from pool with %(segment)s ",
{"type": alloc.segment_type, "segment": alloc.segmentation_id})

count = (session.query(AllocationsModel).
count = (ctx.session.query(AllocationsModel).
filter_by(segment_id=None, **raw_segment).
update({"network_id": network_id, 'segment_id': segment.id}))

Expand All @@ -207,10 +208,10 @@ def release_segment(self, network, host_config, level, segment):
def _release_vlan_segment(self, network, host_config, level, segment):
LOG.debug("Checking release for segment %(segment)s with top level VLAN segment", {"segment": segment})

session = db_api.get_writer_session()
with session.begin(subtransactions=True):
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
# Delete the network segment
query = (session.query(ml2_models.NetworkSegment).
query = (ctx.session.query(ml2_models.NetworkSegment).
filter_by(id=segment['id'], network_id=network['id'], network_type=segment['network_type'],
segmentation_id=segment['segmentation_id'], segment_index=level))
query.delete()
Expand All @@ -222,9 +223,9 @@ def _release_vxlan_segment(self, network, host_config, level, segment):
segmentation_id = segment['segmentation_id']
network_id = network['id']

session = db_api.get_writer_session()
with session.begin(subtransactions=True):
select = (session.query(models.PortBindingLevel).
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
select = (ctx.session.query(models.PortBindingLevel).
filter_by(segment_id=segment_id, level=level))

if select.count() > 0:
Expand All @@ -234,7 +235,7 @@ def _release_vxlan_segment(self, network, host_config, level, segment):

segmentation_ids = self._segmentation_ids(host_config)
inside = segmentation_id in segmentation_ids
query = (session.query(AllocationsModel).
query = (ctx.session.query(AllocationsModel).
filter_by(network_id=network_id, level=level, segment_type=segment_type,
segment_id=segment_id))
if inside:
Expand All @@ -243,7 +244,7 @@ def _release_vxlan_segment(self, network, host_config, level, segment):
query.delete()

# Delete the network segment
query = (session.query(ml2_models.NetworkSegment).
query = (ctx.session.query(ml2_models.NetworkSegment).
filter_by(id=segment_id, network_id=network_id, network_type=segment_type,
segmentation_id=segmentation_id, segment_index=level))

Expand Down Expand Up @@ -281,15 +282,15 @@ def allocate_baremetal_segment(self, context, network, hostgroup, level, segment
_release_vxlan_segment().
"""
is_access = segmentation_id is None
session = context.session
ctx = context.get_admin_context()
segment_type = hostgroup.get('segment_type', 'vlan')
segment_physnet = hostgroup.get('physical_network')
network_id = network['id']
access_id_pool = common.get_set_from_ranges(hostgroup['baremetal_access_vlan_ranges'])

with db_api.exc_to_retry(sa.exc.IntegrityError), session.begin(subtransactions=True):
with db_api.exc_to_retry(sa.exc.IntegrityError), db_api.CONTEXT_WRITER.using(ctx):
# 1. check if segment exists
existing_segments = (session.query(ml2_models.NetworkSegment)
existing_segments = (ctx.session.query(ml2_models.NetworkSegment)
.filter_by(network_id=network_id, physical_network=segment_physnet,
segment_index=level, network_type=segment_type)
.all())
Expand Down Expand Up @@ -323,7 +324,7 @@ def allocate_baremetal_segment(self, context, network, hostgroup, level, segment
segment_id=far_segment_id)
else:
# for trunk mode: check segmentation_id is not already in use in physnet
existing_segments = (session.query(ml2_models.NetworkSegment)
existing_segments = (ctx.session.query(ml2_models.NetworkSegment)
.filter_by(segmentation_id=segmentation_id, physical_network=segment_physnet,
segment_index=level, network_type=segment_type)
.all())
Expand All @@ -334,7 +335,7 @@ def allocate_baremetal_segment(self, context, network, hostgroup, level, segment
# 3. no segment exists, allocate one
if is_access:
# find a free vlan id from the pool
physnet_segments = (session.query(ml2_models.NetworkSegment)
physnet_segments = (ctx.session.query(ml2_models.NetworkSegment)
.filter_by(physical_network=segment_physnet)
.all())
used_ids = set(n.segmentation_id for n in physnet_segments)
Expand All @@ -353,7 +354,7 @@ def allocate_baremetal_segment(self, context, network, hostgroup, level, segment
segment_index=level,
is_dynamic=False
)
session.add(segment)
ctx.session.add(segment)

return segment

Expand Down Expand Up @@ -443,16 +444,16 @@ def _sync_allocations(self):
def _sync_hostgroup_modes(self):
LOG.info("Preparing hostgroup modes sync")

session = db_api.get_writer_session()
with session.begin(subtransactions=True):
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
# fetch all mode-hostgroups from db
db_groups = []
for db_entry in (session.query(HostgroupModeModel).with_for_update()):
for db_entry in (ctx.session.query(HostgroupModeModel).with_for_update()):
db_groups.append(db_entry.hostgroup)

for hg_name, hg in ACI_CONFIG.hostgroups.items():
if hg['direct_mode'] and hg_name not in db_groups:
LOG.info("Adding %s to hostgroup db", hg_name)
hgmm = HostgroupModeModel(hostgroup=hg_name)
session.add(hgmm)
ctx.session.add(hgmm)
LOG.info("Hostgroup modes synced")
26 changes: 20 additions & 6 deletions networking_aci/plugins/ml2/drivers/mech_aci/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from neutron.db.models import segment as segment_models
from neutron.db.models import tag as tag_models
from neutron.db import segments_db as ml2_db
from neutron_lib.db import api as db_api
from neutron.plugins.ml2 import models as ml2_models
import neutron.services.trunk.models as trunk_models
from oslo_config import cfg
Expand All @@ -49,14 +50,15 @@ class DBPlugin(db_base_plugin_v2.NeutronDbPluginV2,
def __init__(self):
pass

@db_api.CONTEXT_READER
def get_ports_with_binding(self, context, network_id):
with context.session.begin(subtransactions=True):
query = context.session.query(models_v2.Port)
query1 = query.join(ml2_models.PortBinding)
bind_ports = query1.filter(models_v2.Port.network_id == network_id)
query = context.session.query(models_v2.Port)
query1 = query.join(ml2_models.PortBinding)
bind_ports = query1.filter(models_v2.Port.network_id == network_id)

return bind_ports
return bind_ports

@db_api.CONTEXT_READER
def get_network_ids(self, context):
result = []
query = context.session.query(models_v2.Network.id).order_by(models_v2.Network.id)
Expand Down Expand Up @@ -87,6 +89,7 @@ def get_address_scope_name(self, context, subnet_pool_id):

return scope.get('name')

@db_api.CONTEXT_READER
def get_hostgroup_modes(self, context, hostgroup_names=None):
hg_modes = {}
query = context.session.query(HostgroupModeModel)
Expand All @@ -100,15 +103,17 @@ def get_hostgroup_mode(self, context, hostgroup_name):
hg_modes = self.get_hostgroup_modes(context, [hostgroup_name])
return hg_modes.get(hostgroup_name)

@db_api.CONTEXT_READER
def set_hostgroup_mode(self, context, hostgroup_name, hostgroup_mode):
with context.session.begin(subtransactions=True):
with db_api.CONTEXT_WRITER.using(context):
query = context.session.query(HostgroupModeModel).filter(HostgroupModeModel.hostgroup == hostgroup_name)
hg = query.first()
if not hg:
return False
hg.mode = hostgroup_mode
return True

@db_api.CONTEXT_READER
def get_hosts_on_segment(self, context, segment_id, level=None):
"""Get all binding hosts (from host or binding_profile) present on a segment"""
# get all ports bound to segment, extract their host
Expand All @@ -125,6 +130,7 @@ def get_hosts_on_segment(self, context, segment_id, level=None):
hosts.add(host)
return hosts

@db_api.CONTEXT_READER
def get_hosts_on_network(self, context, network_id, level=None, with_segment=False, transit_hostgroups=None):
"""Get all binding hosts (from host or binding_profile) present on a network"""
fields = [ml2_models.PortBinding.host, ml2_models.PortBinding.profile]
Expand Down Expand Up @@ -167,6 +173,7 @@ def get_hosts_on_network(self, context, network_id, level=None, with_segment=Fal

return hosts

@db_api.CONTEXT_READER
def get_hosts_on_physnet(self, context, physical_network, level=None, with_segment=False, with_segmentation=False):
"""Get all binding hosts (from host or binding_profile) present on a network
Expand Down Expand Up @@ -201,6 +208,7 @@ def get_hosts_on_physnet(self, context, physical_network, level=None, with_segme
hosts.add(host)
return hosts

@db_api.CONTEXT_READER
def get_segment_ids_by_physnet(self, context, physical_network, fuzzy_match=False):
query = context.session.query(segment_models.NetworkSegment.id)
if fuzzy_match:
Expand All @@ -209,6 +217,7 @@ def get_segment_ids_by_physnet(self, context, physical_network, fuzzy_match=Fals
query = query.filter(segment_models.NetworkSegment.physical_network == physical_network)
return [seg.id for seg in query.all()]

@db_api.CONTEXT_READER
def get_ports_on_network_by_physnet_prefix(self, context, network_id, physical_network_prefix):
# get all ports for a network that are on a segment with a physnet prefix
fields = [
Expand All @@ -233,6 +242,7 @@ def get_ports_on_network_by_physnet_prefix(self, context, network_id, physical_n

return result

@db_api.CONTEXT_READER
def get_bound_projects_by_physnet_prefix(self, context, physical_network_prefix):
# get all projects that have a port bound to a segment with this prefix
query = context.session.query(models_v2.Port.project_id)
Expand All @@ -244,6 +254,7 @@ def get_bound_projects_by_physnet_prefix(self, context, physical_network_prefix)

return [entry.project_id for entry in query.all()]

@db_api.CONTEXT_READER
def get_trunk_vlan_usage_on_project(self, context, project_id, segmentation_id=None):
# return vlan --> networks mapping for aci trunk ports inside a project
query = context.session.query(models_v2.Port.network_id, trunk_models.SubPort.segmentation_id)
Expand All @@ -262,6 +273,7 @@ def get_trunk_vlan_usage_on_project(self, context, project_id, segmentation_id=N

return vlan_map

@db_api.CONTEXT_READER
def get_az_aware_external_subnets(self, context):
if not cfg.CONF.ml2_aci.handle_all_l3_gateways:
return []
Expand Down Expand Up @@ -298,6 +310,7 @@ def get_az_aware_external_subnets(self, context):

return subnets

@db_api.CONTEXT_READER
def get_external_subnet_nullroute_mapping(self, context, level=1):
if not cfg.CONF.ml2_aci.handle_all_l3_gateways:
return {}
Expand Down Expand Up @@ -417,6 +430,7 @@ def get_external_subnet_nullroute_mapping(self, context, level=1):

return subnets

@db_api.CONTEXT_READER
def get_subnetpool_details(self, context, subnetpool_ids):
# get az from tags
fields = [models_v2.SubnetPool.id, tag_models.Tag.tag]
Expand Down
2 changes: 2 additions & 0 deletions networking_aci/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from networking_aci.plugins.ml2.drivers.mech_aci import constants

from unittest import mock
from networking_aci.plugins.ml2.drivers.mech_aci.allocations_manager import AllocationsManager

class NetworkingAciMechanismDriverTestBase(test_plugin.Ml2PluginV2TestCase, base.BaseTestCase):
"""Test case base class for all unit tests."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from neutron.db.models import tag as tag_models
from neutron.db import models_v2
from neutron_lib.api.definitions import external_net as extnet_api
from neutron_lib.db import api as db_api
from neutron_lib import context
from neutron.tests.common import helpers as neutron_test_helpers
from neutron.tests.unit.plugins.ml2 import test_plugin
Expand All @@ -23,7 +24,7 @@ def setUp(self):
super().setUp()
self._register_azs()
ctx = context.get_admin_context()
with ctx.session.begin(subtransactions=True):
with db_api.CONTEXT_WRITER.using(ctx):
self._address_scope = ascope_models.AddressScope(name="the-open-sea", ip_version=4)
ctx.session.add(self._address_scope)

Expand Down Expand Up @@ -57,7 +58,7 @@ def test_create_subnet_network_no_az_snp_az_fails(self):
with self.subnetpool(["1.1.0.0/16", "1.2.0.0/24"], address_scope_id=self._address_scope['id'], name="foo",
tenant_id="foo", admin=True) as snp:
ctx = context.get_admin_context()
with ctx.session.begin():
with db_api.CONTEXT_WRITER.using(ctx):
snp_db = ctx.session.query(models_v2.SubnetPool).get(snp['subnetpool']['id'])
ctx.session.add(tag_models.Tag(standard_attr_id=snp_db.standard_attr_id,
tag="availability-zone::qa-de-1a"))
Expand Down

0 comments on commit 66f6e82

Please sign in to comment.