From 3d9271ea0d403c35413d74406d8d66b56a1c3f53 Mon Sep 17 00:00:00 2001 From: Szepesi Tibor Date: Sun, 7 Jul 2024 22:40:35 +0200 Subject: [PATCH] Simplify testing db --- libiam/src/testing/db.rs | 91 +++++++++++++--------------------------- 1 file changed, 28 insertions(+), 63 deletions(-) diff --git a/libiam/src/testing/db.rs b/libiam/src/testing/db.rs index dfb4c53..21e07f6 100644 --- a/libiam/src/testing/db.rs +++ b/libiam/src/testing/db.rs @@ -1,94 +1,59 @@ use async_trait::async_trait; -use futures::{channel::mpsc, StreamExt}; use sea_orm::{ - ConnectOptions, ConnectionTrait, DbBackend, DbErr, ExecResult, QueryResult, Statement, + ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr, ExecResult, QueryResult, + Statement, }; -use tokio::{runtime, sync::oneshot}; +use std::sync::Arc; +use tokio::runtime::{self, Runtime}; use tracing::log::LevelFilter; #[derive(Clone)] pub struct Database { - channel: mpsc::UnboundedSender, + runtime: Arc, + conn: Arc, } impl Database { pub async fn connect(uri: &str) -> Self { - let (tx, mut rx) = mpsc::unbounded(); - - let uri = uri.to_owned(); - std::thread::spawn(move || { - let rt = runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to create tokio runtime"); - - rt.block_on(async move { - let mut opts = ConnectOptions::new(uri); - opts.sqlx_logging_level(LevelFilter::Debug); + let runtime = runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); - let conn = sea_orm::Database::connect(opts) - .await - .expect("failed to connect to database"); + let conn = { + let mut opts = ConnectOptions::new(uri.to_owned()); + opts.sqlx_logging_level(LevelFilter::Debug); - while let Some(msg) = rx.next().await { - match msg { - Message::Execute(tx, stmt) => { - let res = conn.execute(stmt).await; - tx.send(res).unwrap(); - } - Message::QueryOne(tx, stmt) => { - let res = conn.query_one(stmt).await; - tx.send(res).unwrap(); - } - Message::QueryAll(tx, stmt) => { - let res = conn.query_all(stmt).await; - tx.send(res).unwrap(); - } - } - } - }); - }); + sea_orm::Database::connect(opts) + .await + .expect("failed to connect to database") + }; - Self { channel: tx } + Self { + runtime: Arc::new(runtime), + conn: Arc::new(conn), + } } } -pub enum Message { - Execute(oneshot::Sender>, Statement), - QueryOne( - oneshot::Sender, DbErr>>, - Statement, - ), - QueryAll(oneshot::Sender, DbErr>>, Statement), -} - #[async_trait] impl ConnectionTrait for Database { fn get_database_backend(&self) -> DbBackend { - DbBackend::MySql + self.conn.get_database_backend() } async fn execute(&self, stmt: Statement) -> Result { - let (tx, rx) = oneshot::channel(); - self.channel - .unbounded_send(Message::Execute(tx, stmt)) - .unwrap(); - rx.await.unwrap() + let this: &'static Self = unsafe { std::mem::transmute(self) }; + self.runtime.spawn(this.conn.execute(stmt)).await.unwrap() } async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - let (tx, rx) = oneshot::channel(); - self.channel - .unbounded_send(Message::QueryOne(tx, stmt)) - .unwrap(); - rx.await.unwrap() + let this: &'static Self = unsafe { std::mem::transmute(self) }; + self.runtime.spawn(this.conn.query_one(stmt)).await.unwrap() } async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - let (tx, rx) = oneshot::channel(); - self.channel - .unbounded_send(Message::QueryAll(tx, stmt)) - .unwrap(); - rx.await.unwrap() + let this: &'static Self = unsafe { std::mem::transmute(self) }; + self.runtime.spawn(this.conn.query_all(stmt)).await.unwrap() } }