Skip to content

Commit

Permalink
replace 403 with 401 status code (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
sinisaos authored Oct 4, 2024
1 parent d6ed009 commit bc7eefa
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/source/jwt/middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Setup
-----

``JWTMiddleware`` wraps an ASGI app, and ensures a valid token is passed in the header.
Otherwise a 403 error is returned. If the token is valid, the corresponding
Otherwise a 401 error is returned. If the token is valid, the corresponding
``user_id`` is added to the ASGI ``scope``.

blacklist
Expand Down
28 changes: 22 additions & 6 deletions piccolo_api/jwt_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jwt
from piccolo.apps.user.tables import BaseUser
from starlette.exceptions import HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.types import ASGIApp


Expand Down Expand Up @@ -126,7 +127,7 @@ async def get_user(
async def __call__(self, scope, receive, send):
"""
Add the user_id to the scope if a JWT token is available, and the user
is recognised, otherwise raise a 403 HTTP error.
is recognised, otherwise raise a 401 HTTP error.
"""
allow_unauthenticated = self.allow_unauthenticated

Expand All @@ -142,7 +143,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

if await self.blacklist.in_blacklist(token):
error = JWTError.token_revoked.value
Expand All @@ -154,7 +158,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

try:
token_dict = jwt.decode(token, self.secret, algorithms=["HS256"])
Expand All @@ -168,7 +175,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)
except jwt.exceptions.InvalidSignatureError:
error = JWTError.token_invalid.value
if allow_unauthenticated:
Expand All @@ -179,7 +189,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

user = await self.get_user(token_dict)
if user is None:
Expand All @@ -192,7 +205,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

await self.asgi(
extend_scope(scope, {"user_id": user.id}), receive, send
Expand Down
5 changes: 3 additions & 2 deletions piccolo_api/mfa/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED

from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.shared.auth.styles import Styles
Expand Down Expand Up @@ -64,7 +65,7 @@ def _render_cancel_template(
template = environment.get_template("mfa_cancel.html")

return HTMLResponse(
status_code=400,
status_code=HTTP_400_BAD_REQUEST,
content=template.render(
styles=self._styles,
csrftoken=request.scope.get("csrftoken"),
Expand Down Expand Up @@ -110,7 +111,7 @@ async def post(self, request: Request):
):
return self._render_register_template(
request=request,
status_code=403,
status_code=HTTP_401_UNAUTHORIZED,
extra_context={"error": "Incorrect password"},
)

Expand Down
30 changes: 19 additions & 11 deletions piccolo_api/session_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PlainTextResponse,
RedirectResponse,
)
from starlette.status import HTTP_303_SEE_OTHER
from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED

from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.session_auth.tables import SessionsBase
Expand Down Expand Up @@ -92,7 +92,8 @@ async def post(self, request: Request) -> Response:
cookie = request.cookies.get(self._cookie_name, None)
if not cookie:
raise HTTPException(
status_code=401, detail="The session cookie wasn't found."
status_code=HTTP_401_UNAUTHORIZED,
detail="The session cookie wasn't found.",
)
await self._session_table.remove_session(token=cookie)

Expand Down Expand Up @@ -204,11 +205,14 @@ def _get_error_response(
) -> Response:
if response_format == "html":
return self._render_template(
request, template_context={"error": error}, status_code=401
request,
template_context={"error": error},
status_code=HTTP_401_UNAUTHORIZED,
)
else:
return PlainTextResponse(
status_code=401, content=f"Login failed: {error}"
status_code=HTTP_401_UNAUTHORIZED,
content=f"Login failed: {error}",
)

async def get(self, request: Request) -> HTMLResponse:
Expand Down Expand Up @@ -261,7 +265,8 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401, detail=validate_response
status_code=HTTP_401_UNAUTHORIZED,
detail=validate_response,
)

# Attempt login
Expand Down Expand Up @@ -314,7 +319,8 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401, detail=message
status_code=HTTP_401_UNAUTHORIZED,
detail=message,
)

# Work out which MFA provider to use:
Expand All @@ -325,7 +331,7 @@ async def post(self, request: Request) -> Response:

if mfa_provider_name is None:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA provider must be specified",
)

Expand All @@ -337,13 +343,13 @@ async def post(self, request: Request) -> Response:

if len(filtered_mfa_providers) == 0:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA provider not recognised.",
)

if len(filtered_mfa_providers) > 1:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail=(
"Multiple matching MFA providers found."
),
Expand All @@ -368,7 +374,7 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA failed",
)

Expand Down Expand Up @@ -404,7 +410,9 @@ async def post(self, request: Request) -> Response:
},
)
else:
raise HTTPException(status_code=401, detail="Login failed")
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Login failed"
)

now = datetime.now()
expiry_date = now + self._session_expiry
Expand Down
3 changes: 2 additions & 1 deletion piccolo_api/shared/middleware/junction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from starlette.exceptions import HTTPException
from starlette.routing import Router
from starlette.status import HTTP_404_NOT_FOUND
from starlette.types import Receive, Scope, Send


Expand All @@ -22,4 +23,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
else:
return

raise HTTPException(status_code=404)
raise HTTPException(status_code=HTTP_404_NOT_FOUND)
6 changes: 4 additions & 2 deletions piccolo_api/token_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_401_UNAUTHORIZED

from .tables import TokenAuth

Expand Down Expand Up @@ -61,11 +62,12 @@ async def post(self, request: Request) -> Response:
else:
return Response(
content="The credentials were incorrect",
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
)
else:
return Response(
content="No credentials were found.", status_code=401
content="No credentials were found.",
status_code=HTTP_401_UNAUTHORIZED,
)


Expand Down
10 changes: 5 additions & 5 deletions tests/jwt_auth/test_jwt_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_empty_token(self):
with self.assertRaises(HTTPException):
response = client.get("/")

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_not_found.value
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_expired_token(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_expired.value
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_wrong_secret(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_invalid.value
)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_missing_expiry(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_expired.value
)
Expand All @@ -188,7 +188,7 @@ def test_token_without_user_id(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(response.content, b"")

# allow_unauthenticated
Expand Down

0 comments on commit bc7eefa

Please sign in to comment.