Remove anyhow from trait definitions (#82)

This commit is contained in:
cetra3 2023-11-20 21:12:47 +10:30 committed by GitHub
parent 679228873a
commit 098a4299f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 137 additions and 137 deletions

View file

@ -14,7 +14,6 @@ serde = { version = "1.0.189", features = ["derive"] }
async-trait = "0.1.74" async-trait = "0.1.74"
url = { version = "2.4.1", features = ["serde"] } url = { version = "2.4.1", features = ["serde"] }
serde_json = { version = "1.0.107", features = ["preserve_order"] } serde_json = { version = "1.0.107", features = ["preserve_order"] }
anyhow = "1.0.75"
reqwest = { version = "0.11.22", features = ["json", "stream"] } reqwest = { version = "0.11.22", features = ["json", "stream"] }
reqwest-middleware = "0.2.3" reqwest-middleware = "0.2.3"
tracing = "0.1.40" tracing = "0.1.40"
@ -65,6 +64,7 @@ actix-web = ["dep:actix-web"]
axum = ["dep:axum", "dep:tower", "dep:hyper"] axum = ["dep:axum", "dep:tower", "dep:hyper"]
[dev-dependencies] [dev-dependencies]
anyhow = "1.0.75"
rand = "0.8.5" rand = "0.8.5"
env_logger = "0.10.0" env_logger = "0.10.0"
tower-http = { version = "0.4.4", features = ["map-request-body", "util"] } tower-http = { version = "0.4.4", features = ["map-request-body", "util"] }

View file

@ -49,9 +49,11 @@ struct MyUrlVerifier();
#[async_trait] #[async_trait]
impl UrlVerifier for MyUrlVerifier { impl UrlVerifier for MyUrlVerifier {
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { async fn verify(&self, url: &Url) -> Result<(), activitypub_federation::error::Error> {
if url.domain() == Some("malicious.com") { if url.domain() == Some("malicious.com") {
Err(anyhow!("malicious domain")) Err(activitypub_federation::error::Error::Other(
"malicious domain".into(),
))
} else { } else {
Ok(()) Ok(())
} }

View file

@ -107,7 +107,7 @@ impl DbUser {
activity: Activity, activity: Activity,
recipients: Vec<Url>, recipients: Vec<Url>,
data: &Data<DatabaseHandle>, data: &Data<DatabaseHandle>,
) -> Result<(), <Activity as ActivityHandler>::Error> ) -> Result<(), Error>
where where
Activity: ActivityHandler + Serialize + Debug + Send + Sync, Activity: ActivityHandler + Serialize + Debug + Send + Sync,
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>, <Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,

View file

@ -10,7 +10,6 @@ use crate::{
traits::{ActivityHandler, Actor}, traits::{ActivityHandler, Actor},
FEDERATION_CONTENT_TYPE, FEDERATION_CONTENT_TYPE,
}; };
use anyhow::{anyhow, Context};
use bytes::Bytes; use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
@ -57,17 +56,16 @@ impl SendActivityTask<'_> {
actor: &ActorType, actor: &ActorType,
inboxes: Vec<Url>, inboxes: Vec<Url>,
data: &Data<Datatype>, data: &Data<Datatype>,
) -> Result<Vec<SendActivityTask<'a>>, <Activity as ActivityHandler>::Error> ) -> Result<Vec<SendActivityTask<'a>>, Error>
where where
Activity: ActivityHandler + Serialize, Activity: ActivityHandler + Serialize,
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,
Datatype: Clone, Datatype: Clone,
ActorType: Actor, ActorType: Actor,
{ {
let config = &data.config; let config = &data.config;
let actor_id = activity.actor(); let actor_id = activity.actor();
let activity_id = activity.id(); let activity_id = activity.id();
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into(); let activity_serialized: Bytes = serde_json::to_vec(&activity).map_err(Error::Json)?.into();
let private_key = get_pkey_cached(data, actor).await?; let private_key = get_pkey_cached(data, actor).await?;
Ok(futures::stream::iter( Ok(futures::stream::iter(
@ -95,10 +93,7 @@ impl SendActivityTask<'_> {
} }
/// convert a sendactivitydata to a request, signing and sending it /// convert a sendactivitydata to a request, signing and sending it
pub async fn sign_and_send<Datatype: Clone>( pub async fn sign_and_send<Datatype: Clone>(&self, data: &Data<Datatype>) -> Result<(), Error> {
&self,
data: &Data<Datatype>,
) -> Result<(), anyhow::Error> {
let req = self let req = self
.sign(&data.config.client, data.config.request_timeout) .sign(&data.config.client, data.config.request_timeout)
.await?; .await?;
@ -108,7 +103,7 @@ impl SendActivityTask<'_> {
&self, &self,
client: &ClientWithMiddleware, client: &ClientWithMiddleware,
timeout: Duration, timeout: Duration,
) -> Result<Request, anyhow::Error> { ) -> Result<Request, Error> {
let task = self; let task = self;
let request_builder = client let request_builder = client
.post(task.inbox.to_string()) .post(task.inbox.to_string())
@ -121,36 +116,31 @@ impl SendActivityTask<'_> {
task.private_key.clone(), task.private_key.clone(),
task.http_signature_compat, task.http_signature_compat,
) )
.await .await?;
.context("signing request")?;
Ok(request) Ok(request)
} }
async fn send( async fn send(&self, client: &ClientWithMiddleware, request: Request) -> Result<(), Error> {
&self, let response = client.execute(request).await?;
client: &ClientWithMiddleware,
request: Request,
) -> Result<(), anyhow::Error> {
let response = client.execute(request).await;
match response { match response {
Ok(o) if o.status().is_success() => { o if o.status().is_success() => {
debug!("Activity {self} delivered successfully"); debug!("Activity {self} delivered successfully");
Ok(()) Ok(())
} }
Ok(o) if o.status().is_client_error() => { o if o.status().is_client_error() => {
let text = o.text_limited().await.map_err(Error::other)?; let text = o.text_limited().await?;
debug!("Activity {self} was rejected, aborting: {text}"); debug!("Activity {self} was rejected, aborting: {text}");
Ok(()) Ok(())
} }
Ok(o) => { o => {
let status = o.status(); let status = o.status();
let text = o.text_limited().await.map_err(Error::other)?; let text = o.text_limited().await?;
Err(anyhow!(
Err(Error::Other(format!(
"Activity {self} failure with status {status}: {text}", "Activity {self} failure with status {status}: {text}",
)) )))
} }
Err(e) => Err(anyhow!("Activity {self} connection failure: {e}")),
} }
} }
} }
@ -158,7 +148,7 @@ impl SendActivityTask<'_> {
async fn get_pkey_cached<ActorType>( async fn get_pkey_cached<ActorType>(
data: &Data<impl Clone>, data: &Data<impl Clone>,
actor: &ActorType, actor: &ActorType,
) -> Result<PKey<Private>, anyhow::Error> ) -> Result<PKey<Private>, Error>
where where
ActorType: Actor, ActorType: Actor,
{ {
@ -168,20 +158,23 @@ where
.actor_pkey_cache .actor_pkey_cache
.try_get_with_by_ref(&actor_id, async { .try_get_with_by_ref(&actor_id, async {
let private_key_pem = actor.private_key_pem().ok_or_else(|| { let private_key_pem = actor.private_key_pem().ok_or_else(|| {
anyhow!("Actor {actor_id} does not contain a private key for signing") Error::Other(format!(
"Actor {actor_id} does not contain a private key for signing"
))
})?; })?;
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening // This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
let pkey = tokio::task::spawn_blocking(move || { let pkey = tokio::task::spawn_blocking(move || {
PKey::private_key_from_pem(private_key_pem.as_bytes()) PKey::private_key_from_pem(private_key_pem.as_bytes()).map_err(|err| {
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}")) Error::Other(format!("Could not create private key from PEM data:{err}"))
})
}) })
.await .await
.map_err(|err| anyhow!("Error joining: {err}"))??; .map_err(|err| Error::Other(format!("Error joining: {err}")))??;
std::result::Result::<PKey<Private>, anyhow::Error>::Ok(pkey) std::result::Result::<PKey<Private>, Error>::Ok(pkey)
}) })
.await .await
.map_err(|e| anyhow!("cloned error: {e}")) .map_err(|e| Error::Other(format!("cloned error: {e}")))
} }
pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {

View file

@ -8,7 +8,6 @@ use crate::{
traits::{ActivityHandler, Actor, Object}, traits::{ActivityHandler, Actor, Object},
}; };
use actix_web::{web::Bytes, HttpRequest, HttpResponse}; use actix_web::{web::Bytes, HttpRequest, HttpResponse};
use anyhow::Context;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use tracing::debug; use tracing::debug;
@ -24,17 +23,13 @@ where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static, Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static, ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error> <Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
+ From<Error> <ActorT as Object>::Error: From<Error>,
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
Datatype: Clone, Datatype: Clone,
{ {
verify_body_hash(request.headers().get("Digest"), &body)?; verify_body_hash(request.headers().get("Digest"), &body)?;
let activity: Activity = serde_json::from_slice(&body) let activity: Activity = serde_json::from_slice(&body).map_err(Error::Json)?;
.with_context(|| format!("deserializing body: {}", String::from_utf8_lossy(&body)))?;
data.config.verify_url_and_domain(&activity).await?; data.config.verify_url_and_domain(&activity).await?;
let actor = ObjectId::<ActorT>::from(activity.actor().clone()) let actor = ObjectId::<ActorT>::from(activity.actor().clone())
.dereference(data) .dereference(data)
@ -91,8 +86,7 @@ mod test {
.err() .err()
.unwrap(); .unwrap();
let e = err.root_cause().downcast_ref::<Error>().unwrap(); assert_eq!(&err, &Error::ActivityBodyDigestInvalid)
assert_eq!(e, &Error::ActivityBodyDigestInvalid)
} }
#[tokio::test] #[tokio::test]
@ -108,8 +102,7 @@ mod test {
.err() .err()
.unwrap(); .unwrap();
let e = err.root_cause().downcast_ref::<Error>().unwrap(); assert_eq!(&err, &Error::ActivitySignatureInvalid)
assert_eq!(e, &Error::ActivitySignatureInvalid)
} }
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) { async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {

View file

@ -22,7 +22,7 @@ pub async fn signing_actor<A>(
) -> Result<A, <A as Object>::Error> ) -> Result<A, <A as Object>::Error>
where where
A: Object + Actor, A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>, <A as Object>::Error: From<Error>,
for<'de2> <A as Object>::Kind: Deserialize<'de2>, for<'de2> <A as Object>::Kind: Deserialize<'de2>,
{ {
verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?; verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?;

View file

@ -29,16 +29,13 @@ where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static, Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static, ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error> <Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
+ From<Error> <ActorT as Object>::Error: From<Error>,
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
Datatype: Clone, Datatype: Clone,
{ {
verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?; verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?;
let activity: Activity = serde_json::from_slice(&activity_data.body)?; let activity: Activity = serde_json::from_slice(&activity_data.body).map_err(Error::Json)?;
data.config.verify_url_and_domain(&activity).await?; data.config.verify_url_and_domain(&activity).await?;
let actor = ObjectId::<ActorT>::from(activity.actor().clone()) let actor = ObjectId::<ActorT>::from(activity.actor().clone())
.dereference(data) .dereference(data)

View file

@ -19,7 +19,6 @@ use crate::{
protocol::verification::verify_domains_match, protocol::verification::verify_domains_match,
traits::{ActivityHandler, Actor}, traits::{ActivityHandler, Actor},
}; };
use anyhow::anyhow;
use async_trait::async_trait; use async_trait::async_trait;
use derive_builder::Builder; use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone}; use dyn_clone::{clone_trait_object, DynClone};
@ -104,9 +103,9 @@ impl<T: Clone> FederationConfig<T> {
verify_domains_match(activity.id(), activity.actor())?; verify_domains_match(activity.id(), activity.actor())?;
self.verify_url_valid(activity.id()).await?; self.verify_url_valid(activity.id()).await?;
if self.is_local_url(activity.id()) { if self.is_local_url(activity.id()) {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Activity was sent from local instance" "Activity was sent from local instance",
))); ));
} }
Ok(()) Ok(())
@ -129,12 +128,12 @@ impl<T: Clone> FederationConfig<T> {
"https" => {} "https" => {}
"http" => { "http" => {
if !self.allow_http_urls { if !self.allow_http_urls {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Http urls are only allowed in debug mode" "Http urls are only allowed in debug mode",
))); ));
} }
} }
_ => return Err(Error::UrlVerificationError(anyhow!("Invalid url scheme"))), _ => return Err(Error::UrlVerificationError("Invalid url scheme")),
}; };
// Urls which use our local domain are not a security risk, no further verification needed // Urls which use our local domain are not a security risk, no further verification needed
@ -143,21 +142,16 @@ impl<T: Clone> FederationConfig<T> {
} }
if url.domain().is_none() { if url.domain().is_none() {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError("Url must have a domain"));
"Url must have a domain"
)));
} }
if url.domain() == Some("localhost") && !self.debug { if url.domain() == Some("localhost") && !self.debug {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Localhost is only allowed in debug mode" "Localhost is only allowed in debug mode",
))); ));
} }
self.url_verifier self.url_verifier.verify(url).await?;
.verify(url)
.await
.map_err(Error::UrlVerificationError)?;
Ok(()) Ok(())
} }
@ -227,7 +221,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
/// # use async_trait::async_trait; /// # use async_trait::async_trait;
/// # use url::Url; /// # use url::Url;
/// # use activitypub_federation::config::UrlVerifier; /// # use activitypub_federation::config::UrlVerifier;
/// # use anyhow::anyhow; /// # use activitypub_federation::error::Error;
/// # #[derive(Clone)] /// # #[derive(Clone)]
/// # struct DatabaseConnection(); /// # struct DatabaseConnection();
/// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> { /// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> {
@ -240,11 +234,11 @@ impl<T: Clone> Deref for FederationConfig<T> {
/// ///
/// #[async_trait] /// #[async_trait]
/// impl UrlVerifier for Verifier { /// impl UrlVerifier for Verifier {
/// async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { /// async fn verify(&self, url: &Url) -> Result<(), Error> {
/// let blocklist = get_blocklist(&self.db_connection).await; /// let blocklist = get_blocklist(&self.db_connection).await;
/// let domain = url.domain().unwrap().to_string(); /// let domain = url.domain().unwrap().to_string();
/// if blocklist.contains(&domain) { /// if blocklist.contains(&domain) {
/// Err(anyhow!("Domain is blocked")) /// Err(Error::Other("Domain is blocked".into()))
/// } else { /// } else {
/// Ok(()) /// Ok(())
/// } /// }
@ -254,7 +248,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
#[async_trait] #[async_trait]
pub trait UrlVerifier: DynClone + Send { pub trait UrlVerifier: DynClone + Send {
/// Should return Ok iff the given url is valid for processing. /// Should return Ok iff the given url is valid for processing.
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error>; async fn verify(&self, url: &Url) -> Result<(), Error>;
} }
/// Default URL verifier which does nothing. /// Default URL verifier which does nothing.
@ -263,7 +257,7 @@ struct DefaultUrlVerifier();
#[async_trait] #[async_trait]
impl UrlVerifier for DefaultUrlVerifier { impl UrlVerifier for DefaultUrlVerifier {
async fn verify(&self, _url: &Url) -> Result<(), anyhow::Error> { async fn verify(&self, _url: &Url) -> Result<(), Error> {
Ok(()) Ok(())
} }
} }

View file

@ -1,5 +1,11 @@
//! Error messages returned by this library //! Error messages returned by this library
use std::string::FromUtf8Error;
use http_signature_normalization_reqwest::SignError;
use openssl::error::ErrorStack;
use url::Url;
/// Error messages returned by this library /// Error messages returned by this library
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum Error { pub enum Error {
@ -13,11 +19,11 @@ pub enum Error {
#[error("Response body limit was reached during fetch")] #[error("Response body limit was reached during fetch")]
ResponseBodyLimit, ResponseBodyLimit,
/// Object to be fetched was deleted /// Object to be fetched was deleted
#[error("Object to be fetched was deleted")] #[error("Fetched remote object {0} which was deleted")]
ObjectDeleted, ObjectDeleted(Url),
/// url verification error /// url verification error
#[error("URL failed verification: {0}")] #[error("URL failed verification: {0}")]
UrlVerificationError(anyhow::Error), UrlVerificationError(&'static str),
/// Incoming activity has invalid digest for body /// Incoming activity has invalid digest for body
#[error("Incoming activity has invalid digest for body")] #[error("Incoming activity has invalid digest for body")]
ActivityBodyDigestInvalid, ActivityBodyDigestInvalid,
@ -27,17 +33,35 @@ pub enum Error {
/// Failed to resolve actor via webfinger /// Failed to resolve actor via webfinger
#[error("Failed to resolve actor via webfinger")] #[error("Failed to resolve actor via webfinger")]
WebfingerResolveFailed, WebfingerResolveFailed,
/// other error /// Failed to resolve actor via webfinger
#[error("Webfinger regex failed to match")]
WebfingerRegexFailed,
/// JSON Error
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Json(#[from] serde_json::Error),
/// Reqwest Middleware Error
#[error(transparent)]
ReqwestMiddleware(#[from] reqwest_middleware::Error),
/// Reqwest Error
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
/// UTF-8 error
#[error(transparent)]
Utf8(#[from] FromUtf8Error),
/// Url Parse
#[error(transparent)]
UrlParse(#[from] url::ParseError),
/// Signing errors
#[error(transparent)]
SignError(#[from] SignError),
/// Other generic errors
#[error("{0}")]
Other(String),
} }
impl Error { impl From<ErrorStack> for Error {
pub(crate) fn other<T>(error: T) -> Self fn from(value: ErrorStack) -> Self {
where Error::Other(value.to_string())
T: Into<anyhow::Error>,
{
Error::Other(error.into())
} }
} }

View file

@ -83,13 +83,13 @@ async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
data.config.http_signature_compat, data.config.http_signature_compat,
) )
.await?; .await?;
config.client.execute(req).await.map_err(Error::other)? config.client.execute(req).await?
} else { } else {
req.send().await.map_err(Error::other)? req.send().await?
}; };
if res.status() == StatusCode::GONE { if res.status() == StatusCode::GONE {
return Err(Error::ObjectDeleted); return Err(Error::ObjectDeleted(url.clone()));
} }
let url = res.url().clone(); let url = res.url().clone();

View file

@ -1,5 +1,4 @@
use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object}; use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
use anyhow::anyhow;
use chrono::{DateTime, Duration as ChronoDuration, Utc}; use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
@ -90,7 +89,7 @@ where
data: &Data<<Kind as Object>::DataType>, data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error> ) -> Result<Kind, <Kind as Object>::Error>
where where
<Kind as Object>::Error: From<Error> + From<anyhow::Error>, <Kind as Object>::Error: From<Error>,
{ {
let db_object = self.dereference_from_db(data).await?; let db_object = self.dereference_from_db(data).await?;
// if its a local object, only fetch it from the database and not over http // if its a local object, only fetch it from the database and not over http
@ -145,15 +144,15 @@ where
db_object: Option<Kind>, db_object: Option<Kind>,
) -> Result<Kind, <Kind as Object>::Error> ) -> Result<Kind, <Kind as Object>::Error>
where where
<Kind as Object>::Error: From<Error> + From<anyhow::Error>, <Kind as Object>::Error: From<Error>,
{ {
let res = fetch_object_http(&self.0, data).await; let res = fetch_object_http(&self.0, data).await;
if let Err(Error::ObjectDeleted) = &res { if let Err(Error::ObjectDeleted(url)) = res {
if let Some(db_object) = db_object { if let Some(db_object) = db_object {
db_object.delete(data).await?; db_object.delete(data).await?;
} }
return Err(anyhow!("Fetched remote object {} which was deleted", self).into()); return Err(Error::ObjectDeleted(url).into());
} }
let res = res?; let res = res?;

View file

@ -5,7 +5,6 @@ use crate::{
traits::{Actor, Object}, traits::{Actor, Object},
FEDERATION_CONTENT_TYPE, FEDERATION_CONTENT_TYPE,
}; };
use anyhow::anyhow;
use itertools::Itertools; use itertools::Itertools;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -24,8 +23,7 @@ pub async fn webfinger_resolve_actor<T: Clone, Kind>(
where where
Kind: Object + Actor + Send + 'static + Object<DataType = T>, Kind: Object + Actor + Send + 'static + Object<DataType = T>,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
<Kind as Object>::Error: <Kind as Object>::Error: From<crate::error::Error> + Send + Sync,
From<crate::error::Error> + From<anyhow::Error> + From<url::ParseError> + Send + Sync,
{ {
let (_, domain) = identifier let (_, domain) = identifier
.splitn(2, '@') .splitn(2, '@')
@ -36,10 +34,13 @@ where
format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}"); format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}");
debug!("Fetching webfinger url: {}", &fetch_url); debug!("Fetching webfinger url: {}", &fetch_url);
let res: Webfinger = let res: Webfinger = fetch_object_http_with_accept(
fetch_object_http_with_accept(&Url::parse(&fetch_url)?, data, "application/jrd+json") &Url::parse(&fetch_url).map_err(Error::UrlParse)?,
.await? data,
.object; "application/jrd+json",
)
.await?
.object;
debug_assert_eq!(res.subject, format!("acct:{identifier}")); debug_assert_eq!(res.subject, format!("acct:{identifier}"));
let links: Vec<Url> = res let links: Vec<Url> = res
@ -94,14 +95,16 @@ where
{ {
// Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`. // Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`.
// TODO: would be nice if we could implement this without regex and remove the dependency // TODO: would be nice if we could implement this without regex and remove the dependency
let regex = let result = Regex::new(&format!(r"^acct:([\p{{L}}0-9_]+)@{}$", data.domain()))
Regex::new(&format!(r"^acct:([\p{{L}}0-9_]+)@{}$", data.domain())).map_err(Error::other)?; .map_err(|_| Error::WebfingerRegexFailed)
Ok(regex .and_then(|regex| {
.captures(query) regex
.and_then(|c| c.get(1)) .captures(query)
.ok_or_else(|| Error::other(anyhow!("Webfinger regex failed to match")))? .and_then(|c| c.get(1))
.as_str() .ok_or_else(|| Error::WebfingerRegexFailed)
.to_string()) })?;
return Ok(result.as_str().to_string());
} }
/// Builds a basic webfinger response for the actor. /// Builds a basic webfinger response for the actor.

View file

@ -12,7 +12,6 @@ use crate::{
protocol::public_key::main_key_id, protocol::public_key::main_key_id,
traits::{Actor, Object}, traits::{Actor, Object},
}; };
use anyhow::Context;
use base64::{engine::general_purpose::STANDARD as Base64, Engine}; use base64::{engine::general_purpose::STANDARD as Base64, Engine};
use bytes::Bytes; use bytes::Bytes;
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri}; use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri};
@ -83,7 +82,7 @@ pub(crate) async fn sign_request(
activity: Bytes, activity: Bytes,
private_key: PKey<Private>, private_key: PKey<Private>,
http_signature_compat: bool, http_signature_compat: bool,
) -> Result<Request, anyhow::Error> { ) -> Result<Request, Error> {
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER)); static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| { static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
Config::new() Config::new()
@ -103,14 +102,10 @@ pub(crate) async fn sign_request(
Sha256::new(), Sha256::new(),
activity, activity,
move |signing_string| { move |signing_string| {
let mut signer = Signer::new(MessageDigest::sha256(), &private_key) let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
.context("instantiating signer")?; signer.update(signing_string.as_bytes())?;
signer
.update(signing_string.as_bytes())
.context("updating signer")?;
Ok(Base64.encode(signer.sign_to_vec().context("sign to vec")?)) Ok(Base64.encode(signer.sign_to_vec()?)) as Result<_, Error>
as Result<_, anyhow::Error>
}, },
) )
.await .await
@ -152,7 +147,7 @@ pub(crate) async fn signing_actor<'a, A, H>(
) -> Result<A, <A as Object>::Error> ) -> Result<A, <A as Object>::Error>
where where
A: Object + Actor, A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>, <A as Object>::Error: From<Error>,
for<'de2> <A as Object>::Kind: Deserialize<'de2>, for<'de2> <A as Object>::Kind: Deserialize<'de2>,
H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>, H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>,
{ {
@ -197,8 +192,8 @@ fn verify_signature_inner(
let verified = CONFIG let verified = CONFIG
.begin_verify(method.as_str(), path_and_query, header_map) .begin_verify(method.as_str(), path_and_query, header_map)
.map_err(Error::other)? .map_err(|val| Error::Other(val.to_string()))?
.verify(|signature, signing_string| -> anyhow::Result<bool> { .verify(|signature, signing_string| -> Result<bool, Error> {
debug!( debug!(
"Verifying with key {}, message {}", "Verifying with key {}, message {}",
&public_key, &signing_string &public_key, &signing_string
@ -206,9 +201,13 @@ fn verify_signature_inner(
let public_key = PKey::public_key_from_pem(public_key.as_bytes())?; let public_key = PKey::public_key_from_pem(public_key.as_bytes())?;
let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?; let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?;
verifier.update(signing_string.as_bytes())?; verifier.update(signing_string.as_bytes())?;
Ok(verifier.verify(&Base64.decode(signature)?)?)
}) let base64_decoded = Base64
.map_err(Error::other)?; .decode(signature)
.map_err(|err| Error::Other(err.to_string()))?;
Ok(verifier.verify(&base64_decoded)?)
})?;
if verified { if verified {
debug!("verified signature for {}", uri); debug!("verified signature for {}", uri);

View file

@ -1,7 +1,6 @@
//! Verify that received data is valid //! Verify that received data is valid
use crate::error::Error; use crate::error::Error;
use anyhow::anyhow;
use url::Url; use url::Url;
/// Check that both urls have the same domain. If not, return UrlVerificationError. /// Check that both urls have the same domain. If not, return UrlVerificationError.
@ -16,7 +15,7 @@ use url::Url;
/// ``` /// ```
pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
if a.domain() != b.domain() { if a.domain() != b.domain() {
return Err(Error::UrlVerificationError(anyhow!("Domains do not match"))); return Err(Error::UrlVerificationError("Domains do not match"));
} }
Ok(()) Ok(())
} }
@ -33,7 +32,7 @@ pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
/// ``` /// ```
pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> { pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> {
if a != b { if a != b {
return Err(Error::UrlVerificationError(anyhow!("Urls do not match"))); return Err(Error::UrlVerificationError("Urls do not match"));
} }
Ok(()) Ok(())
} }

View file

@ -30,10 +30,7 @@ impl Future for BytesFuture {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop { loop {
let this = self.as_mut().project(); let this = self.as_mut().project();
if let Some(chunk) = ready!(this.stream.poll_next(cx)) if let Some(chunk) = ready!(this.stream.poll_next(cx)).transpose()? {
.transpose()
.map_err(Error::other)?
{
this.aggregator.put(chunk); this.aggregator.put(chunk);
if this.aggregator.len() > *this.limit { if this.aggregator.len() > *this.limit {
return Poll::Ready(Err(Error::ResponseBodyLimit)); return Poll::Ready(Err(Error::ResponseBodyLimit));
@ -66,7 +63,7 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
let bytes = ready!(this.future.poll(cx))?; let bytes = ready!(this.future.poll(cx))?;
Poll::Ready(serde_json::from_slice(&bytes).map_err(Error::other)) Poll::Ready(serde_json::from_slice(&bytes).map_err(Error::Json))
} }
} }
@ -83,7 +80,7 @@ impl Future for TextFuture {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
let bytes = ready!(this.future.poll(cx))?; let bytes = ready!(this.future.poll(cx))?;
Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::other)) Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::Utf8))
} }
} }

View file

@ -340,12 +340,12 @@ pub trait Collection: Sized {
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::{ use crate::{
error::Error,
fetch::object_id::ObjectId, fetch::object_id::ObjectId,
http_signatures::{generate_actor_keypair, Keypair}, http_signatures::{generate_actor_keypair, Keypair},
protocol::{public_key::PublicKey, verification::verify_domains_match}, protocol::{public_key::PublicKey, verification::verify_domains_match},
}; };
use activitystreams_kinds::{activity::FollowType, actor::PersonType}; use activitystreams_kinds::{activity::FollowType, actor::PersonType};
use anyhow::Error;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};