Skip to content

Commit

Permalink
add playwright tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Aug 15, 2024
1 parent ff806e9 commit 605d09b
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 41 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ docs/source/_build/
example_projects/token_auth/
.env/
.venv/

# Playwright
videos/
Empty file added e2e/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import time
from http.client import HTTPConnection
from subprocess import Popen

import pytest

HOST = "localhost"
PORT = 8000
BASE_URL = f"http://{HOST}:{PORT}"


@pytest.fixture
def browser_context_args():
return {"record_video_dir": "videos/"}


@pytest.fixture
def context(context):
# We don't need a really long timeout.
# The timeout determines how long Playwright waits for a HTML element to
# become available.
# By default it's 30 seconds, which is way too long when testing an app
# locally.
context.set_default_timeout(10000)
yield context


@pytest.fixture
def mfa_app():
"""
Running dev server and Playwright test in parallel.
More info https://til.simonwillison.net/pytest/playwright-pytest
"""
path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"example_projects",
"mfa_demo",
)

process = Popen(
["python", "-m", "main", "--reset-db"],
cwd=path,
)
retries = 5
while retries > 0:
conn = HTTPConnection(f"{HOST}:{PORT}")
try:
conn.request("HEAD", "/")
response = conn.getresponse()
if response is not None:
yield process
break
except ConnectionRefusedError:
time.sleep(1)
retries -= 1

if not retries:
raise RuntimeError("Failed to start http server")
else:
process.terminate()
process.wait()
57 changes: 57 additions & 0 deletions e2e/pages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
By using pages we can make out test more scalable.
https://playwright.dev/docs/pom
"""

from playwright.sync_api import Page


class LoginPage:
url = "http://localhost:8000/login/"

def __init__(self, page: Page):
self.page = page
self.username_input = page.locator('input[name="username"]')
self.password_input = page.locator('input[name="password"]')
self.button = page.locator("button")

def reset(self):
self.page.goto(self.url)

def login(self):
self.username_input.fill("piccolo")
self.password_input.fill("piccolo123")
self.button.click()


class RegisterPage:
url = "http://localhost:8000/register/"

def __init__(self, page: Page):
self.page = page
self.username_input = page.locator("[name=username]")
self.email_input = page.locator("[name=email]")
self.password_input = page.locator("[name=password]")
self.confirm_password_input = page.locator("[name=confirm_password]")
self.button = page.locator("button")

def reset(self):
self.page.goto(self.url)

def login(self):
self.username_input.fill("piccolo")
self.email_input.fill("[email protected]")
self.password_input.fill("piccolo123")
self.confirm_password_input.fill("piccolo123")
self.button.click()


class MFARegisterPage:
url = "http://localhost:8000/private/mfa-register/"

def __init__(self, page: Page):
self.page = page

def reset(self):
self.page.goto(self.url)
19 changes: 19 additions & 0 deletions e2e/test_mfa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from playwright.async_api import Page

from .pages import LoginPage, MFARegisterPage, RegisterPage


def test_login(page: Page, mfa_app):
"""
Make sure we can register, sign up for MFA.
"""
register_page = RegisterPage(page=page)
register_page.reset()
register_page.login()

login_page = LoginPage(page=page)
login_page.reset()
login_page.login()

mfa_register_page = MFARegisterPage(page=page)
mfa_register_page.reset()
24 changes: 13 additions & 11 deletions example_projects/mfa_demo/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os

from jinja2 import Environment, FileSystemLoader
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
Expand All @@ -16,19 +19,18 @@
EXAMPLE_DB_ENCRYPTION_KEY = "wqsOqyTTEsrWppZeIMS8a3l90yPUtrqT48z7FS6_U8g="


environment = Environment(
loader=FileSystemLoader(
os.path.join(os.path.dirname(__file__), "templates"),
)
)

Check warning

Code scanning / CodeQL

Jinja2 templating with autoescape=False Medium

Using jinja2 templates with autoescape=False can potentially allow XSS attacks.


class HomeEndpoint(HTTPEndpoint):
async def get(self, request):
return HTMLResponse(
content=(
"<style>body{font-family: sans-serif;}</style>"
"<h1>MFA Demo</h1>"
'<p>First <a href="/register/">register</a></p>' # noqa: E501
'<p>Then <a href="/login/">login</a></p>' # noqa: E501
'<p>Then <a href="/private/mfa-register/">sign up for MFA</a></p>' # noqa: E501
'<p>Then try the <a href="/private/">private page</a></p>' # noqa: E501
'<p>And <a href="/private/logout/">logout</a></p>' # noqa: E501
)
)
home_template = environment.get_template("home.html")

return HTMLResponse(content=home_template.render())


class PrivateEndpoint(HTTPEndpoint):
Expand Down
16 changes: 16 additions & 0 deletions example_projects/mfa_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,24 @@
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))


def reset_db():
print("Resetting DB ...")

from piccolo.apps.user.tables import BaseUser

from piccolo_api.mfa.authenticator.tables import AuthenticatorSecret
from piccolo_api.session_auth.tables import SessionsBase

BaseUser.delete(force=True).run_sync()
AuthenticatorSecret.delete(force=True).run_sync()
SessionsBase.delete(force=True).run_sync()


if __name__ == "__main__":

if "--reset-db" in sys.argv:
reset_db()

import uvicorn

uvicorn.run("app:app", reload=True)
24 changes: 24 additions & 0 deletions example_projects/mfa_demo/templates/home.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<html lang="en">

<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Home</title>

<style>
body {
font-family: sans-serif;
}
</style>
</head>

<body>
<h1>MFA Demo</h1>
<p>First <a href="/register/">register</a></p>
<p>Then <a href="/login/">login</a></p>
<p>Then <a href="/private/mfa-register/">sign up for MFA</a></p>
<p>Then try the <a href="/private/">private page</a></p>
<p>And <a href="/private/logout/">logout</a></p>
</body>

</html>
11 changes: 11 additions & 0 deletions piccolo_api/mfa/authenticator/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ async def get_registration_html(self, user: BaseUser) -> str:
When a user wants to register for MFA, this HTML is shown containing
instructions.
"""
# If the user is already enrolled, don't create a new secret.
if await self.secret_table.is_user_enrolled(user_id=user.id):
return self.register_template.render(
already_enrolled=True,
styles=self.styles,
)

