Skip to content

Commit

Permalink
Handle GitHub Refresh Token request errors
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed Apr 17, 2024
1 parent bd536f0 commit 5bc4b44
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
30 changes: 29 additions & 1 deletion httpx_oauth/clients/github.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import urllib.parse
from typing import Any, Dict, List, Optional, Tuple, TypedDict, cast

import httpx

from httpx_oauth.errors import GetIdEmailError
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token, RefreshTokenError

AUTHORIZE_ENDPOINT = "https://github.com/login/oauth/authorize"
ACCESS_TOKEN_ENDPOINT = "https://github.com/login/oauth/access_token"
Expand Down Expand Up @@ -47,6 +48,33 @@ def __init__(
base_scopes=scopes,
)

async def refresh_token(self, refresh_token: str):
assert self.refresh_token_endpoint is not None
async with self.get_httpx_client() as client:
response = await client.post(
self.refresh_token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
},
headers=self.request_headers,
)

content_type = response.headers.get("content-type", "")
if content_type.startswith("application/json"):
data = response.json()
# GitHub sends errors with a 200 status code
# and a form-urlencoded content type 😕
elif content_type.startswith("application/x-www-form-urlencoded"):
data = urllib.parse.parse_qs(response.text)

if response.status_code >= 400 or "error" in data:
raise RefreshTokenError(data)

return OAuth2Token(data)

async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]:
async with httpx.AsyncClient(
headers={**self.request_headers, "Authorization": f"token {token}"}
Expand Down
8 changes: 8 additions & 0 deletions tests/mock/github_success_refresh_token.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"access_token": "ghu_XXX",
"expires_in": 28800,
"refresh_token": "ghr_XXX",
"refresh_token_expires_in": 15897600,
"scope": "",
"token_type": "bearer"
}
48 changes: 48 additions & 0 deletions tests/test_clients_github.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import re
from urllib.parse import urlencode

import pytest
import respx
from httpx import Response

from httpx_oauth.clients.github import EMAILS_ENDPOINT, PROFILE_ENDPOINT, GitHubOAuth2
from httpx_oauth.errors import GetIdEmailError
from httpx_oauth.oauth2 import OAuth2Token, RefreshTokenError

client = GitHubOAuth2("CLIENT_ID", "CLIENT_SECRET")

Expand All @@ -26,6 +28,52 @@ def test_github_oauth2():
emails_response = [{"email": "[email protected]"}]


class TestGitHubRefreshToken:
@pytest.mark.asyncio
@respx.mock
async def test_refresh_token(self, load_mock, get_respx_call_args):
request = respx.post(client.refresh_token_endpoint).mock(
return_value=Response(200, json=load_mock("github_success_refresh_token"))
)
access_token = await client.refresh_token("REFRESH_TOKEN")

url, headers, content = await get_respx_call_args(request)
assert headers["Content-Type"] == "application/x-www-form-urlencoded"
assert headers["Accept"] == "application/json"
assert "grant_type=refresh_token" in content
assert "refresh_token=REFRESH_TOKEN" in content
assert "client_id=CLIENT_ID" in content
assert "client_secret=CLIENT_SECRET" in content

assert type(access_token) == OAuth2Token
assert "access_token" in access_token
assert "token_type" in access_token
assert access_token.is_expired() is False

@pytest.mark.asyncio
@respx.mock
async def test_refresh_token_error(self, load_mock):
error_response = {
"error": "bad_refresh_token",
"error_description": "The refresh token passed is incorrect or expired.",
"error_uri": "https://docs.github.com",
}
error_response_encoded = urlencode(error_response)

respx.post(client.refresh_token_endpoint).mock(
return_value=Response(
200,
headers={"content-type": "application/x-www-form-urlencoded"},
content=error_response_encoded,
)
)

with pytest.raises(RefreshTokenError) as excinfo:
await client.refresh_token("REFRESH_TOKEN")
assert isinstance(excinfo.value.args[0], dict)
assert "error" in excinfo.value.args[0]


class TestGitHubGetIdEmail:
@pytest.mark.asyncio
@respx.mock
Expand Down

0 comments on commit 5bc4b44

Please sign in to comment.