Skip to content

Commit

Permalink
implement Connection trait
Browse files Browse the repository at this point in the history
  • Loading branch information
rkusa committed Mar 13, 2024
1 parent 5908fb0 commit 1d5991f
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 84 deletions.
167 changes: 167 additions & 0 deletions postgres/src/connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// TODO: remove once Rust's async lifetime in trait story got improved
#![allow(clippy::manual_async_fn)]

use std::future::Future;

use deadpool_postgres::GenericClient;
use tokio_postgres::types::ToSql;
use tokio_postgres::Row;

use crate::Error;

pub trait Connection: Send + Sync {
fn query_one<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Row, Error>> + Send + 'a;

fn query_opt<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Option<Row>, Error>> + Send + 'a;

fn query<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Vec<Row>, Error>> + Send + 'a;

fn execute<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<(), Error>> + Send + 'a;
}

impl Connection for deadpool_postgres::Client {
fn query_one<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Row, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Client::query_one(self, &stmt, parameters).await?)
}
}

fn query_opt<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Option<Row>, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Client::query_opt(self, &stmt, parameters).await?)
}
}

fn query<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Vec<Row>, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Client::query(self, &stmt, parameters).await?)
}
}

fn execute<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<(), Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
tokio_postgres::Client::execute(self, &stmt, parameters).await?;
Ok(())
}
}
}

impl<'t> Connection for deadpool_postgres::Transaction<'t> {
fn query_one<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Row, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Transaction::query_one(self, &stmt, parameters).await?)
}
}

fn query_opt<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Option<Row>, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Transaction::query_opt(self, &stmt, parameters).await?)
}
}

fn query<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Vec<Row>, Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
Ok(tokio_postgres::Transaction::query(self, &stmt, parameters).await?)
}
}

fn execute<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<(), Error>> + Send + 'a {
async move {
let stmt = self.prepare_cached(query).await?;
tokio_postgres::Transaction::execute(self, &stmt, parameters).await?;
Ok(())
}
}
}

impl<'b, C> Connection for &'b C
where
C: Connection,
{
fn query_one<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Row, Error>> + Send + 'a {
(*self).query_one(query, parameters)
}

fn query_opt<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Option<Row>, Error>> + Send + 'a {
(*self).query_opt(query, parameters)
}

fn query<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<Vec<Row>, Error>> + Send + 'a {
(*self).query(query, parameters)
}

fn execute<'a>(
&'a self,
query: &'a str,
parameters: &'a [&'a (dyn ToSql + Sync)],
) -> impl Future<Output = Result<(), Error>> + Send + 'a {
(*self).execute(query, parameters)
}
}
67 changes: 61 additions & 6 deletions postgres/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,33 @@ where
type IntoFuture = SqlFuture<'a, T>;

fn into_future(self) -> Self::IntoFuture {
SqlFuture::new(self)
}
}

pub struct SqlFuture<'a, T> {
future: Pin<Box<dyn Future<Output = Result<T, Error>> + Send + 'a>>,
marker: PhantomData<&'a ()>,
}

