Skip to content

Commit

Permalink
Google OAuth endpoints and UI login (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
samredai authored Oct 4, 2023
1 parent 394bf30 commit d5616a4
Show file tree
Hide file tree
Showing 21 changed files with 513 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ djqs.db
.pdm.toml

# oauth credentials
client_secret.json
client_secret*
2 changes: 2 additions & 0 deletions datajunction-server/.coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ omit =
*/datajunction_server/sql/parsing/backends/grammar/generated/*
*/datajunction_server/sql/parsing/backends/antlr4.py
*/datajunction_server/sql/parsing/ast.py
*/datajunction_server/internal/authentication/google.py
*/datajunction_server/api/authentication/google.py

[paths]
source =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session, select

from datajunction_server.constants import DJ_AUTH_COOKIE, DJ_LOGGED_IN_FLAG_COOKIE
from datajunction_server.constants import AUTH_COOKIE, LOGGED_IN_FLAG_COOKIE
from datajunction_server.errors import DJError, DJException, ErrorCode
from datajunction_server.internal.authentication.basic import (
get_password_hash,
Expand Down Expand Up @@ -70,23 +70,23 @@ async def login(
)
response = Response(status_code=HTTPStatus.OK)
response.set_cookie(
DJ_AUTH_COOKIE,
AUTH_COOKIE,
create_token({"username": user.username}, expires_delta=timedelta(days=365)),
httponly=True,
)
response.set_cookie(
DJ_LOGGED_IN_FLAG_COOKIE,
LOGGED_IN_FLAG_COOKIE,
"true",
)
return response


@router.post("/basic/logout/")
@router.post("/logout/")
async def logout():
"""
Logout a user by deleting the auth cookie
"""
response = Response(status_code=HTTPStatus.OK)
response.delete_cookie(DJ_AUTH_COOKIE, httponly=True)
response.delete_cookie(DJ_LOGGED_IN_FLAG_COOKIE)
response.delete_cookie(AUTH_COOKIE, httponly=True)
response.delete_cookie(LOGGED_IN_FLAG_COOKIE)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Response
from fastapi.responses import JSONResponse, RedirectResponse

from datajunction_server.constants import DJ_AUTH_COOKIE, DJ_LOGGED_IN_FLAG_COOKIE
from datajunction_server.constants import AUTH_COOKIE, LOGGED_IN_FLAG_COOKIE
from datajunction_server.errors import DJError, DJException, ErrorCode
from datajunction_server.internal.authentication import github
from datajunction_server.internal.authentication.tokens import create_token
Expand Down Expand Up @@ -97,12 +97,12 @@ async def get_access_token(
),
)
response.set_cookie(
DJ_AUTH_COOKIE,
AUTH_COOKIE,
create_token({"username": user.username}, expires_delta=timedelta(days=365)),
httponly=True,
)
response.set_cookie(
DJ_LOGGED_IN_FLAG_COOKIE,
LOGGED_IN_FLAG_COOKIE,
"true",
)
return response
108 changes: 108 additions & 0 deletions datajunction-server/datajunction_server/api/authentication/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Google OAuth Router
"""
import logging
import secrets
from datetime import timedelta
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin, urlparse

import google.auth.transport.requests
import google.oauth2.credentials
import requests
from fastapi import APIRouter, Depends, Request
from google.oauth2 import id_token
from sqlmodel import Session, select
from starlette.responses import RedirectResponse

from datajunction_server.constants import AUTH_COOKIE, LOGGED_IN_FLAG_COOKIE
from datajunction_server.errors import DJException
from datajunction_server.internal.authentication.basic import get_password_hash
from datajunction_server.internal.authentication.google import flow, get_authorize_url
from datajunction_server.internal.authentication.tokens import create_token
from datajunction_server.models.user import OAuthProvider, User
from datajunction_server.utils import Settings, get_session, get_settings

_logger = logging.getLogger(__name__)
router = APIRouter(tags=["Google OAuth"])
settings = get_settings()


@router.get("/google/login/", status_code=HTTPStatus.FOUND)
async def login(target: Optional[str] = None):
"""
Login using Google OAuth
"""
return RedirectResponse(
url=get_authorize_url(state=target),
status_code=HTTPStatus.FOUND,
)


@router.get("/google/token/")
async def get_access_token(
request: Request,
state: Optional[str] = None,
error: Optional[str] = None,
session: Session = Depends(get_session),
setting: Settings = Depends(get_settings),
):
"""
Perform a token exchange, exchanging a google auth code for a google access token.
The google access token is then used to request user information and return a JWT
cookie. If the user does not already exist, a new user is created.
"""
if error:
raise DJException(
http_status_code=HTTPStatus.UNAUTHORIZED,
message="Ran into an error during Google auth: {error}",
)
hostname = urlparse(settings.url).hostname
url = str(request.url)
flow.fetch_token(authorization_response=url)
credentials = flow.credentials
request_session = requests.session()
token_request = google.auth.transport.requests.Request(session=request_session)
user_data = id_token.verify_oauth2_token(
id_token=credentials._id_token, # pylint: disable=protected-access
request=token_request,
audience=setting.google_oauth_client_id,
)

existing_user = session.exec(
select(User).where(User.email == user_data["email"]),
).one_or_none()
if existing_user:
_logger.info("OAuth user found")
user = existing_user
else:
_logger.info("OAuth user does not exist, creating a new user")
new_user = User(
username=user_data["email"],
email=user_data["email"],
password=get_password_hash(secrets.token_urlsafe(13)),
name=user_data["name"],
oauth_provider=OAuthProvider.GOOGLE,
)
session.add(new_user)
session.commit()
session.refresh(new_user)
user = new_user
response = RedirectResponse(url=urljoin(settings.frontend_host, state)) # type: ignore
response.set_cookie(
AUTH_COOKIE,
create_token({"username": user.email}, expires_delta=timedelta(days=365)),
httponly=True,
samesite="none",
secure=True,
domain=hostname,
)
response.set_cookie(
LOGGED_IN_FLAG_COOKIE,
"true",
samesite="none",
secure=True,
domain=hostname,
)
return response
19 changes: 16 additions & 3 deletions datajunction-server/datajunction_server/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from datajunction_server.api.attributes import default_attribute_types
from datajunction_server.api.authentication import whoami
from datajunction_server.api.graphql.main import graphql_app
from datajunction_server.constants import DJ_AUTH_COOKIE, DJ_LOGGED_IN_FLAG_COOKIE
from datajunction_server.constants import AUTH_COOKIE, LOGGED_IN_FLAG_COOKIE
from datajunction_server.errors import DJException
from datajunction_server.models.catalog import Catalog
from datajunction_server.models.column import Column
Expand Down Expand Up @@ -121,8 +121,8 @@ async def dj_exception_handler( # pylint: disable=unused-argument
)
# If unauthorized, clear out any DJ cookies
if exc.http_status_code == HTTPStatus.UNAUTHORIZED:
response.delete_cookie(DJ_AUTH_COOKIE, httponly=True)
response.delete_cookie(DJ_LOGGED_IN_FLAG_COOKIE)
response.delete_cookie(AUTH_COOKIE, httponly=True)
response.delete_cookie(LOGGED_IN_FLAG_COOKIE)
return response


Expand All @@ -143,3 +143,16 @@ async def dj_exception_handler( # pylint: disable=unused-argument
from datajunction_server.api.authentication import github

app.include_router(github.router)

# Only mount google auth router if a google oauth is configured
if all(
[
settings.secret,
settings.google_oauth_client_id,
settings.google_oauth_client_secret,
settings.google_oauth_client_secret_file,
],
): # pragma: no cover
from datajunction_server.api.authentication import google

app.include_router(google.router)
13 changes: 12 additions & 1 deletion datajunction-server/datajunction_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Settings(
# DJ UI host, used for OAuth redirection
frontend_host: Optional[str] = "http://localhost:3000"

# Library to use when transpiling SQL to other dialects
sql_transpilation_library: Optional[str] = None

# DJ secret, used to encrypt passwords and JSON web tokens
secret: Optional[str] = None

Expand All @@ -72,7 +75,15 @@ class Settings(

# GitHub OAuth application client secret
github_oauth_client_secret: Optional[str] = None
sql_transpilation_library: Optional[str] = None

# Google OAuth application client ID
google_oauth_client_id: Optional[str] = None

# Google OAuth application client secret
google_oauth_client_secret: Optional[str] = None

# Google OAuth application client secret file
google_oauth_client_secret_file: Optional[str] = None

@property
def celery(self) -> Celery:
Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
QUERY_EXECUTE_TIMEOUT = timedelta(seconds=60)
GET_COLUMNS_TIMEOUT = timedelta(seconds=60)

DJ_AUTH_COOKIE = "__dj"
DJ_LOGGED_IN_FLAG_COOKIE = "__djlif"
AUTH_COOKIE = "__dj"
LOGGED_IN_FLAG_COOKIE = "__djlif"
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Google OAuth helper functions
"""
import logging
import secrets
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin

import google_auth_oauthlib.flow
import requests
from google.auth.external_account_authorized_user import Credentials
from sqlmodel import select

from datajunction_server.errors import DJException
from datajunction_server.internal.authentication.basic import get_password_hash
from datajunction_server.models.user import OAuthProvider, User
from datajunction_server.utils import get_session, get_settings

_logger = logging.getLogger(__name__)

settings = get_settings()
flow = ( # pylint: disable=invalid-name
google_auth_oauthlib.flow.Flow.from_client_secrets_file(
settings.google_oauth_client_secret_file,
scopes=[
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/userinfo.email",
"openid",
],
redirect_uri=urljoin(settings.url, "/google/token/"),
)
if settings.google_oauth_client_secret_file
else None
)


def get_authorize_url(
state: Optional[str] = None,
) -> google_auth_oauthlib.flow.Flow:
"""
Get the authorize url for a Google OAuth app
"""
authorization_url, _ = flow.authorization_url(
access_type="offline",
include_granted_scopes="true",
prompt="consent",
state=state,
)
return authorization_url


def get_google_access_token(
authorization_response_url: str,
) -> Credentials:
"""
Exchange an authorization token for an access token
"""
flow.fetch_token(authorization_response=authorization_response_url)
return flow.credentials


def get_google_user(token: str) -> User:
"""
Get the google user using an access token
"""
headers = {"Accept": "application/json", "Authorization": f"Bearer {token}"}
response = requests.get(
"https://www.googleapis.com/oauth2/v2/userinfo?alt=json",
headers=headers,
timeout=10,
)
if not response.ok:
raise DJException(
http_status_code=HTTPStatus.FORBIDDEN,
message=f"Error retrieving Google user: {response.text}",
)
user_data = response.json()
if "message" in user_data and user_data["message"] == "Bad credentials":
raise DJException(
http_status_code=HTTPStatus.FORBIDDEN,
message=f"Error retrieving Google user: {response.text}",
)
session = next(get_session())
existing_user = session.exec(
select(User).where(User.email == user_data["login"]),
).one_or_none()
if existing_user:
_logger.info("OAuth user found")
user = existing_user
else:
_logger.info("OAuth user does not exist, creating a new user")
new_user = User(
username=user_data["email"],
email=user_data["email"],
password=get_password_hash(secrets.token_urlsafe(13)),
name=user_data["name"],
oauth_provider=OAuthProvider.GOOGLE,
)
session.add(new_user)
session.commit()
session.refresh(new_user)
user = new_user
return user
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlmodel import Session
from starlette.requests import Request

from datajunction_server.constants import DJ_AUTH_COOKIE
from datajunction_server.constants import AUTH_COOKIE
from datajunction_server.errors import DJError, DJException, ErrorCode
from datajunction_server.internal.authentication.basic import get_user
from datajunction_server.internal.authentication.tokens import decode_token
Expand All @@ -30,7 +30,7 @@ async def __call__(
session: Session = Depends(get_session),
) -> None:
# First check for a JWT sent in a cookie
jwt = request.cookies.get(DJ_AUTH_COOKIE)
jwt = request.cookies.get(AUTH_COOKIE)
if jwt:
try:
jwt_data = await decode_token(jwt)
Expand Down
1 change: 1 addition & 0 deletions datajunction-server/datajunction_server/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class OAuthProvider(Enum):

BASIC = "basic"
GITHUB = "github"
GOOGLE = "google"


class User(SQLModel, table=True): # type: ignore
Expand Down
Loading

0 comments on commit d5616a4

Please sign in to comment.