diff --git a/Cargo.lock b/Cargo.lock index 613625f..ef5f871 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -389,6 +389,17 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -472,6 +483,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -571,6 +583,20 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -615,6 +641,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -635,6 +672,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -977,6 +1015,7 @@ checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", + "serde", ] [[package]] @@ -1005,7 +1044,9 @@ dependencies = [ "serde_json", "sqlx", "thiserror", + "time", "tokio", + "tower-sessions", "tracing", "tracing-subscriber", ] @@ -2289,6 +2330,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-cookies" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fd0118512cf0b3768f7fcccf0bef1ae41d68f2b45edc1e77432b36c97c56c6d" +dependencies = [ + "async-trait", + "axum-core", + "cookie", + "futures-util", + "http", + "parking_lot", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2301,6 +2359,57 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "tower-sessions" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65856c81ee244e0f8a55ab0f7b769b72fbde387c235f0a73cd97c579818d05eb" +dependencies = [ + "async-trait", + "http", + "time", + "tokio", + "tower-cookies", + "tower-layer", + "tower-service", + "tower-sessions-core", + "tower-sessions-memory-store", + "tracing", +] + +[[package]] +name = "tower-sessions-core" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb6abbfcaf6436ec5a772cd9f965401da12db793e404ae6134eac066fa5a04f3" +dependencies = [ + "async-trait", + "axum-core", + "base64", + "futures", + "http", + "parking_lot", + "rand", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", +] + +[[package]] +name = "tower-sessions-memory-store" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fad75660c8afbe74f4e7cbbe8e9090171a056b57370ea4d7d5e9eb3e4af3092" +dependencies = [ + "async-trait", + "time", + "tokio", + "tower-sessions-core", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/Cargo.toml b/Cargo.toml index 7d2781d..9d75864 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,3 +28,5 @@ maud = "0.26.0" sentry = { version = "0.34.0", features = ["tracing", "reqwest", "rustls"], default-features = false } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +tower-sessions = "0.13.0" +time = "0.3.36" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..3e0d538 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,33 @@ +use axum::async_trait; +use axum::extract::FromRequestParts; +use axum::http::{request::Parts, StatusCode}; +use tower_sessions::Session; + +use crate::AccountPk; +use crate::ResponseError; + +pub const SESSION_COOKIE_KEY: &str = "auth"; + +pub struct LoggedIn { + pub account: Option, +} + +impl LoggedIn { + pub fn account(&self) -> Result { + self.account.clone().ok_or(ResponseError::NeedsAuth) + } +} + +#[async_trait] +impl FromRequestParts for LoggedIn +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(req: &mut Parts, state: &S) -> Result { + let session = Session::from_request_parts(req, state).await?; + let account: Option = session.get(SESSION_COOKIE_KEY).await.unwrap(); + Ok(LoggedIn { account }) + } +} diff --git a/src/error.rs b/src/error.rs index 3261763..014744a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use axum::{ http::StatusCode, - response::{IntoResponse, Response}, + response::{IntoResponse, Redirect, Response}, }; use reqwest::header::InvalidHeaderValue; use tokio::task::JoinError; @@ -17,11 +17,20 @@ pub enum ResponseError { InvalidHeader(#[from] InvalidHeaderValue), #[error("invalid JSON input: {0}")] Json(#[from] serde_json::Error), + #[error("failed to update session")] + Session(#[from] tower_sessions::session::Error), + #[error("no login found")] + NeedsAuth, } impl IntoResponse for ResponseError { fn into_response(self) -> Response { - tracing::error!("error while serving request: {}", self); - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}\n", self)).into_response() + match self { + ResponseError::NeedsAuth => Redirect::to("/").into_response(), + _ => { + tracing::error!("error while serving request: {}", self); + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}\n", self)).into_response() + } + } } } diff --git a/src/main.rs b/src/main.rs index fc68959..ee981e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,18 +8,20 @@ use axum::{ body::Body, debug_handler, extract::{Host, Query, State}, - response::{Html, IntoResponse, Response}, + response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, Form, Router, }; use clap::Parser; use maud::Markup; use serde::{Deserialize, Serialize}; +use tower_sessions::{Expiry, Session}; mod api_cache; mod api_client; mod api_helpers; mod api_models; +mod auth; mod config; mod error; mod list_manager; @@ -30,6 +32,9 @@ use config::Server; use config::{Cli, Subcommand}; use error::ResponseError; use store::{AccountPk, RegisterAccount, SyncImmediateResult}; +use tower_sessions::{MemoryStore, SessionManagerLayer}; + +use crate::auth::{LoggedIn, SESSION_COOKIE_KEY}; #[tokio::main] async fn main() -> Result<(), Error> { @@ -86,6 +91,14 @@ async fn serve(server_cli: Server) -> Result<(), Error> { } }); + let session_store = MemoryStore::default(); + let session_layer = SessionManagerLayer::new(session_store) + .with_secure(!cfg!(debug_assertions)) + // https://bugzilla.mozilla.org/show_bug.cgi?id=1465402 + // https://issues.chromium.org/issues/40508226#comment2 + .with_same_site(tower_sessions::cookie::SameSite::Lax) + .with_expiry(Expiry::OnInactivity(time::Duration::seconds(3600))); + let app = Router::new() .route("/", get(index)) // If this line is failing compilation, you need to run 'npm install && npm run build' to get your CSS bundle. @@ -108,8 +121,11 @@ async fn serve(server_cli: Server) -> Result<(), Error> { }), ) .route("/account/login", post(account_login)) + .route("/account/logout", post(account_logout)) .route("/account/sync-immediate", post(sync_immediate)) - .route("/account", get(account)) + .route("/account/oauth-redirect", get(account_redirect)) + .route("/account/admin", get(account_admin)) + .layer(session_layer) .with_state(state); tracing::info!("listening on {}", socketaddr_str); @@ -152,13 +168,14 @@ async fn index() -> Response { async fn sync_immediate( State(state): State, - Form(account_pk): Form, + login: LoggedIn, ) -> Result { + let account_pk = login.account()?; let body = state.store.sync_immediate(account_pk).await?; let html: maud::Markup = match body { SyncImmediateResult::Ok => maud::html! { - p { "Done syncing! Future updates to your lists will happen automatically." } + p { "Done syncing! Refresh the page to see results. Future updates to your lists will happen automatically." } }, SyncImmediateResult::Error { value } => maud::html! { p.red { "Error: "(value) } @@ -186,12 +203,20 @@ struct OauthState { host: String, } +fn get_service_uri(Host(self_host): Host) -> String { + if cfg!(debug_assertions) { + format!("http://{self_host}") + } else { + format!("https://{self_host}") + } +} + async fn account_login( - Host(self_host): Host, + self_host: Host, Form(AccountRegister { host }): Form, ) -> Result { - let service_uri = format!("https://{self_host}"); - let self_redirect_uri = format!("{service_uri}/account"); + let service_uri = get_service_uri(self_host); + let self_redirect_uri = format!("{service_uri}/account/oauth-redirect"); let client = ApiClient::new(&host, None).unwrap(); let scopes = "read:follows read:lists read:accounts write:lists"; @@ -240,6 +265,12 @@ async fn account_login( .unwrap()) } +#[debug_handler] +async fn account_logout(session: Session) -> Result { + session.remove::(SESSION_COOKIE_KEY).await?; + Ok(Redirect::to("/").into_response()) +} + #[derive(Deserialize)] struct OauthAccountRedirect { code: String, @@ -247,16 +278,17 @@ struct OauthAccountRedirect { } #[debug_handler] -async fn account( - Host(self_host): Host, +async fn account_redirect( + session: Session, + self_host: Host, State(state): State, Query(OauthAccountRedirect { code, state: oauth_state, }): Query, ) -> Result { - let service_uri = format!("https://{self_host}"); - let self_redirect_uri = format!("{service_uri}/account"); + let service_uri = get_service_uri(self_host); + let self_redirect_uri = format!("{service_uri}/account/oauth-redirect"); let OauthState { client_id, @@ -292,14 +324,30 @@ async fn account( }; let account = state.store.register(register_account).await?; + session + .insert(SESSION_COOKIE_KEY, account.primary_key()) + .await?; + Ok(Redirect::to("/account/admin").into_response()) +} + +#[debug_handler] +async fn account_admin( + State(state): State, + login: LoggedIn, +) -> Result { + let account_pk = login.account()?; + let account = state.store.get_account(account_pk).await?; let html = maud::html! { div { - // hide account credentials in query string from browser history - script { "history.replaceState({}, '', '/');" } - p.green { "Hello "(account.username)"@"(account.host)"!" } + form.pure-form + method="post" + action="/account/logout" { + input type="submit" value="Logout"; + } + @if account.failure_count > 0 { p.red { "We have encountered "(account.failure_count)" fatal errors when trying to sync. After 10 attempts, we will stop synchronizing." @@ -320,7 +368,7 @@ async fn account( } p { - "Your lists will be updated once every day automatically. Take a look at the " a href="https://github.com/untitaker/mastodon-list-bot#how-to-use" { "README" } " to see which list names are supported. After that, click Sync Now." + "Your lists will be updated once per day. Take a look at the " a href="https://github.com/untitaker/mastodon-list-bot#how-to-use" { "README" } " to see which list names are supported. After that, click Sync Now." } form.pure-form @@ -331,8 +379,6 @@ async fn account( data-hx-swap="innerHTML" data-hx-target="#sync-result" data-hx-disabled-elt="input[type=submit]" { - input type="hidden" name="host" value=(account.host); - input type="hidden" name="username" value=(account.username); input type="submit" value="Sync now"; p id="sync-result"; } diff --git a/src/store.rs b/src/store.rs index b524a29..f6475c7 100644 --- a/src/store.rs +++ b/src/store.rs @@ -106,11 +106,15 @@ impl Store { account.host, account.token, account.username, account.created_at, account.last_success_at, account.last_error, account.failure_count, account.list_count, ).execute(&self.pool).await?; + Ok(account) + } + + pub async fn get_account(&self, pk: AccountPk) -> Result { let account = sqlx::query_as!( Account, "select * from accounts where host = ?1 and username = ?2", - account.host, - account.username + pk.host, + pk.username ) .fetch_one(&self.pool) .await?;