impl<'a, T> SqlFuture<'a, T> {
pub fn new<Cols>(sql: Sql<'a, Cols, T>) -> Self
where
T: Query<Cols> + Send + Sync + 'a,
Cols: Send + Sync + 'a,
{
let span =
tracing::debug_span!("sql query", query = self.query, parameters = ?self.parameters);
tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters);
let start = Instant::now();

SqlFuture {
future: Box::pin(
// Note: changes here must be applied to `with_connection` below too!
async move {
let mut i = 1;
loop {
match T::query(&self).await {
let conn = super::connect().await?;
match T::query(&sql, &conn).await {
Ok(r) => {
let elapsed = start.elapsed();
tracing::trace!(?elapsed, "sql query finished");
Expand All @@ -55,11 +72,49 @@ where
marker: PhantomData,
}
}
}

pub struct SqlFuture<'a, T> {
future: Pin<Box<dyn Future<Output = Result<T, Error>> + Send + 'a>>,
marker: PhantomData<&'a ()>,
pub fn with_connection<Cols>(sql: Sql<'a, Cols, T>, conn: impl super::Connection + 'a) -> Self
where
T: Query<Cols> + Send + Sync + 'a,
Cols: Send + Sync + 'a,
{
let span =
tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters);
let start = Instant::now();

SqlFuture {
future: Box::pin(
// Note: changes here must be applied to `bew` above too!
async move {
let mut i = 1;
loop {
match T::query(&sql, &conn).await {
Ok(r) => {
let elapsed = start.elapsed();
tracing::trace!(?elapsed, "sql query finished");
return Ok(r);
}
Err(Error {
kind: ErrorKind::Postgres(err),
..
}) if err.is_closed() && i <= 5 => {
// retry pool size + 1 times if connection is closed (might have
// received a closed one from the connection pool)
i += 1;
tracing::trace!("retry due to connection closed error");
continue;
}
Err(err) => {
return Err(err);
}
}
}
}
.instrument(span),
),
marker: PhantomData,
}
}
}

impl<'a, T> Future for SqlFuture<'a, T> {
Expand Down
79 changes: 10 additions & 69 deletions postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#[cfg(test)]
extern crate self as sqlm_postgres;

mod connection;
mod error;
mod future;
pub mod internal;
Expand All @@ -17,8 +18,9 @@ use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::Arc;

pub use connection::Connection;
pub use deadpool_postgres::Transaction;
use deadpool_postgres::{ClientWrapper, Manager, ManagerConfig, Object, Pool, RecyclingMethod};
use deadpool_postgres::{ClientWrapper, Manager, ManagerConfig, Pool, RecyclingMethod};
pub use error::Error;
use error::ErrorKind;
pub use future::SqlFuture;
Expand All @@ -35,10 +37,8 @@ pub use types::SqlType;

static POOL: OnceCell<Pool> = OnceCell::new();

pub type Connection = Object;

#[tracing::instrument]
pub async fn connect() -> Result<Connection, Error> {
pub async fn connect() -> Result<deadpool_postgres::Client, Error> {
// Don't trace connect, as this would create an endless loop of connecting again and
// again when persisting the connect trace!
let pool = POOL.get_or_try_init(|| {
Expand Down Expand Up @@ -90,71 +90,12 @@ pub struct Sql<'a, Cols, T> {
}

impl<'a, Cols, T> Sql<'a, Cols, T> {
pub fn with(mut self, tx: &'a ClientWrapper) -> Self {
self.connection = Some(tx);
self
}

pub fn with_transaction(mut self, tx: &'a Transaction<'a>) -> Self {
self.transaction = Some(tx);
self
}

async fn query_one(&self) -> Result<tokio_postgres::Row, Error> {
if let Some(tx) = self.transaction {
let stmt = tx.prepare_cached(self.query).await?;
Ok(tx.query_one(&stmt, self.parameters).await?)
} else if let Some(conn) = self.connection {
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query_one(&stmt, self.parameters).await?)
} else {
let conn = connect().await?;
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query_one(&stmt, self.parameters).await?)
}
}

async fn query_opt(&self) -> Result<Option<tokio_postgres::Row>, Error> {
if let Some(tx) = self.transaction {
let stmt = tx.prepare_cached(self.query).await?;
Ok(tx.query_opt(&stmt, self.parameters).await?)
} else if let Some(conn) = self.connection {
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query_opt(&stmt, self.parameters).await?)
} else {
let conn = connect().await?;
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query_opt(&stmt, self.parameters).await?)
}
}

async fn query(&self) -> Result<Vec<tokio_postgres::Row>, Error> {
if let Some(tx) = self.transaction {
let stmt = tx.prepare_cached(self.query).await?;
Ok(tx.query(&stmt, self.parameters).await?)
} else if let Some(conn) = self.connection {
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query(&stmt, self.parameters).await?)
} else {
let conn = connect().await?;
let stmt = conn.prepare_cached(self.query).await?;
Ok(conn.query(&stmt, self.parameters).await?)
}
}

async fn execute(&self) -> Result<(), Error> {
if let Some(tx) = self.transaction {
let stmt = tx.prepare_cached(self.query).await?;
tx.execute(&stmt, self.parameters).await?;
} else if let Some(conn) = self.connection {
let stmt = conn.prepare_cached(self.query).await?;
conn.execute(&stmt, self.parameters).await?;
} else {
let conn = connect().await?;
let stmt = conn.prepare_cached(self.query).await?;
conn.execute(&stmt, self.parameters).await?;
}
Ok(())
pub fn run_with(self, conn: impl Connection + 'a) -> SqlFuture<'a, T>
where
T: Query<Cols> + Send + Sync + 'a,
Cols: Send + Sync + 'a,
{
SqlFuture::with_connection(self, conn)
}
}

Expand Down
Loading

0 comments on commit 1d5991f

Please sign in to comment.