From 1c042e5a8dff971814fdcbad638d324a66fc1811 Mon Sep 17 00:00:00 2001 From: Pyreko <25498386+Pyreko@users.noreply.github.com> Date: Mon, 29 Jan 2024 03:00:59 -0500 Subject: [PATCH] use less-terrible way to keep track of count state (#26) * use less-terrible way to keep track of count state * bmp * add resync check * comments and unset dirty if resync --- server/Cargo.lock | 7 ++ server/Cargo.toml | 1 + server/clean_log.sh | 2 +- server/restart.sh | 4 +- server/src/main.rs | 192 ++++++++++++++++++++++---------------------- server/src/state.rs | 144 +++++++++++++++++++++++++++++++++ 6 files changed, 249 insertions(+), 101 deletions(-) create mode 100644 server/src/state.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index f0f8727..c787fb6 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -70,6 +70,12 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +[[package]] +name = "anyhow" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" + [[package]] name = "async-compression" version = "0.3.15" @@ -547,6 +553,7 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" name = "hows-the-volume-server" version = "0.1.0" dependencies = [ + "anyhow", "axum", "dotenv", "serde", diff --git a/server/Cargo.toml b/server/Cargo.toml index 1a65e54..cf12c6b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" strip = "symbols" [dependencies] +anyhow = "1.0.79" axum = "0.5.17" # The webapp framework of choice. dotenv = "0.15.0" # For accessing env vars stored in .env serde = { version = "1.0.137", features = ["derive"] } # For JSON serialization. diff --git a/server/clean_log.sh b/server/clean_log.sh index 8d8a39b..854054f 100755 --- a/server/clean_log.sh +++ b/server/clean_log.sh @@ -2,4 +2,4 @@ set -eu -find /home/pyreko/htv-server/ -mtime +7 -name "volume.log*" -print -exec /bin/rm {} \; \ No newline at end of file +find ~/yc-server/ -mtime +7 -name "volume.log*" -print -exec /bin/rm {} \; diff --git a/server/restart.sh b/server/restart.sh index 1ca3f4e..599894f 100755 --- a/server/restart.sh +++ b/server/restart.sh @@ -1,8 +1,8 @@ #!/bin/bash -set -eu +set -eux -pkill -2 hows-the-volume +pkill -2 hows-the-volume || true while pgrep -u $UID -x hows-the-volume >/dev/null; do sleep 1; done cp ./target/release/hows-the-volume-server ~/htv-server/hows-the-volume-server cp -r ./assets ~/htv-server/ diff --git a/server/src/main.rs b/server/src/main.rs index d8136c1..5512eb5 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,5 +1,8 @@ -use std::{env, fs, sync::Arc}; +mod state; +use std::{env, fs, sync::Arc, time::Duration}; + +use anyhow::Result; use axum::{ body::Body, extract::{rejection::PathRejection, Path}, @@ -11,74 +14,16 @@ use axum::{ }; use dotenv::dotenv; use serde::Serialize; + use sqlx::{Pool, Sqlite, SqlitePool}; +use state::State; +use tokio::time::timeout; use tower::util::ServiceExt; use tower_http::{cors::CorsLayer, services::ServeDir}; -use tracing::{error, info, warn}; +use tracing::{error, info}; use tracing_subscriber::filter::EnvFilter; -#[tokio::main] -async fn main() { - dotenv().ok().unwrap(); - - let file_appender = tracing_appender::rolling::daily("./", "volume.log"); - let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender); - - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .with_writer(non_blocking) - .with_ansi(false) - .init(); - - info!("Starting up HTV..."); - let pool = SqlitePool::connect(&env::var("DATABASE_URL").unwrap()) - .await - .unwrap(); - - let pool_rc = Arc::new(pool); - - let origins = [ - "http://localhost:3000".parse::().unwrap(), - "https://howsthevolu.me".parse::().unwrap(), - ]; - let cors = CorsLayer::new() - .allow_methods(vec![Method::GET, Method::POST]) - .allow_origin(origins); - - let app = Router::new() - .route("/sound/:id", get(sound)) - .route("/count", get(count)) - .route("/increment", post(increment)) - .route("/num-files", get(num_audio_tracks)) - .layer(cors) - .layer(Extension(pool_rc.clone())) - .fallback(not_found_handler.into_service()); - - let addr = "127.0.0.1:8080".parse().unwrap(); - - info!("Listening on {}", addr); - - if std::path::Path::new("assets/").exists() { - let num_files = fs::read_dir("assets/").unwrap().count(); - info!("Found {} files in assets!", num_files); - } else { - error!("Warning - no asset/ folder found! There should be one located near the binary!"); - } - - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .with_graceful_shutdown(async { - tokio::signal::ctrl_c().await.unwrap(); - info!("Shutdown signal received."); - }) - .await - .unwrap(); - - info!("Closing SQLite connection."); - pool_rc.close().await; - - info!("Shutting down server."); -} +use crate::state::init_state; #[derive(Serialize)] struct EmptyJson {} @@ -91,8 +36,7 @@ fn not_found() -> StatusCode { StatusCode::NOT_FOUND } -type PoolExt = Arc>; - +/// A JSON message containing a count. #[derive(Serialize)] struct Count { count: u64, @@ -104,26 +48,9 @@ impl Count { } } -/// Returns the current global count, stored in the DB. -async fn count(Extension(pool): Extension) -> Json { - match pool.acquire().await { - Ok(mut conn) => { - match sqlx::query!("SELECT count FROM counts WHERE name = 'volume'") - .fetch_one(&mut conn) - .await - { - Ok(query) => Json(Count::new(query.count as u64)), - Err(err) => { - warn!("SQLite query for count failed - err: {}", err); - Json(Count::new(0)) - } - } - } - Err(err) => { - error!("Failed to get pool connection to SQLite DB - err: {}", err); - Json(Count::new(0)) - } - } +async fn count(Extension(state): Extension) -> Json { + let val = state.lock().await; + Json(Count::new(val.count)) } /// Returns the selected sound file if it exists. @@ -132,7 +59,7 @@ async fn sound(id_result: Result, PathRejection>) -> Response { const PREFIX: &str = "volume_"; const SUFFIX: &str = ".mp3"; - let uri = format!("/{}{:0>2}{}", PREFIX, id, SUFFIX); + let uri = format!("/{PREFIX}{id:0>2}{SUFFIX}"); match Request::builder().uri(&uri).body(Body::empty()) { Ok(req) => match ServeDir::new("assets/").oneshot(req).await { Ok(resp) => { @@ -141,11 +68,11 @@ async fn sound(id_result: Result, PathRejection>) -> Response { } } Err(err) => { - error!("Failed to get a response for file {} - err: {}", uri, err); + error!("Failed to get a response for file {uri} - err: {err}"); } }, Err(err) => { - error!("Failed to build a request for file {} - err: {}", uri, err); + error!("Failed to build a request for file {uri} - err: {err}"); } } } @@ -154,17 +81,19 @@ async fn sound(id_result: Result, PathRejection>) -> Response { } /// Increments the count. -async fn increment(Extension(pool): Extension) { - let mut conn = pool.acquire().await.unwrap(); +async fn increment(Extension(state): Extension) { + let mut val = state.lock().await; + val.count += 1; + val.dirty = true; +} - match sqlx::query!("UPDATE counts SET count = count + 1 WHERE name = 'volume'") - .execute(&mut conn) - .await +/// Open the SQLite pool. +async fn open_pool() -> Result>> { { - Ok(_) => {} - Err(err) => { - error!("Failed to increment in DB - err: {}", err); - } + let url = env::var("DATABASE_URL")?; + let pool = SqlitePool::connect(&url).await?; + + Ok(Arc::new(pool)) } } @@ -185,3 +114,70 @@ async fn num_audio_tracks() -> Json { Json(Count::new(num_tracks)) } + +#[tokio::main] +async fn main() -> Result<()> { + dotenv().ok().unwrap(); + + let file_appender = tracing_appender::rolling::daily("./", "volume.log"); + let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender); + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(non_blocking) + .with_ansi(false) + .init(); + + info!("Starting up HTV..."); + + let pool = open_pool().await?; + let (state, sync_task, shutdown) = init_state(pool.clone()).await?; + + let origins = [ + "http://localhost:3000".parse::()?, + "https://howsthevolu.me".parse::()?, + ]; + + let cors = CorsLayer::new() + .allow_methods(vec![Method::GET, Method::POST]) + .allow_origin(origins); + + let app = Router::new() + .route("/sound/:id", get(sound)) + .route("/count", get(count)) + .route("/increment", post(increment)) + .route("/num-files", get(num_audio_tracks)) + .layer(cors) + .layer(Extension(state)) + .fallback(not_found_handler.into_service()); + + let addr = "127.0.0.1:8080".parse()?; + + info!("Listening on {addr}"); + + if std::path::Path::new("assets/").exists() { + let num_files = fs::read_dir("assets/").unwrap().count(); + info!("Found {num_files} files in assets!"); + } else { + error!("Warning - no asset/ folder found! There should be one located near the binary!"); + } + + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .with_graceful_shutdown(async { + tokio::signal::ctrl_c().await.unwrap(); + info!("Shutdown signal received."); + }) + .await?; + + info!("Stopping sync task."); + let _ = shutdown.send(()); + let _ = timeout(Duration::from_secs(15), sync_task).await; + + info!("Closing SQLite connection."); + pool.close().await; + + info!("Cleanup complete, shutting down HTV server."); + + Ok(()) +} diff --git a/server/src/state.rs b/server/src/state.rs new file mode 100644 index 0000000..2484dcf --- /dev/null +++ b/server/src/state.rs @@ -0,0 +1,144 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use sqlx::{Pool, Sqlite}; +use tokio::{ + sync::{mpsc::UnboundedSender, Mutex}, + task::JoinHandle, +}; +use tracing::{error, info, warn}; + +#[derive(Clone)] +pub(crate) struct Value { + pub(crate) count: u64, + pub(crate) dirty: bool, +} + +/// Represents the current count and a connection to the SQLite pool. +pub type State = Arc>; + +/// Returns the current global count, from the DB. This is the true source of truth! +async fn get_db_count(pool: &Pool) -> Result { + let mut conn = pool.acquire().await?; + let query = sqlx::query!("SELECT count FROM counts WHERE name = 'volume'") + .fetch_one(&mut conn) + .await?; + + Ok(query.count as u64) +} + +/// Try to update the DB count. If the query was successful, it returns whether any rows were changed. +async fn update_db_count(pool: Arc>, new_count: u64) -> Result { + let mut conn = pool.acquire().await.unwrap(); + + // Note that u64 values aren't supported by SQLX for whatever reason, + // so we instead use an i64 internally. + let new_count = new_count as i64; + + // Update to either the max of the current value stored and the new value. + match sqlx::query!( + "UPDATE counts SET count = MAX(count, ?) WHERE name = 'volume'", + new_count + ) + .execute(&mut conn) + .await + { + Ok(result) => Ok(result.rows_affected() > 0), + Err(err) => { + error!("Failed to increment in DB - err: {err}"); + Err(err.into()) + } + } +} + +#[inline] +async fn sync_counts(pool: &Arc>, state: &State) { + let mut value = state.lock().await; + if value.dirty { + // Get values out and immediately clear dirty flag, then + // release the mutex. + + let new_val = value.count; + value.dirty = false; + drop(value); // XXX: YOU MUST CALL THIS, OR YOU WILL DEADLOCK THE INCREMENT JOB! + + if let Ok(did_update) = update_db_count(pool.clone(), new_val).await { + if !did_update { + // No update means that we might have to sync the state to be correct if the state is currently + // showing a LOWER value than the DB. + // + // This won't fire if the values are _equal_ though (or it shouldn't), due to the dirty + // bit check - if they're equal and the states are synced, then the DB update shouldn't + // have ever fired! + + if let Ok(current_db_value) = get_db_count(pool).await { + let mut value = state.lock().await; + if value.count < current_db_value { + value.count = current_db_value; + + // Unset dirty as they're now equal, so no need for the sync job to fire. + value.dirty = false; + + drop(value); + + warn!( + "The state's value was lower than what was in the DB - state resynced." + ); + } + } + } + } + } +} + +/// Initializes the state, along with the update task and shutdown sender. +pub(crate) async fn init_state( + pool: Arc>, +) -> Result<(State, JoinHandle<()>, UnboundedSender<()>)> { + let state = { + let count = match get_db_count(&pool).await { + Ok(val) => val, + Err(err) => { + warn!("Error while trying to get the currently stored DB value of count, defaulting to 0: {err}"); + 0 + } + }; + + info!("Count on startup: {count}"); + + Arc::new(Mutex::new(Value { + count, + dirty: false, + })) + }; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let sync_task = { + let state = state.clone(); + let pool = pool.clone(); + + tokio::task::spawn(async move { + let mut rx = rx; + let mut interval = tokio::time::interval(Duration::from_secs(10)); + + loop { + tokio::select! { + _ = rx.recv() => { + // Sync one last time, then stop. + warn!("Syncing update task before shutting down..."); + sync_counts(&pool, &state).await; + warn!("Synced. Update task shutting down."); + + break; + } + _ = interval.tick() => { + sync_counts(&pool, &state).await; + } + } + } + }) + }; + + Ok((state, sync_task, tx)) +}