Skip to content

Commit

Permalink
Simplify testing db
Browse files Browse the repository at this point in the history
  • Loading branch information
smrtrfszm committed Jul 7, 2024
1 parent 447312c commit 3d9271e
Showing 1 changed file with 28 additions and 63 deletions.
91 changes: 28 additions & 63 deletions libiam/src/testing/db.rs
Original file line number Diff line number Diff line change
@@ -1,94 +1,59 @@
use async_trait::async_trait;
use futures::{channel::mpsc, StreamExt};
use sea_orm::{
ConnectOptions, ConnectionTrait, DbBackend, DbErr, ExecResult, QueryResult, Statement,
ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr, ExecResult, QueryResult,
Statement,
};
use tokio::{runtime, sync::oneshot};
use std::sync::Arc;
use tokio::runtime::{self, Runtime};
use tracing::log::LevelFilter;

#[derive(Clone)]
pub struct Database {
channel: mpsc::UnboundedSender<Message>,
runtime: Arc<Runtime>,
conn: Arc<DatabaseConnection>,
}

impl Database {
pub async fn connect(uri: &str) -> Self {
let (tx, mut rx) = mpsc::unbounded();

let uri = uri.to_owned();
std::thread::spawn(move || {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");

rt.block_on(async move {
let mut opts = ConnectOptions::new(uri);
opts.sqlx_logging_level(LevelFilter::Debug);
let runtime = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

let conn = sea_orm::Database::connect(opts)
.await
.expect("failed to connect to database");
let conn = {
let mut opts = ConnectOptions::new(uri.to_owned());
opts.sqlx_logging_level(LevelFilter::Debug);

while let Some(msg) = rx.next().await {
match msg {
Message::Execute(tx, stmt) => {
let res = conn.execute(stmt).await;
tx.send(res).unwrap();
}
Message::QueryOne(tx, stmt) => {
let res = conn.query_one(stmt).await;
tx.send(res).unwrap();
}
Message::QueryAll(tx, stmt) => {
let res = conn.query_all(stmt).await;
tx.send(res).unwrap();
}
}
}
});
});
sea_orm::Database::connect(opts)
.await
.expect("failed to connect to database")
};

Self { channel: tx }
Self {
runtime: Arc::new(runtime),
conn: Arc::new(conn),
}
}
}

pub enum Message {
Execute(oneshot::Sender<Result<ExecResult, DbErr>>, Statement),
QueryOne(
oneshot::Sender<Result<Option<QueryResult>, DbErr>>,
Statement,
),
QueryAll(oneshot::Sender<Result<Vec<QueryResult>, DbErr>>, Statement),
}

#[async_trait]
impl ConnectionTrait for Database {
fn get_database_backend(&self) -> DbBackend {
DbBackend::MySql
self.conn.get_database_backend()
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::Execute(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.execute(stmt)).await.unwrap()
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::QueryOne(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.query_one(stmt)).await.unwrap()
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::QueryAll(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.query_all(stmt)).await.unwrap()
}
}

0 comments on commit 3d9271e

Please sign in to comment.