diff --git a/crates/db_schema/src/utils.rs b/crates/db_schema/src/utils.rs index 1dc2f9afa..e4f5d6cae 100644 --- a/crates/db_schema/src/utils.rs +++ b/crates/db_schema/src/utils.rs @@ -6,6 +6,7 @@ use crate::{ SortType, }; use activitypub_federation::{fetch::object_id::ObjectId, traits::Object}; +use async_trait::async_trait; use chrono::NaiveDateTime; use deadpool::Runtime; use diesel::{ @@ -48,8 +49,33 @@ const POOL_TIMEOUT: Option = Some(Duration::from_secs(5)); pub type DbPool = Pool; -pub async fn get_conn(pool: &DbPool) -> Result, DieselError> { - pool.get().await.map_err(|e| QueryBuilderError(e.into())) +#[async_trait] +pub trait GetConn { + type Conn: std::ops::Deref; + + async fn get_conn(self) -> Result; +} + +#[async_trait] +impl<'a> GetConn for &'a DbPool { + type Conn = PooledConnection; + + async fn get_conn(self) -> Result { + self.get().await.map_err(|e| QueryBuilderError(e.into())) + } +} + +#[async_trait] +impl<'a> GetConn for &'a mut AsyncPgConnection { + type Conn = Self; + + async fn get_conn(self) -> Result { + Ok(self) + } +} + +pub async fn get_conn(getter: T) -> Result { + getter.get_conn().await } pub fn get_database_url_from_env() -> Result {