secret, recovery_codes = await self.secret_table.create_new(
user_id=user.id,
db_encryption_key=self.db_encryption_key,
Expand Down Expand Up @@ -129,3 +136,7 @@ async def get_registration_json(self, user: BaseUser) -> dict:
)

return {"qrcode_image": qrcode_image, "recovery_codes": recovery_codes}

async def delete_registration(self, user: BaseUser) -> str:
await self.secret_table.revoke_all(user_id=user.id)
return "<p>Successfully deleted</p>"
12 changes: 11 additions & 1 deletion piccolo_api/mfa/authenticator/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ async def create_new(

return (instance, recovery_codes)

@classmethod
async def revoke_all(cls, user_id: int):
now = datetime.datetime.now(tz=datetime.timezone.utc)
await cls.update({cls.revoked_at: now}).where(
cls.user_id == user_id,
cls.revoked_at.is_null(),
)

@classmethod
async def authenticate(
cls, user_id: int, code: str, db_encryption_key: str
Expand Down Expand Up @@ -241,7 +249,9 @@ async def authenticate(

@classmethod
async def is_user_enrolled(cls, user_id: int) -> bool:
return await cls.exists().where(cls.user_id == user_id)
return await cls.exists().where(
cls.user_id == user_id, cls.revoked_at.is_null()
)

def get_authentication_setup_uri(
self,
Expand Down
63 changes: 48 additions & 15 deletions piccolo_api/mfa/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing as t
from abc import ABCMeta, abstractmethod
from json import JSONDecodeError

from piccolo.apps.user.tables import BaseUser
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
Expand All @@ -16,23 +18,54 @@ def _provider(self) -> MFAProvider:
raise NotImplementedError

async def get(self, request: Request):
piccolo_user = request.user.user

if request.query_params.get("format") == "json":
json_content = await self._provider.get_registration_json(
user=piccolo_user
)
return JSONResponse(content=json_content)
else:
html_content = await self._provider.get_registration_html(
user=piccolo_user
)
return HTMLResponse(content=html_content)
return HTMLResponse(
content=f"""
<form method="post">
<input type="hidden" name="action" value="register" />
<input type="hidden" name="csrftoken" value="{request.scope['csrftoken']}" />
<button>Register</button>
</form>
""" # noqa: E501
)

async def post(self, request: Request):
# TODO - we might need the user to confirm once they're setup.
# We could embed the ID of the row in the HTML response (in a form).
pass
piccolo_user: BaseUser = request.user.user

# Some middleware (for example CSRF) has already awaited the request
# body, and adds it to the request.
body: t.Any = request.scope.get("form")

if not body:
try:
body = await request.json()
except JSONDecodeError:
body = await request.form()

if action := body.get("action"):
if action == "register":
if body.get("format") == "json":
json_content = await self._provider.get_registration_json(
user=piccolo_user
)
return JSONResponse(content=json_content)
else:
html_content = await self._provider.get_registration_html(
user=piccolo_user
)
return HTMLResponse(content=html_content)
elif action == "revoke":
if password := body.get("password"):
if await piccolo_user.__class__.login(
username=piccolo_user.username, password=password
):
html_content = (
await self._provider.delete_registration(
user=piccolo_user
)
)
return HTMLResponse(content=html_content)

return HTMLResponse(content="<p>Error</p>")


def mfa_register_endpoint(provider: MFAProvider) -> t.Type[HTTPEndpoint]:
Expand Down
7 changes: 7 additions & 0 deletions piccolo_api/mfa/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ async def get_registration_json(self, user: BaseUser) -> dict:
response, rather than HTML, if they want to render the UI themselves.
"""
pass

@abstractmethod
async def delete_registration(self, user: BaseUser) -> str:
"""
Used to remove the MFA.
"""
pass
Loading

0 comments on commit 605d09b

Please sign in to comment.