From 24830070f6f3a1111295df6a9bbe09faf4424684 Mon Sep 17 00:00:00 2001 From: Nutomic Date: Mon, 11 Dec 2023 15:04:18 +0100 Subject: [PATCH] Add diesel feature, add ObjectId::dereference_forced (#88) * Add diesel feature This can simplify Lemmy code and avoid converting back and forth to DbUrl type all the time. * Also add diesel derives for CollectionId * Add ObjectId::dereference_forced * no deprecated code * fmt --- Cargo.toml | 2 + src/fetch/collection_id.rs | 89 +++++++++++++++++++++++++++ src/fetch/object_id.rs | 123 ++++++++++++++++++++++++++++++++++--- 3 files changed, 206 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e0f0cb8..da83e67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ documentation = "https://docs.rs/activitypub_federation/" default = ["actix-web", "axum"] actix-web = ["dep:actix-web"] axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"] +diesel = ["dep:diesel"] [dependencies] chrono = { version = "0.4.31", features = ["clock"], default-features = false } @@ -50,6 +51,7 @@ tokio = { version = "1.34.0", features = [ "rt-multi-thread", "time", ] } +diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true } futures = "0.3.29" moka = { version = "0.12.1", features = ["future"] } diff --git a/src/fetch/collection_id.rs b/src/fetch/collection_id.rs index ae17ca0..8c796f4 100644 --- a/src/fetch/collection_id.rs +++ b/src/fetch/collection_id.rs @@ -102,3 +102,92 @@ where self.0.eq(&other.0) && self.1 == other.1 } } + +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(CollectionId::parse(&string)?) + } + } + impl Queryable for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(CollectionId::parse(&row)?) + } + } + impl QueryId for CollectionId + where + Kind: Collection + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; diff --git a/src/fetch/object_id.rs b/src/fetch/object_id.rs index 179a82c..782900d 100644 --- a/src/fetch/object_id.rs +++ b/src/fetch/object_id.rs @@ -57,12 +57,12 @@ where pub struct ObjectId(Box, PhantomData) where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>; + for<'de2> ::Kind: Deserialize<'de2>; impl ObjectId where Kind: Object + Send + Debug + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { /// Construct a new objectid instance pub fn parse(url: &str) -> Result { @@ -112,6 +112,24 @@ where } } + /// If this is a remote object, fetch it from origin instance unconditionally to get the + /// latest version, regardless of refresh interval. + pub async fn dereference_forced( + &self, + data: &Data<::DataType>, + ) -> Result::Error> + where + ::Error: From, + { + if data.config.is_local_url(&self.0) { + self.dereference_from_db(data) + .await + .map(|o| o.ok_or(Error::NotFound.into()))? + } else { + self.dereference_from_http(data, None).await + } + } + /// Fetch an object from the local db. Instead of falling back to http, this throws an error if /// the object is not found in the database. pub async fn dereference_local( @@ -163,7 +181,7 @@ where impl Clone for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn clone(&self) -> Self { ObjectId(self.0.clone(), self.1) @@ -190,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime) -> bool { impl Display for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -200,7 +218,7 @@ where impl Debug for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -210,7 +228,7 @@ where impl From> for Url where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(id: ObjectId) -> Self { *id.0 @@ -220,7 +238,7 @@ where impl From for ObjectId where Kind: Object + Send + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(url: Url) -> Self { ObjectId(Box::new(url), PhantomData::) @@ -230,13 +248,102 @@ where impl PartialEq for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn eq(&self, other: &Self) -> bool { self.0.eq(&other.0) && self.1 == other.1 } } +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(ObjectId::parse(&string)?) + } + } + impl Queryable for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(ObjectId::parse(&row)?) + } + } + impl QueryId for ObjectId + where + Kind: Object + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; + #[cfg(test)] pub mod tests { use super::*;