diff --git a/Cargo.lock b/Cargo.lock index 5ba47c0..552f0e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,6 +16,7 @@ dependencies = [ "base64", "chrono", "derive_builder", + "dyn-clone", "env_logger", "http", "http-signature-normalization-actix", @@ -530,6 +531,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c97b9233581d84b8e1e689cdd3a47b6f69770084fc246e86a7f78b0d9c1d4a5" +[[package]] +name = "dyn-clone" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f94fa09c2aeea5b8839e414b7b841bf429fd25b9c522116ac97ee87856d88b2" + [[package]] name = "either" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index 8cc4e97..d774254 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ background-jobs = "0.13.0" thiserror = "1.0.37" derive_builder = "0.11.2" itertools = "0.10.5" +dyn-clone = "1.0.9" [dev-dependencies] activitystreams-kinds = "0.2.1" diff --git a/examples/federation/instance.rs b/examples/federation/instance.rs index 653fa91..c3eec7f 100644 --- a/examples/federation/instance.rs +++ b/examples/federation/instance.rs @@ -13,9 +13,11 @@ use activitypub_federation::{ traits::ApubObject, InstanceSettings, LocalInstance, + UrlVerifier, APUB_JSON_CONTENT_TYPE, }; use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; +use async_trait::async_trait; use http_signature_normalization_actix::prelude::VerifyDigest; use reqwest::Client; use sha2::{Digest, Sha256}; @@ -37,9 +39,27 @@ pub struct Instance { pub posts: Mutex>, } +/// Use this to store your federation blocklist, or a database connection needed to retrieve it. +#[derive(Clone)] +struct MyUrlVerifier(); + +#[async_trait] +impl UrlVerifier for MyUrlVerifier { + async fn verify(&self, url: &Url) -> Result<(), &'static str> { + if url.domain() == Some("malicious.com") { + Err("malicious domain") + } else { + Ok(()) + } + } +} + impl Instance { pub fn new(hostname: String) -> Result { - let settings = InstanceSettings::builder().debug(true).build()?; + let settings = InstanceSettings::builder() + .debug(true) + .url_verifier(Box::new(MyUrlVerifier())) + .build()?; let local_instance = LocalInstance::new(hostname.clone(), Client::default().into(), settings); let local_user = MyUser::new(generate_object_id(&hostname)?, generate_actor_keypair()?); diff --git a/src/core/activity_queue.rs b/src/core/activity_queue.rs index 861b2ea..704e9f5 100644 --- a/src/core/activity_queue.rs +++ b/src/core/activity_queue.rs @@ -51,12 +51,11 @@ where .into_iter() .unique() .filter(|i| !instance.is_local_url(i)) - .filter(|i| verify_url_valid(i, &instance.settings).is_ok()) .collect(); let activity_queue = &instance.activity_queue; for inbox in inboxes { - if verify_url_valid(&inbox, &instance.settings).is_err() { + if verify_url_valid(&inbox, &instance.settings).await.is_err() { continue; } let message = SendActivityTask { diff --git a/src/core/inbox.rs b/src/core/inbox.rs index a42aefe..b8dde54 100644 --- a/src/core/inbox.rs +++ b/src/core/inbox.rs @@ -26,7 +26,7 @@ where ::Error: From + From, { verify_domains_match(activity.id(), activity.actor())?; - verify_url_valid(activity.id(), &local_instance.settings)?; + verify_url_valid(activity.id(), &local_instance.settings).await?; if local_instance.is_local_url(activity.id()) { return Err(Error::UrlVerificationError("Activity was sent from local instance").into()); } diff --git a/src/lib.rs b/src/lib.rs index 7c0fa5c..5d641ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ use crate::core::activity_queue::create_activity_queue; +use async_trait::async_trait; use background_jobs::Manager; use derive_builder::Builder; +use dyn_clone::{clone_trait_object, DynClone}; use reqwest_middleware::ClientWithMiddleware; use std::time::Duration; use url::Url; @@ -23,6 +25,12 @@ pub struct LocalInstance { settings: InstanceSettings, } +#[async_trait] +pub trait UrlVerifier: DynClone + Send { + async fn verify(&self, url: &Url) -> Result<(), &'static str>; +} +clone_trait_object!(UrlVerifier); + // Use InstanceSettingsBuilder to initialize this #[derive(Builder)] pub struct InstanceSettings { @@ -45,12 +53,13 @@ pub struct InstanceSettings { /// Function used to verify that urls are valid, used when receiving activities or fetching remote /// objects. Use this to implement functionality like federation blocklists. In case verification /// fails, it should return an error message. - #[builder(default = "|_| { Ok(()) }")] - verify_url_function: fn(&Url) -> Result<(), &'static str>, + #[builder(default = "Box::new(DefaultUrlVerifier())")] + url_verifier: Box, /// Enable to sign HTTP signatures according to draft 10, which does not include (created) and /// (expires) fields. This is required for compatibility with some software like Pleroma. /// https://datatracker.ietf.org/doc/html/draft-cavage-http-signatures-10 /// https://git.pleroma.social/pleroma/pleroma/-/issues/2939 + #[builder(default = "false")] http_signature_compat: bool, } @@ -61,6 +70,16 @@ impl InstanceSettings { } } +#[derive(Clone)] +struct DefaultUrlVerifier(); + +#[async_trait] +impl UrlVerifier for DefaultUrlVerifier { + async fn verify(&self, _url: &Url) -> Result<(), &'static str> { + Ok(()) + } +} + impl LocalInstance { pub fn new(domain: String, client: ClientWithMiddleware, settings: InstanceSettings) -> Self { let activity_queue = create_activity_queue(client.clone(), &settings); diff --git a/src/utils.rs b/src/utils.rs index 76d21db..45e4037 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -11,7 +11,7 @@ pub async fn fetch_object_http( ) -> Result { // dont fetch local objects this way debug_assert!(url.domain() != Some(&instance.hostname)); - verify_url_valid(url, &instance.settings)?; + verify_url_valid(url, &instance.settings).await?; info!("Fetching remote object {}", url.to_string()); *request_counter += 1; @@ -55,7 +55,7 @@ pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> { /// [`InstanceSettings.verify_url_function`]. /// /// https://www.w3.org/TR/activitypub/#security-considerations -pub fn verify_url_valid(url: &Url, settings: &InstanceSettings) -> Result<(), Error> { +pub async fn verify_url_valid(url: &Url, settings: &InstanceSettings) -> Result<(), Error> { match url.scheme() { "https" => {} "http" => { @@ -78,7 +78,11 @@ pub fn verify_url_valid(url: &Url, settings: &InstanceSettings) -> Result<(), Er )); } - (settings.verify_url_function)(url).map_err(Error::UrlVerificationError)?; + settings + .url_verifier + .verify(url) + .await + .map_err(Error::UrlVerificationError)?; Ok(()) }