From 14b80c66a73cdf4f3391d5ebe4af4b181d3e986b Mon Sep 17 00:00:00 2001 From: Tangel Date: Thu, 14 Dec 2023 07:18:33 +0000 Subject: [PATCH] merge upstream --- Cargo.toml | 22 +++- docs/06_http_endpoints_axum.md | 2 +- examples/live_federation/http.rs | 2 +- examples/local_federation/actix_web/http.rs | 2 +- examples/local_federation/axum/http.rs | 2 +- src/activity_sending.rs | 38 ++---- src/actix_web/inbox.rs | 67 ++++++++--- src/axum/inbox.rs | 13 +-- src/axum/json.rs | 2 +- src/error.rs | 10 +- src/fetch/collection_id.rs | 89 ++++++++++++++ src/fetch/object_id.rs | 123 ++++++++++++++++++-- src/fetch/webfinger.rs | 67 +++++++---- src/lib.rs | 41 +++++++ src/traits.rs | 2 +- 15 files changed, 387 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a0bee78..34c82e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "activitypub_federation" -version = "0.5.0-beta.5" +version = "0.5.0-beta.6" edition = "2021" description = "High-level Activitypub framework" keywords = ["activitypub", "activitystreams", "federation", "fediverse"] @@ -8,18 +8,27 @@ license = "AGPL-3.0" repository = "https://github.com/LemmyNet/activitypub-federation-rust" documentation = "https://docs.rs/activitypub_federation/" +[features] +default = ["actix-web", "axum"] +actix-web = ["dep:actix-web"] +axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"] +diesel = ["dep:diesel"] + [dependencies] chrono = { version = "0.4.31", features = ["clock"], default-features = false } serde = { version = "1.0.193", features = ["derive"] } +serde = { version = "1.0.193", features = ["derive"] } async-trait = "0.1.74" url = { version = "2.5.0", features = ["serde"] } serde_json = { version = "1.0.108", features = ["preserve_order"] } +url = { version = "2.5.0", features = ["serde"] } +serde_json = { version = "1.0.108", features = ["preserve_order"] } reqwest = { version = "0.11.22", features = ["json", "stream"] } reqwest-middleware = "0.2.4" tracing = "0.1.40" base64 = "0.21.5" -openssl = "0.10.60" -once_cell = "1.18.0" +openssl = "0.10.61" +once_cell = "1.19.0" http = "0.2.11" sha2 = "0.10.8" thiserror = "1.0.50" @@ -32,6 +41,7 @@ http-signature-normalization-reqwest = { version = "0.10.0", default-features = "default-spawner", "sha-2", "middleware", + "default-spawner", ] } http-signature-normalization = "0.7.0" bytes = "1.5.0" @@ -48,6 +58,9 @@ tokio = { version = "1.34.0", features = [ "rt-multi-thread", "time", ] } +diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true } +futures = "0.3.29" +moka = { version = "0.12.1", features = ["future"] } # Actix-web actix-web = { version = "4.4.0", default-features = false, optional = true } @@ -58,8 +71,7 @@ axum = { git = "https://github.com/tokio-rs/axum.git", rev = "30afe97e99303fffc4 ], default-features = false, optional = true } tower = { version = "*", optional = true } hyper = { version = "*", optional = true } -futures = "*" -moka = { version = "0.12.1", features = ["future"] } +http-body-util = {version = "0.1.0", optional = true } [features] default = ["actix-web", "axum"] diff --git a/docs/06_http_endpoints_axum.md b/docs/06_http_endpoints_axum.md index 8ebbcc8..3a33410 100644 --- a/docs/06_http_endpoints_axum.md +++ b/docs/06_http_endpoints_axum.md @@ -48,7 +48,7 @@ async fn http_get_user( ) -> impl IntoResponse { let accept = header_map.get("accept").map(|v| v.to_str().unwrap()); if accept == Some(FEDERATION_CONTENT_TYPE) { - let db_user = data.read_local_user(name).await.unwrap(); + let db_user = data.read_local_user(&name).await.unwrap(); let json_user = db_user.into_json(&data).await.unwrap(); FederationJson(WithContext::new_default(json_user)).into_response() } diff --git a/examples/live_federation/http.rs b/examples/live_federation/http.rs index d626396..e0d2869 100644 --- a/examples/live_federation/http.rs +++ b/examples/live_federation/http.rs @@ -61,7 +61,7 @@ pub async fn webfinger( data: Data, ) -> Result, Error> { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(Json(build_webfinger_response( query.resource, db_user.ap_id.into_inner(), diff --git a/examples/local_federation/actix_web/http.rs b/examples/local_federation/actix_web/http.rs index 12a750f..6298014 100644 --- a/examples/local_federation/actix_web/http.rs +++ b/examples/local_federation/actix_web/http.rs @@ -89,7 +89,7 @@ pub async fn webfinger( data: Data, ) -> Result { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(HttpResponse::Ok().json(build_webfinger_response( query.resource.clone(), db_user.ap_id.into_inner(), diff --git a/examples/local_federation/axum/http.rs b/examples/local_federation/axum/http.rs index 16c5f0e..cf6469d 100644 --- a/examples/local_federation/axum/http.rs +++ b/examples/local_federation/axum/http.rs @@ -75,7 +75,7 @@ async fn webfinger( data: Data, ) -> Result, Error> { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(Json(build_webfinger_response( query.resource, db_user.ap_id.into_inner(), diff --git a/src/activity_sending.rs b/src/activity_sending.rs index 357ca08..ec438f4 100644 --- a/src/activity_sending.rs +++ b/src/activity_sending.rs @@ -13,17 +13,15 @@ use crate::{ use bytes::Bytes; use futures::StreamExt; -use http::{header::HeaderName, HeaderMap, HeaderValue}; use httpdate::fmt_http_date; use itertools::Itertools; use openssl::pkey::{PKey, Private}; -use reqwest::Request; -use reqwest_middleware::ClientWithMiddleware; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use serde::Serialize; use std::{ self, fmt::{Debug, Display}, - time::{Duration, SystemTime}, + time::SystemTime, }; use tracing::debug; use url::Url; @@ -94,33 +92,19 @@ impl SendActivityTask<'_> { /// convert a sendactivitydata to a request, signing and sending it pub async fn sign_and_send(&self, data: &Data) -> Result<(), Error> { - let req = self - .sign(&data.config.client, data.config.request_timeout) - .await?; - self.send(&data.config.client, req).await - } - async fn sign( - &self, - client: &ClientWithMiddleware, - timeout: Duration, - ) -> Result { - let task = self; + let client = &data.config.client; let request_builder = client - .post(task.inbox.to_string()) - .timeout(timeout) - .headers(generate_request_headers(&task.inbox)); + .post(self.inbox.to_string()) + .timeout(data.config.request_timeout) + .headers(generate_request_headers(&self.inbox)); let request = sign_request( request_builder, - task.actor_id, - task.activity.clone(), - task.private_key.clone(), - task.http_signature_compat, + self.actor_id, + self.activity.clone(), + self.private_key.clone(), + self.http_signature_compat, ) .await?; - Ok(request) - } - - async fn send(&self, client: &ClientWithMiddleware, request: Request) -> Result<(), Error> { let response = client.execute(request).await?; match response { @@ -287,7 +271,7 @@ mod tests { let start = Instant::now(); for _ in 0..num_messages { - message.sign_and_send(&data).await?; + message.clone().sign_and_send(&data).await?; } info!("Queue Sent: {:?}", start.elapsed()); diff --git a/src/actix_web/inbox.rs b/src/actix_web/inbox.rs index e7e19e8..a2b55d4 100644 --- a/src/actix_web/inbox.rs +++ b/src/actix_web/inbox.rs @@ -3,8 +3,8 @@ use crate::{ config::Data, error::Error, - fetch::object_id::ObjectId, http_signatures::{verify_body_hash, verify_signature}, + parse_received_activity, traits::{ActivityHandler, Actor, Object}, }; use actix_web::{web::Bytes, HttpRequest, HttpResponse}; @@ -29,11 +29,7 @@ where { verify_body_hash(request.headers().get("Digest"), &body)?; - let activity: Activity = serde_json::from_slice(&body).map_err(Error::Json)?; - data.config.verify_url_and_domain(&activity).await?; - let actor = ObjectId::::from(activity.actor().clone()) - .dereference(data) - .await?; + let (activity, actor) = parse_received_activity::(&body, data).await?; verify_signature( request.headers(), @@ -54,12 +50,14 @@ mod test { use crate::{ activity_sending::generate_request_headers, config::FederationConfig, + fetch::object_id::ObjectId, http_signatures::sign_request, traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR}, }; use actix_web::test::TestRequest; use reqwest::Client; use reqwest_middleware::ClientWithMiddleware; + use serde_json::json; use url::Url; #[tokio::test] @@ -105,22 +103,49 @@ mod test { assert_eq!(&err, &Error::ActivitySignatureInvalid) } - async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig) { + #[tokio::test] + async fn test_receive_unparseable_activity() { + let (_, _, config) = setup_receive_test().await; + + let actor = Url::parse("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap(); + let id = "http://localhost:123/1"; + let activity = json!({ + "actor": actor.as_str(), + "to": ["https://www.w3.org/ns/activitystreams#Public"], + "object": "http://ds9.lemmy.ml/post/1", + "cc": ["http://enterprise.lemmy.ml/c/main"], + "type": "Delete", + "id": id + } + ); + let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); + let incoming_request = construct_request(&body, &actor).await; + + // intentionally cause a parse error by using wrong type for deser + let res = receive_activity::( + incoming_request.to_http_request(), + body, + &config.to_request_data(), + ) + .await; + + match res { + Err(Error::ParseReceivedActivity(url, _)) => { + assert_eq!(id, url.as_str()); + } + _ => unreachable!(), + } + } + + async fn construct_request(body: &Bytes, actor: &Url) -> TestRequest { let inbox = "https://example.com/inbox"; let headers = generate_request_headers(&Url::parse(inbox).unwrap()); let request_builder = ClientWithMiddleware::from(Client::default()) .post(inbox) .headers(headers); - let activity = Follow { - actor: ObjectId::parse("http://localhost:123").unwrap(), - object: ObjectId::parse("http://localhost:124").unwrap(), - kind: Default::default(), - id: "http://localhost:123/1".try_into().unwrap(), - }; - let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); let outgoing_request = sign_request( request_builder, - &activity.actor.into_inner(), + actor, body.clone(), DB_USER_KEYPAIR.private_key().unwrap(), false, @@ -131,6 +156,18 @@ mod test { for h in outgoing_request.headers() { incoming_request = incoming_request.append_header(h); } + incoming_request + } + + async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig) { + let activity = Follow { + actor: ObjectId::parse("http://localhost:123").unwrap(), + object: ObjectId::parse("http://localhost:124").unwrap(), + kind: Default::default(), + id: "http://localhost:123/1".try_into().unwrap(), + }; + let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); + let incoming_request = construct_request(&body, activity.actor.inner()).await; let config = FederationConfig::builder() .domain("localhost:8002") diff --git a/src/axum/inbox.rs b/src/axum/inbox.rs index c0995a7..890e8f3 100644 --- a/src/axum/inbox.rs +++ b/src/axum/inbox.rs @@ -5,8 +5,8 @@ use crate::{ config::Data, error::Error, - fetch::object_id::ObjectId, - http_signatures::{verify_body_hash, verify_signature}, + http_signatures::verify_signature, + parse_received_activity, traits::{ActivityHandler, Actor, Object}, }; use axum::{ @@ -33,13 +33,8 @@ where ::Error: From, Datatype: Clone, { - verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?; - - let activity: Activity = serde_json::from_slice(&activity_data.body).map_err(Error::Json)?; - data.config.verify_url_and_domain(&activity).await?; - let actor = ObjectId::::from(activity.actor().clone()) - .dereference(data) - .await?; + let (activity, actor) = + parse_received_activity::(&activity_data.body, data).await?; // verify_signature( // &activity_data.headers, diff --git a/src/axum/json.rs b/src/axum/json.rs index f8a649e..f99c8bd 100644 --- a/src/axum/json.rs +++ b/src/axum/json.rs @@ -9,7 +9,7 @@ //! # use activitypub_federation::traits::Object; //! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person}; //! async fn http_get_user(Path(name): Path, data: Data) -> Result>, Error> { -//! let user: DbUser = data.read_local_user(name).await?; +//! let user: DbUser = data.read_local_user(&name).await?; //! let person = user.into_json(&data).await?; //! //! Ok(FederationJson(WithContext::new_default(person))) diff --git a/src/error.rs b/src/error.rs index c66e16c..89f6abf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ use http_signature_normalization_reqwest::SignError; use openssl::error::ErrorStack; use url::Url; +use crate::fetch::webfinger::WebFingerError; + /// Error messages returned by this library #[derive(thiserror::Error, Debug)] pub enum Error { @@ -32,13 +34,13 @@ pub enum Error { ActivitySignatureInvalid, /// Failed to resolve actor via webfinger #[error("Failed to resolve actor via webfinger")] - WebfingerResolveFailed, - /// Failed to resolve actor via webfinger - #[error("Webfinger regex failed to match")] - WebfingerRegexFailed, + WebfingerResolveFailed(#[from] WebFingerError), /// JSON Error #[error(transparent)] Json(#[from] serde_json::Error), + /// Failed to parse an activity received from another instance + #[error("Failed to parse incoming activity with id {0}: {1}")] + ParseReceivedActivity(Url, serde_json::Error), /// Reqwest Middleware Error #[error(transparent)] ReqwestMiddleware(#[from] reqwest_middleware::Error), diff --git a/src/fetch/collection_id.rs b/src/fetch/collection_id.rs index ae17ca0..8c796f4 100644 --- a/src/fetch/collection_id.rs +++ b/src/fetch/collection_id.rs @@ -102,3 +102,92 @@ where self.0.eq(&other.0) && self.1 == other.1 } } + +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(CollectionId::parse(&string)?) + } + } + impl Queryable for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(CollectionId::parse(&row)?) + } + } + impl QueryId for CollectionId + where + Kind: Collection + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; diff --git a/src/fetch/object_id.rs b/src/fetch/object_id.rs index 179a82c..782900d 100644 --- a/src/fetch/object_id.rs +++ b/src/fetch/object_id.rs @@ -57,12 +57,12 @@ where pub struct ObjectId(Box, PhantomData) where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>; + for<'de2> ::Kind: Deserialize<'de2>; impl ObjectId where Kind: Object + Send + Debug + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { /// Construct a new objectid instance pub fn parse(url: &str) -> Result { @@ -112,6 +112,24 @@ where } } + /// If this is a remote object, fetch it from origin instance unconditionally to get the + /// latest version, regardless of refresh interval. + pub async fn dereference_forced( + &self, + data: &Data<::DataType>, + ) -> Result::Error> + where + ::Error: From, + { + if data.config.is_local_url(&self.0) { + self.dereference_from_db(data) + .await + .map(|o| o.ok_or(Error::NotFound.into()))? + } else { + self.dereference_from_http(data, None).await + } + } + /// Fetch an object from the local db. Instead of falling back to http, this throws an error if /// the object is not found in the database. pub async fn dereference_local( @@ -163,7 +181,7 @@ where impl Clone for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn clone(&self) -> Self { ObjectId(self.0.clone(), self.1) @@ -190,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime) -> bool { impl Display for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -200,7 +218,7 @@ where impl Debug for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -210,7 +228,7 @@ where impl From> for Url where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(id: ObjectId) -> Self { *id.0 @@ -220,7 +238,7 @@ where impl From for ObjectId where Kind: Object + Send + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(url: Url) -> Self { ObjectId(Box::new(url), PhantomData::) @@ -230,13 +248,102 @@ where impl PartialEq for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn eq(&self, other: &Self) -> bool { self.0.eq(&other.0) && self.1 == other.1 } } +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(ObjectId::parse(&string)?) + } + } + impl Queryable for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(ObjectId::parse(&row)?) + } + } + impl QueryId for ObjectId + where + Kind: Object + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; + #[cfg(test)] pub mod tests { use super::*; diff --git a/src/fetch/webfinger.rs b/src/fetch/webfinger.rs index a345fd4..68b110d 100644 --- a/src/fetch/webfinger.rs +++ b/src/fetch/webfinger.rs @@ -1,17 +1,38 @@ use crate::{ config::Data, - error::{Error, Error::WebfingerResolveFailed}, + error::Error, fetch::{fetch_object_http_with_accept, object_id::ObjectId}, traits::{Actor, Object}, FEDERATION_CONTENT_TYPE, }; use itertools::Itertools; +use once_cell::sync::Lazy; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Display}; use tracing::debug; use url::Url; +/// Errors relative to webfinger handling +#[derive(thiserror::Error, Debug)] +pub enum WebFingerError { + /// The webfinger identifier is invalid + #[error("The webfinger identifier is invalid")] + WrongFormat, + /// The webfinger identifier doesn't match the expected instance domain name + #[error("The webfinger identifier doesn't match the expected instance domain name")] + WrongDomain, + /// The wefinger object did not contain any link to an activitypub item + #[error("The webfinger object did not contain any link to an activitypub item")] + NoValidLink, +} + +impl WebFingerError { + fn into_crate_error(self) -> Error { + self.into() + } +} + /// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`. /// /// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID @@ -23,12 +44,12 @@ pub async fn webfinger_resolve_actor( where Kind: Object + Actor + Send + 'static + Object, for<'de2> ::Kind: serde::Deserialize<'de2>, - ::Error: From + Send + Sync, + ::Error: From + Send + Sync + Display, { let (_, domain) = identifier .splitn(2, '@') .collect_tuple() - .ok_or(WebfingerResolveFailed)?; + .ok_or(WebFingerError::WrongFormat.into_crate_error())?; let protocol = if data.config.debug { "http" } else { "https" }; let fetch_url = format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}"); @@ -55,13 +76,15 @@ where }) .filter_map(|l| l.href.clone()) .collect(); + for l in links { let object = ObjectId::::from(l).dereference(data).await; - if object.is_ok() { - return object; + match object { + Ok(obj) => return Ok(obj), + Err(error) => debug!(%error, "Failed to dereference link"), } } - Err(WebfingerResolveFailed.into()) + Err(WebFingerError::NoValidLink.into_crate_error().into()) } /// Extracts username from a webfinger resource parameter. @@ -89,22 +112,24 @@ where /// # Ok::<(), anyhow::Error>(()) /// }).unwrap(); ///``` -pub fn extract_webfinger_name(query: &str, data: &Data) -> Result +pub fn extract_webfinger_name<'i, T>(query: &'i str, data: &Data) -> Result<&'i str, Error> where T: Clone, { + static WEBFINGER_REGEX: Lazy = + Lazy::new(|| Regex::new(r"^acct:([\p{L}0-9_]+)@(.*)$").expect("compile regex")); // Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`. - // TODO: would be nice if we could implement this without regex and remove the dependency - let result = Regex::new(&format!(r"^acct:([\p{{L}}0-9_]+)@{}$", data.domain())) - .map_err(|_| Error::WebfingerRegexFailed) - .and_then(|regex| { - regex - .captures(query) - .and_then(|c| c.get(1)) - .ok_or_else(|| Error::WebfingerRegexFailed) - })?; + // TODO: This should use a URL parser + let captures = WEBFINGER_REGEX + .captures(query) + .ok_or(WebFingerError::WrongFormat)?; - return Ok(result.as_str().to_string()); + let account_name = captures.get(1).ok_or(WebFingerError::WrongFormat)?; + + if captures.get(2).map(|m| m.as_str()) != Some(data.domain()) { + return Err(WebFingerError::WrongDomain.into()); + } + Ok(account_name.as_str()) } /// Builds a basic webfinger response for the actor. @@ -252,15 +277,15 @@ mod tests { request_counter: Default::default(), }; assert_eq!( - Ok("test123".to_string()), + Ok("test123"), extract_webfinger_name("acct:test123@example.com", &data) ); assert_eq!( - Ok("Владимир".to_string()), + Ok("Владимир"), extract_webfinger_name("acct:Владимир@example.com", &data) ); assert_eq!( - Ok("تجريب".to_string()), + Ok("تجريب"), extract_webfinger_name("acct:تجريب@example.com", &data) ); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index c660253..42da8df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,48 @@ pub mod protocol; pub(crate) mod reqwest_shim; pub mod traits; +use crate::{ + config::Data, + error::Error, + fetch::object_id::ObjectId, + traits::{ActivityHandler, Actor, Object}, +}; pub use activitystreams_kinds as kinds; +use serde::{de::DeserializeOwned, Deserialize}; +use url::Url; + /// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json"; + +/// Deserialize incoming inbox activity to the given type, perform basic +/// validation and extract the actor. +async fn parse_received_activity( + body: &[u8], + data: &Data, +) -> Result<(Activity, ActorT), ::Error> +where + Activity: ActivityHandler + DeserializeOwned + Send + 'static, + ActorT: Object + Actor + Send + 'static, + for<'de2> ::Kind: serde::Deserialize<'de2>, + ::Error: From + From<::Error>, + ::Error: From, + Datatype: Clone, +{ + let activity: Activity = serde_json::from_slice(body).map_err(|e| { + // Attempt to include activity id in error message + #[derive(Deserialize)] + struct Id { + id: Url, + } + match serde_json::from_slice::(body) { + Ok(id) => Error::ParseReceivedActivity(id.id, e), + Err(e) => Error::Json(e), + } + })?; + data.config.verify_url_and_domain(&activity).await?; + let actor = ObjectId::::from(activity.actor().clone()) + .dereference(data) + .await?; + Ok((activity, actor)) +} diff --git a/src/traits.rs b/src/traits.rs index 21a1540..9fdec27 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -356,7 +356,7 @@ pub mod tests { pub async fn read_post_from_json_id(&self, _: Url) -> Result, Error> { Ok(None) } - pub async fn read_local_user(&self, _: String) -> Result { + pub async fn read_local_user(&self, _: &str) -> Result { todo!() } pub async fn upsert(&self, _: &T) -> Result<(), Error> {