From bc7eefa9d63b018cf45853798acbfd0be0978754 Mon Sep 17 00:00:00 2001 From: sinisaos Date: Fri, 4 Oct 2024 21:06:26 +0200 Subject: [PATCH] replace 403 with 401 status code (#302) --- docs/source/jwt/middleware.rst | 2 +- piccolo_api/jwt_auth/middleware.py | 28 ++++++++++++++++----- piccolo_api/mfa/endpoints.py | 5 ++-- piccolo_api/session_auth/endpoints.py | 30 ++++++++++++++--------- piccolo_api/shared/middleware/junction.py | 3 ++- piccolo_api/token_auth/endpoints.py | 6 +++-- tests/jwt_auth/test_jwt_middleware.py | 10 ++++---- 7 files changed, 56 insertions(+), 28 deletions(-) diff --git a/docs/source/jwt/middleware.rst b/docs/source/jwt/middleware.rst index f9aa666c..74d0fe4e 100644 --- a/docs/source/jwt/middleware.rst +++ b/docs/source/jwt/middleware.rst @@ -11,7 +11,7 @@ Setup ----- ``JWTMiddleware`` wraps an ASGI app, and ensures a valid token is passed in the header. -Otherwise a 403 error is returned. If the token is valid, the corresponding +Otherwise a 401 error is returned. If the token is valid, the corresponding ``user_id`` is added to the ASGI ``scope``. blacklist diff --git a/piccolo_api/jwt_auth/middleware.py b/piccolo_api/jwt_auth/middleware.py index 04a610a8..26ce9194 100644 --- a/piccolo_api/jwt_auth/middleware.py +++ b/piccolo_api/jwt_auth/middleware.py @@ -6,6 +6,7 @@ import jwt from piccolo.apps.user.tables import BaseUser from starlette.exceptions import HTTPException +from starlette.status import HTTP_401_UNAUTHORIZED from starlette.types import ASGIApp @@ -126,7 +127,7 @@ async def get_user( async def __call__(self, scope, receive, send): """ Add the user_id to the scope if a JWT token is available, and the user - is recognised, otherwise raise a 403 HTTP error. + is recognised, otherwise raise a 401 HTTP error. """ allow_unauthenticated = self.allow_unauthenticated @@ -142,7 +143,10 @@ async def __call__(self, scope, receive, send): ) return else: - raise HTTPException(status_code=403, detail=error) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error, + ) if await self.blacklist.in_blacklist(token): error = JWTError.token_revoked.value @@ -154,7 +158,10 @@ async def __call__(self, scope, receive, send): ) return else: - raise HTTPException(status_code=403, detail=error) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error, + ) try: token_dict = jwt.decode(token, self.secret, algorithms=["HS256"]) @@ -168,7 +175,10 @@ async def __call__(self, scope, receive, send): ) return else: - raise HTTPException(status_code=403, detail=error) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error, + ) except jwt.exceptions.InvalidSignatureError: error = JWTError.token_invalid.value if allow_unauthenticated: @@ -179,7 +189,10 @@ async def __call__(self, scope, receive, send): ) return else: - raise HTTPException(status_code=403, detail=error) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error, + ) user = await self.get_user(token_dict) if user is None: @@ -192,7 +205,10 @@ async def __call__(self, scope, receive, send): ) return else: - raise HTTPException(status_code=403, detail=error) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=error, + ) await self.asgi( extend_scope(scope, {"user_id": user.id}), receive, send diff --git a/piccolo_api/mfa/endpoints.py b/piccolo_api/mfa/endpoints.py index b4f7187b..fc0bf3eb 100644 --- a/piccolo_api/mfa/endpoints.py +++ b/piccolo_api/mfa/endpoints.py @@ -8,6 +8,7 @@ from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse +from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED from piccolo_api.mfa.provider import MFAProvider from piccolo_api.shared.auth.styles import Styles @@ -64,7 +65,7 @@ def _render_cancel_template( template = environment.get_template("mfa_cancel.html") return HTMLResponse( - status_code=400, + status_code=HTTP_400_BAD_REQUEST, content=template.render( styles=self._styles, csrftoken=request.scope.get("csrftoken"), @@ -110,7 +111,7 @@ async def post(self, request: Request): ): return self._render_register_template( request=request, - status_code=403, + status_code=HTTP_401_UNAUTHORIZED, extra_context={"error": "Incorrect password"}, ) diff --git a/piccolo_api/session_auth/endpoints.py b/piccolo_api/session_auth/endpoints.py index a1580231..ba8003d1 100644 --- a/piccolo_api/session_auth/endpoints.py +++ b/piccolo_api/session_auth/endpoints.py @@ -17,7 +17,7 @@ PlainTextResponse, RedirectResponse, ) -from starlette.status import HTTP_303_SEE_OTHER +from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED from piccolo_api.mfa.provider import MFAProvider from piccolo_api.session_auth.tables import SessionsBase @@ -92,7 +92,8 @@ async def post(self, request: Request) -> Response: cookie = request.cookies.get(self._cookie_name, None) if not cookie: raise HTTPException( - status_code=401, detail="The session cookie wasn't found." + status_code=HTTP_401_UNAUTHORIZED, + detail="The session cookie wasn't found.", ) await self._session_table.remove_session(token=cookie) @@ -204,11 +205,14 @@ def _get_error_response( ) -> Response: if response_format == "html": return self._render_template( - request, template_context={"error": error}, status_code=401 + request, + template_context={"error": error}, + status_code=HTTP_401_UNAUTHORIZED, ) else: return PlainTextResponse( - status_code=401, content=f"Login failed: {error}" + status_code=HTTP_401_UNAUTHORIZED, + content=f"Login failed: {error}", ) async def get(self, request: Request) -> HTMLResponse: @@ -261,7 +265,8 @@ async def post(self, request: Request) -> Response: ) else: raise HTTPException( - status_code=401, detail=validate_response + status_code=HTTP_401_UNAUTHORIZED, + detail=validate_response, ) # Attempt login @@ -314,7 +319,8 @@ async def post(self, request: Request) -> Response: ) else: raise HTTPException( - status_code=401, detail=message + status_code=HTTP_401_UNAUTHORIZED, + detail=message, ) # Work out which MFA provider to use: @@ -325,7 +331,7 @@ async def post(self, request: Request) -> Response: if mfa_provider_name is None: raise HTTPException( - status_code=401, + status_code=HTTP_401_UNAUTHORIZED, detail="MFA provider must be specified", ) @@ -337,13 +343,13 @@ async def post(self, request: Request) -> Response: if len(filtered_mfa_providers) == 0: raise HTTPException( - status_code=401, + status_code=HTTP_401_UNAUTHORIZED, detail="MFA provider not recognised.", ) if len(filtered_mfa_providers) > 1: raise HTTPException( - status_code=401, + status_code=HTTP_401_UNAUTHORIZED, detail=( "Multiple matching MFA providers found." ), @@ -368,7 +374,7 @@ async def post(self, request: Request) -> Response: ) else: raise HTTPException( - status_code=401, + status_code=HTTP_401_UNAUTHORIZED, detail="MFA failed", ) @@ -404,7 +410,9 @@ async def post(self, request: Request) -> Response: }, ) else: - raise HTTPException(status_code=401, detail="Login failed") + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="Login failed" + ) now = datetime.now() expiry_date = now + self._session_expiry diff --git a/piccolo_api/shared/middleware/junction.py b/piccolo_api/shared/middleware/junction.py index bf07eee6..ea506142 100644 --- a/piccolo_api/shared/middleware/junction.py +++ b/piccolo_api/shared/middleware/junction.py @@ -1,5 +1,6 @@ from starlette.exceptions import HTTPException from starlette.routing import Router +from starlette.status import HTTP_404_NOT_FOUND from starlette.types import Receive, Scope, Send @@ -22,4 +23,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): else: return - raise HTTPException(status_code=404) + raise HTTPException(status_code=HTTP_404_NOT_FOUND) diff --git a/piccolo_api/token_auth/endpoints.py b/piccolo_api/token_auth/endpoints.py index 26472ceb..88efc1bc 100644 --- a/piccolo_api/token_auth/endpoints.py +++ b/piccolo_api/token_auth/endpoints.py @@ -7,6 +7,7 @@ from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, Response +from starlette.status import HTTP_401_UNAUTHORIZED from .tables import TokenAuth @@ -61,11 +62,12 @@ async def post(self, request: Request) -> Response: else: return Response( content="The credentials were incorrect", - status_code=401, + status_code=HTTP_401_UNAUTHORIZED, ) else: return Response( - content="No credentials were found.", status_code=401 + content="No credentials were found.", + status_code=HTTP_401_UNAUTHORIZED, ) diff --git a/tests/jwt_auth/test_jwt_middleware.py b/tests/jwt_auth/test_jwt_middleware.py index 9b91b641..a06455e1 100644 --- a/tests/jwt_auth/test_jwt_middleware.py +++ b/tests/jwt_auth/test_jwt_middleware.py @@ -46,7 +46,7 @@ def test_empty_token(self): with self.assertRaises(HTTPException): response = client.get("/") - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) self.assertEqual( response.json()["detail"], JWTError.token_not_found.value ) @@ -103,7 +103,7 @@ def test_expired_token(self): with self.assertRaises(HTTPException): response = client.get("/", headers=headers) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) self.assertEqual( response.json()["detail"], JWTError.token_expired.value ) @@ -134,7 +134,7 @@ def test_wrong_secret(self): with self.assertRaises(HTTPException): response = client.get("/", headers=headers) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) self.assertEqual( response.json()["detail"], JWTError.token_invalid.value ) @@ -165,7 +165,7 @@ def test_missing_expiry(self): with self.assertRaises(HTTPException): response = client.get("/", headers=headers) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) self.assertEqual( response.json()["detail"], JWTError.token_expired.value ) @@ -188,7 +188,7 @@ def test_token_without_user_id(self): with self.assertRaises(HTTPException): response = client.get("/", headers=headers) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, 401) self.assertEqual(response.content, b"") # allow_unauthenticated