Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve retries + caching #21

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions flask_multipass_cern.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
CACHE_LONG_TTL = 86400 * 7
CACHE_TTL = 1800
CERN_OIDC_WELLKNOWN_URL = 'https://auth.cern.ch/auth/realms/cern/.well-known/openid-configuration'
HTTP_RETRY_COUNT = 5

# not sure if retries are still needed, but by not using a backoff we don't risk taking down the site
# using this library in case the API is persistently failing with an error
HTTP_RETRY_COUNT = 2
retry_config = HTTPAdapter(max_retries=Retry(total=HTTP_RETRY_COUNT,
backoff_factor=0.5,
backoff_factor=0,
status_forcelist=[503, 504],
allowed_methods=frozenset(['GET']),
raise_on_status=False))
Expand Down Expand Up @@ -64,6 +65,9 @@ def set(self, key, value, timeout=0, refresh_timeout=None):
if refresh_timeout:
self.cache.set(f'{key}:timestamp', datetime.now(), refresh_timeout)

def delay_refresh(self, key, timeout):
self.cache.set(f'{key}:timestamp', datetime.now(), timeout)

def should_refresh(self, key):
if self.cache is None:
return True
Expand Down Expand Up @@ -163,21 +167,12 @@ def get_members(self):
yield IdentityInfo(self.provider, identifier, extra_data, **res)

def has_member(self, identifier):
cache = self.provider.cache
logger = self.provider.logger
cache_key = f'flask-multipass-cern:{self.provider.name}:groups:{identifier}'
all_groups = cache.get(cache_key)

if all_groups is None or cache.should_refresh(cache_key):
try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
cache.set(cache_key, all_groups, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
logger.warning('Refreshing user groups failed for %s', identifier)
if all_groups is None:
logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False

try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
except RequestException:
# request failed and could not be satisfied from cache
self.provider.logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False
if self.provider.settings['cern_users_group'] and self.name.lower() == 'cern users':
return self.provider.settings['cern_users_group'].lower() in all_groups
return self.name.lower() in all_groups
Expand Down Expand Up @@ -351,6 +346,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s (could not get API token)', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s (could not get API token)', criteria)
ThiefMaster marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -363,6 +359,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s', criteria)
Expand All @@ -387,19 +384,40 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
self.cache.set(cache_key, (cache_data, total), CACHE_LONG_TTL, CACHE_TTL * 2)
return identities, total

def get_identity_groups(self, identifier):
def _fetch_identity_group_names(self, identifier):
with self._get_api_session() as api_session:
identifier = identifier.replace('/', '%2F') # edugain identifiers sometimes contain slashes
resp = api_session.get(f'{self.authz_api_base}/api/v1.0/IdentityMembership/{identifier}/precomputed')
if resp.status_code == 404:
return set()
resp.raise_for_status()
results = resp.json()['data']
groups = {self.group_class(self, res['groupIdentifier']) for res in results}
if self.settings['cern_users_group'] and any(g.name == self.settings['cern_users_group'] for g in groups):
groups.add(self.group_class(self, 'CERN Users'))
groups = {res['groupIdentifier'] for res in results}
if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in groups):
groups.add('CERN Users')
return groups

def get_identity_groups(self, identifier):
cache_key = f'flask-multipass-cern:{self.name}:groups:{identifier}'
group_names = self.cache.get(cache_key)

if group_names is None or self.cache.should_refresh(cache_key):
try:
group_names = self._fetch_identity_group_names(identifier)
self.cache.set(cache_key, group_names, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
self.logger.warning('Refreshing user groups failed for %s', identifier)
if group_names is not None:
self.cache.delay_refresh(cache_key, CACHE_TTL)
else:
self.logger.error('Getting user groups failed for %s, request will fail', identifier)
raise

if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in group_names):
group_names.add('CERN Users')

return {self.group_class(self, g) for g in group_names}

def get_group(self, name):
return self.group_class(self, name)

Expand Down
21 changes: 9 additions & 12 deletions tests/test_has_member.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from unittest.mock import MagicMock

import pytest
from requests import Session
Expand All @@ -17,17 +16,15 @@ def mock_get_api_session(mocker):


@pytest.fixture
def mock_get_identity_groups(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
group = MagicMock()
group.name = 'cern users'
get_identity_groups.return_value = {group}
def mock_fetch_identity_group_names(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.return_value = {'cern users'}
return get_identity_groups


@pytest.fixture
def mock_get_identity_groups_fail(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.side_effect = RequestException()
return get_identity_groups

Expand All @@ -37,7 +34,7 @@ def spy_cache_set(mocker):
return mocker.spy(MemoryCache, 'set')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache(provider):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')
Expand All @@ -46,24 +43,24 @@ def test_has_member_cache(provider):
assert test_group.provider.cache.get('flask-multipass-cern:cip:groups:12345:timestamp')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache_miss(provider, spy_cache_set):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')

assert spy_cache_set.call_count == 2


def test_has_member_cache_hit(provider, mock_get_identity_groups):
def test_has_member_cache_hit(provider, mock_fetch_identity_group_names):
test_group = CERNGroup(provider, 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345', 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345:timestamp', datetime.now())
test_group.has_member('12345')

assert not mock_get_identity_groups.called
assert not mock_fetch_identity_group_names.called


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_request_fails(provider, mock_get_identity_groups_fail):
test_group = CERNGroup(provider, 'cern users')
res = test_group.has_member('12345')
Expand Down
19 changes: 19 additions & 0 deletions tests/test_search_identities_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ def test_search_identities_cache_hit_stale(provider, mock_data, freeze_time):
assert mock_data[data_key] == identities[0][0].data.get(identities_key)
assert isinstance(identities[0][0], IdentityInfo)
assert identities[1] == 1


@pytest.mark.usefixtures('httpretty_enabled')
def test_search_identities_cache_hit_broken_sso(mocker, provider, mock_data, freeze_time):
get_api_session = mocker.patch('flask_multipass_cern.CERNIdentityProvider._get_api_session')
get_api_session.side_effect = RequestException()

test_uri = f'{provider.settings.get("authz_api")}/api/v1.0/Identity'
httpretty.register_uri(httpretty.GET, test_uri, status=401)
cache_key = 'flask-multipass-cern:cip:email-identities:[email protected]'
provider.cache.set(cache_key, ([mock_data], 1), 2000, 10)
freeze_time(datetime.now() + timedelta(seconds=100))

identities = provider.search_identities_ex({'primaryAccountEmail': {'[email protected]'}}, True)

for identities_key, data_key in provider.settings.get('mapping').items():
assert mock_data[data_key] == identities[0][0].data.get(identities_key)
assert isinstance(identities[0][0], IdentityInfo)
assert identities[1] == 1