Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify testing db #75

Merged
merged 1 commit into from
Jul 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 28 additions & 63 deletions libiam/src/testing/db.rs
Original file line number Diff line number Diff line change
@@ -1,94 +1,59 @@
use async_trait::async_trait;
use futures::{channel::mpsc, StreamExt};
use sea_orm::{
ConnectOptions, ConnectionTrait, DbBackend, DbErr, ExecResult, QueryResult, Statement,
ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr, ExecResult, QueryResult,
Statement,
};
use tokio::{runtime, sync::oneshot};
use std::sync::Arc;
use tokio::runtime::{self, Runtime};
use tracing::log::LevelFilter;

#[derive(Clone)]
pub struct Database {
channel: mpsc::UnboundedSender<Message>,
runtime: Arc<Runtime>,
conn: Arc<DatabaseConnection>,
}

impl Database {
pub async fn connect(uri: &str) -> Self {
let (tx, mut rx) = mpsc::unbounded();

let uri = uri.to_owned();
std::thread::spawn(move || {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");

rt.block_on(async move {
let mut opts = ConnectOptions::new(uri);
opts.sqlx_logging_level(LevelFilter::Debug);
let runtime = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

let conn = sea_orm::Database::connect(opts)
.await
.expect("failed to connect to database");
let conn = {
let mut opts = ConnectOptions::new(uri.to_owned());
opts.sqlx_logging_level(LevelFilter::Debug);

while let Some(msg) = rx.next().await {
match msg {
Message::Execute(tx, stmt) => {
let res = conn.execute(stmt).await;
tx.send(res).unwrap();
}
Message::QueryOne(tx, stmt) => {
let res = conn.query_one(stmt).await;
tx.send(res).unwrap();
}
Message::QueryAll(tx, stmt) => {
let res = conn.query_all(stmt).await;
tx.send(res).unwrap();
}
}
}
});
});
sea_orm::Database::connect(opts)
.await
.expect("failed to connect to database")
};

Self { channel: tx }
Self {
runtime: Arc::new(runtime),
conn: Arc::new(conn),
}
}
}

pub enum Message {
Execute(oneshot::Sender<Result<ExecResult, DbErr>>, Statement),
QueryOne(
oneshot::Sender<Result<Option<QueryResult>, DbErr>>,
Statement,
),
QueryAll(oneshot::Sender<Result<Vec<QueryResult>, DbErr>>, Statement),
}

#[async_trait]
impl ConnectionTrait for Database {
fn get_database_backend(&self) -> DbBackend {
DbBackend::MySql
self.conn.get_database_backend()
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::Execute(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.execute(stmt)).await.unwrap()
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::QueryOne(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.query_one(stmt)).await.unwrap()
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
let (tx, rx) = oneshot::channel();
self.channel
.unbounded_send(Message::QueryAll(tx, stmt))
.unwrap();
rx.await.unwrap()
let this: &'static Self = unsafe { std::mem::transmute(self) };
self.runtime.spawn(this.conn.query_all(stmt)).await.unwrap()
}
}
Loading