diff --git a/httpx_oauth/clients/openid.py b/httpx_oauth/clients/openid.py index a54233c..099008b 100644 --- a/httpx_oauth/clients/openid.py +++ b/httpx_oauth/clients/openid.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, get_args import httpx from httpx_oauth.exceptions import GetIdEmailError -from httpx_oauth.oauth2 import BaseOAuth2, OAuth2RequestError +from httpx_oauth.oauth2 import BaseOAuth2, OAuth2ClientAuthMethod, OAuth2RequestError BASE_SCOPES = ["openid", "email"] @@ -70,6 +70,19 @@ def __init__( "revocation_endpoint_auth_methods_supported", ["client_secret_basic"] ) + supported_auth_methods = get_args(OAuth2ClientAuthMethod) + # check if there is any supported and select the first one + token_endpoint_auth_methods_supported = [ + method + for method in token_endpoint_auth_methods_supported + if method in supported_auth_methods + ] + revocation_endpoint_auth_methods_supported = [ + method + for method in revocation_endpoint_auth_methods_supported + if method in supported_auth_methods + ] + super().__init__( client_id, client_secret, @@ -80,11 +93,11 @@ def __init__( name=name, base_scopes=base_scopes, token_endpoint_auth_method=token_endpoint_auth_methods_supported[0], - revocation_endpoint_auth_method=revocation_endpoint_auth_methods_supported[ - 0 - ] - if revocation_endpoint - else None, + revocation_endpoint_auth_method=( + revocation_endpoint_auth_methods_supported[0] + if revocation_endpoint + else None + ), ) async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]: diff --git a/tests/test_clients_facebook.py b/tests/test_clients_facebook.py index 2225432..d78e101 100644 --- a/tests/test_clients_facebook.py +++ b/tests/test_clients_facebook.py @@ -49,7 +49,7 @@ async def test_get_long_lived_access_token(self, load_mock, get_respx_call_args) assert f"client_id={CLIENT_ID}" in content assert f"client_secret={CLIENT_SECRET}" in content - assert type(access_token) == OAuth2Token + assert type(access_token) is OAuth2Token assert "access_token" in access_token assert "token_type" in access_token assert access_token.is_expired() is False diff --git a/tests/test_clients_github.py b/tests/test_clients_github.py index bb18e29..2876133 100644 --- a/tests/test_clients_github.py +++ b/tests/test_clients_github.py @@ -45,7 +45,7 @@ async def test_refresh_token(self, load_mock, get_respx_call_args): assert "client_id=CLIENT_ID" in content assert "client_secret=CLIENT_SECRET" in content - assert type(access_token) == OAuth2Token + assert type(access_token) is OAuth2Token assert "access_token" in access_token assert "token_type" in access_token assert access_token.is_expired() is False diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 2246628..4bae710 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -182,7 +182,7 @@ async def test_get_access_token( assert f"client_id={CLIENT_ID}" in content assert f"client_secret={CLIENT_SECRET}" in content - assert type(access_token) == OAuth2Token + assert type(access_token) is OAuth2Token assert "access_token" in access_token assert "token_type" in access_token assert access_token.is_expired() is False @@ -245,7 +245,7 @@ async def test_refresh_token( assert f"client_id={CLIENT_ID}" in content assert f"client_secret={CLIENT_SECRET}" in content - assert type(access_token) == OAuth2Token + assert type(access_token) is OAuth2Token assert "access_token" in access_token assert "token_type" in access_token assert access_token.is_expired() is False