Skip to content

Commit

Permalink
Improve caching
Browse files Browse the repository at this point in the history
- Cache `get_identity_groups` results
- Wait before trying to refresh again after a failed request
  • Loading branch information
ThiefMaster committed Jun 27, 2024
1 parent 74f72ba commit adbe4e5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
55 changes: 36 additions & 19 deletions flask_multipass_cern.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def set(self, key, value, timeout=0, refresh_timeout=None):
if refresh_timeout:
self.cache.set(f'{key}:timestamp', datetime.now(), refresh_timeout)

def delay_refresh(self, key, timeout):
self.cache.set(f'{key}:timestamp', datetime.now(), timeout)

def should_refresh(self, key):
if self.cache is None:
return True
Expand Down Expand Up @@ -164,21 +167,12 @@ def get_members(self):
yield IdentityInfo(self.provider, identifier, extra_data, **res)

def has_member(self, identifier):
cache = self.provider.cache
logger = self.provider.logger
cache_key = f'flask-multipass-cern:{self.provider.name}:groups:{identifier}'
all_groups = cache.get(cache_key)

if all_groups is None or cache.should_refresh(cache_key):
try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
cache.set(cache_key, all_groups, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
logger.warning('Refreshing user groups failed for %s', identifier)
if all_groups is None:
logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False

try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
except RequestException:
# request failed and could not be satisfied from cache
self.provider.logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False
if self.provider.settings['cern_users_group'] and self.name.lower() == 'cern users':
return self.provider.settings['cern_users_group'].lower() in all_groups
return self.name.lower() in all_groups
Expand Down Expand Up @@ -352,6 +346,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s (could not get API token)', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s (could not get API token)', criteria)
Expand All @@ -364,6 +359,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s', criteria)
Expand All @@ -388,19 +384,40 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
self.cache.set(cache_key, (cache_data, total), CACHE_LONG_TTL, CACHE_TTL * 2)
return identities, total

def get_identity_groups(self, identifier):
def _fetch_identity_group_names(self, identifier):
with self._get_api_session() as api_session:
identifier = identifier.replace('/', '%2F') # edugain identifiers sometimes contain slashes
resp = api_session.get(f'{self.authz_api_base}/api/v1.0/IdentityMembership/{identifier}/precomputed')
if resp.status_code == 404:
return set()
resp.raise_for_status()
results = resp.json()['data']
groups = {self.group_class(self, res['groupIdentifier']) for res in results}
if self.settings['cern_users_group'] and any(g.name == self.settings['cern_users_group'] for g in groups):
groups.add(self.group_class(self, 'CERN Users'))
groups = {res['groupIdentifier'] for res in results}
if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in groups):
groups.add('CERN Users')
return groups

def get_identity_groups(self, identifier):
cache_key = f'flask-multipass-cern:{self.name}:groups:{identifier}'
group_names = self.cache.get(cache_key)

if group_names is None or self.cache.should_refresh(cache_key):
try:
group_names = self._fetch_identity_group_names(identifier)
self.cache.set(cache_key, group_names, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
self.logger.warning('Refreshing user groups failed for %s', identifier)
if group_names is not None:
self.cache.delay_refresh(cache_key, CACHE_TTL)
else:
self.logger.error('Getting user groups failed for %s, request will fail', identifier)
raise

if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in group_names):
group_names.add('CERN Users')

return {self.group_class(self, g) for g in group_names}

def get_group(self, name):
return self.group_class(self, name)

Expand Down
21 changes: 9 additions & 12 deletions tests/test_has_member.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from unittest.mock import MagicMock

import pytest
from requests import Session
Expand All @@ -17,17 +16,15 @@ def mock_get_api_session(mocker):


@pytest.fixture
def mock_get_identity_groups(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
group = MagicMock()
group.name = 'cern users'
get_identity_groups.return_value = {group}
def mock_fetch_identity_group_names(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.return_value = {'cern users'}
return get_identity_groups


@pytest.fixture
def mock_get_identity_groups_fail(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.side_effect = RequestException()
return get_identity_groups

Expand All @@ -37,7 +34,7 @@ def spy_cache_set(mocker):
return mocker.spy(MemoryCache, 'set')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache(provider):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')
Expand All @@ -46,24 +43,24 @@ def test_has_member_cache(provider):
assert test_group.provider.cache.get('flask-multipass-cern:cip:groups:12345:timestamp')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache_miss(provider, spy_cache_set):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')

assert spy_cache_set.call_count == 2


def test_has_member_cache_hit(provider, mock_get_identity_groups):
def test_has_member_cache_hit(provider, mock_fetch_identity_group_names):
test_group = CERNGroup(provider, 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345', 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345:timestamp', datetime.now())
test_group.has_member('12345')

assert not mock_get_identity_groups.called
assert not mock_fetch_identity_group_names.called


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_request_fails(provider, mock_get_identity_groups_fail):
test_group = CERNGroup(provider, 'cern users')
res = test_group.has_member('12345')
Expand Down

0 comments on commit adbe4e5

Please sign in to comment.