Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make jwt decode fucntion generic #354

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Change Log
Unreleased
----------

Added
~~~~~
* (`#354 <https://github.com/openedx/edx-drf-extensions/pull/354>`_) Implemented ``verify_jwk_signature_using_keyset`` function.
This function allows for easy verification of JSON Web Key (JWK) signatures using a provided keyset.

[8.9.3] - 2023-09-13
--------------------

Expand Down
83 changes: 66 additions & 17 deletions edx_rest_framework_extensions/auth/jwt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
# can fully retire code paths for symmetric keys, as part of
# DEPR: Symmetric JWTs: https://github.com/openedx/public-engineering/issues/83

# Use add_symmetric_keys=False to only include asymmetric keys at first
key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=False)
# Pass only asymmetric_keys to only include asymmetric keys at first
asymmetric_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys)
# .. custom_attribute_name: jwt_auth_verify_asymmetric_keys_count
# .. custom_attribute_description: Number of JWT verification keys in use for this
# verification. Should be same as number of asymmetric public keys. This is
Expand All @@ -203,7 +204,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
set_custom_attribute('jwt_auth_verify_asymmetric_keys_count', len(key_set))

try:
_verify_jwk_signature_using_keyset(token, key_set, jwt_issuer)
verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE'])
# .. custom_attribute_name: jwt_auth_asymmetric_verified
# .. custom_attribute_description: Whether the JWT was successfully verified
# using an asymmetric key.
Expand All @@ -218,7 +219,8 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
# the asymmetric keys here is redundant and unnecessary, but this code is temporary and
# will be simplified once symmetric keys have been fully retired.

key_set = _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=decode_symmetric_token)
secret_key = jwt_issuer['SECRET_KEY'] if decode_symmetric_token else None
key_set = get_verification_jwk_key_set(asymmetric_keys=asymmetric_keys, secret_key=secret_key)
# .. custom_attribute_name: jwt_auth_verify_all_keys_count
# .. custom_attribute_description: Number of JWT verification keys in use for this
# verification. Should be same as number of asymmetric public keys, plus one if
Expand All @@ -228,7 +230,7 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
set_custom_attribute('jwt_auth_verify_all_keys_count', len(key_set))

try:
_verify_jwk_signature_using_keyset(token, key_set, jwt_issuer)
verify_jwk_signature_using_keyset(token, key_set, aud=jwt_issuer['AUDIENCE'])
# .. custom_attribute_name: jwt_auth_symmetric_verified
# .. custom_attribute_description: Whether the JWT was successfully verified
# using a symmetric key.
Expand All @@ -248,7 +250,48 @@ def _verify_jwt_signature(token, jwt_issuer, decode_symmetric_token):
raise jwt.InvalidTokenError(exc_info[2]) from token_error


def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer):
def verify_jwk_signature_using_keyset(token, key_set, aud=None, iss=None, verify_signature=True, verify_exp=True):
"""
Verifies the signature of a JSON Web Token (JWT) using a provided JSON Web Key (PyJWK) key set.

Args:
token (str): The JWT to be verified.
key_set (list -> PyJWK): A list containing PyJWKs (JSON Web Keys)
for signature verification.
aud (str or None): The expected "aud" (audience) claim in the JWT.
If provided, the JWT's "aud" claim must match this value for
the verification to succeed.
iss (str or None): The expected "iss" (issuer) claim in the JWT.
If provided, the JWT's "iss" claim must match this value for
the verification to succeed.
verify_signature (bool): Whether to verify the JWT's digital signature.
Set to False if you want to skip signature verification
(e.g., if the JWT is already pre-verified).
verify_exp (bool): Whether to verify the JWT's expiration time ("exp" claim).
Set to False if you want to skip expiration time verification.

Returns:
data (dict): Decoded JWT.

Raises:
ValueError: If the token is not a valid JWT or if the key_set is empty
or improperly formatted.
jwt.ExpiredSignatureError: If the JWT has expired and verify_exp
is set to True.
jwt.InvalidIssuerError: If the "iss" claim does not match the expected
issuer and iss is provided.
jwt.InvalidAudienceError: If the "aud" claim does not match the expected
audience and aud is provided.
jwt.DecodeError: If the JWT decoding fails for any reason.
"""
options = {
'verify_signature': verify_signature,
'verify_exp': verify_exp,
'verify_aud': bool(aud),
'verify_iss': bool(iss)
}
data = None

for i in range(0, len(key_set)):
try:
algorithms = None
Expand All @@ -257,16 +300,19 @@ def _verify_jwk_signature_using_keyset(token, key_set, jwt_issuer):
elif key_set[i].key_type == 'oct':
algorithms = ['HS256',]

_ = jwt.decode(
data = jwt.decode(
token,
key=key_set[i].key,
algorithms=algorithms,
audience=jwt_issuer['AUDIENCE'],
issuer=iss,
audience=aud,
options=options
)
break
except Exception: # pylint: disable=broad-except
if i == len(key_set) - 1:
raise
return data


def _decode_and_verify_token(token, jwt_issuer):
Expand Down Expand Up @@ -315,21 +361,24 @@ def _decode_and_verify_token(token, jwt_issuer):
return decoded_token


def _get_signing_jwk_key_set(jwt_issuer, add_symmetric_keys=True):
def get_verification_jwk_key_set(asymmetric_keys=None, secret_key=None):
"""
Returns a JWK Keyset containing all active keys that are configured
for verifying signatures.
Creates a JWK Keyset containing the provided keys.

Args:
asymmetric_keys (list or None): List of asymmetric JWK verification keys,
each in JSON format.
secret_key (str or None): Secret key for symmetric JWT verification, as an
unencoded string.
"""
key_set = []

# asymmetric keys
signing_jwk_set = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
if signing_jwk_set:
key_set.extend(PyJWKSet.from_json(signing_jwk_set).keys)
if asymmetric_keys:
key_set.extend(PyJWKSet.from_json(asymmetric_keys).keys)

if add_symmetric_keys:
if secret_key:
# symmetric key
encoded_secret_key = base64url_encode(jwt_issuer['SECRET_KEY'].encode('utf-8'))
encoded_secret_key = base64url_encode(secret_key.encode('utf-8'))
key_set.append(PyJWK({'k': encoded_secret_key, 'kty': 'oct'}))

return key_set