diff --git a/Cargo.lock b/Cargo.lock index bf77294ce..870584995 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2255,6 +2255,7 @@ dependencies = [ "reqwest-retry", "reqwest-tracing", "serde", + "tokio", "tracing", "tracing-actix-web", "tracing-error", @@ -2295,6 +2296,7 @@ dependencies = [ "smart-default", "strum", "strum_macros", + "tokio", "tracing", "tracing-error", "typed-builder", diff --git a/Cargo.toml b/Cargo.toml index 30d36e948..a888514f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,3 +73,4 @@ console-subscriber = { version = "0.1.8", optional = true } opentelemetry = { version = "0.17.0", features = ["rt-tokio"], optional = true } opentelemetry-otlp = { version = "0.10.0", optional = true } tracing-opentelemetry = { version = "0.17.2", optional = true } +tokio = "1.21.2" diff --git a/config/config.hjson b/config/config.hjson index c26e52d4f..252fca250 100644 --- a/config/config.hjson +++ b/config/config.hjson @@ -2,8 +2,4 @@ # https://join-lemmy.org/docs/en/administration/configuration.html { hostname: lemmy-alpha - database: { - # Username to connect to postgres - user: "&££%^!£*!:::!"£:!:" - } } diff --git a/crates/api_crud/src/site/create.rs b/crates/api_crud/src/site/create.rs index 7a9b04840..99d67e157 100644 --- a/crates/api_crud/src/site/create.rs +++ b/crates/api_crud/src/site/create.rs @@ -6,6 +6,7 @@ use lemmy_api_common::{ utils::{ get_local_user_view_from_jwt, is_admin, + local_site_rate_limit_to_rate_limit_config, local_site_to_slur_regex, site_description_length_check, }, @@ -139,6 +140,13 @@ impl PerformCrud for CreateSite { let site_view = SiteView::read_local(context.pool()).await?; + let rate_limit_config = + local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); + context + .settings_updated_channel() + .send(rate_limit_config) + .await?; + Ok(SiteResponse { site_view }) } } diff --git a/crates/api_crud/src/site/update.rs b/crates/api_crud/src/site/update.rs index 5f22d9f2a..a4eba29bd 100644 --- a/crates/api_crud/src/site/update.rs +++ b/crates/api_crud/src/site/update.rs @@ -5,6 +5,7 @@ use lemmy_api_common::{ utils::{ get_local_user_view_from_jwt, is_admin, + local_site_rate_limit_to_rate_limit_config, local_site_to_slur_regex, site_description_length_check, }, @@ -176,6 +177,13 @@ impl PerformCrud for EditSite { let site_view = SiteView::read_local(context.pool()).await?; + let rate_limit_config = + local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); + context + .settings_updated_channel() + .send(rate_limit_config) + .await?; + let res = SiteResponse { site_view }; context.chat_server().do_send(SendAllMessage { diff --git a/crates/apub/src/objects/mod.rs b/crates/apub/src/objects/mod.rs index 1e2ac79b4..29f030aed 100644 --- a/crates/apub/src/objects/mod.rs +++ b/crates/apub/src/objects/mod.rs @@ -60,13 +60,12 @@ pub(crate) mod tests { use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool_for_tests}; use lemmy_utils::{ error::LemmyError, - rate_limit::{rate_limiter::RateLimiter, RateLimit, RateLimitConfig}, + rate_limit::{RateLimitCell, RateLimitConfig}, settings::SETTINGS, }; use lemmy_websocket::{chat_server::ChatServer, LemmyContext}; use reqwest::{Client, Request, Response}; use reqwest_middleware::{ClientBuilder, Middleware, Next}; - use std::sync::{Arc, Mutex}; use task_local_extensions::Extensions; struct BlockedMiddleware; @@ -105,22 +104,25 @@ pub(crate) mod tests { } let rate_limit_config = RateLimitConfig::builder().build(); - - let rate_limiter = RateLimit { - rate_limiter: Arc::new(Mutex::new(RateLimiter::default())), - rate_limit_config, - }; + let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; let chat_server = ChatServer::startup( pool.clone(), - rate_limiter, |_, _, _, _| Box::pin(x()), |_, _, _, _| Box::pin(x()), client.clone(), settings.clone(), secret.clone(), + rate_limit_cell.clone(), ) .start(); - LemmyContext::create(pool, chat_server, client, settings, secret) + LemmyContext::create( + pool, + chat_server, + client, + settings, + secret, + rate_limit_cell.clone(), + ) } } diff --git a/crates/routes/src/images.rs b/crates/routes/src/images.rs index 97134078b..716b18be8 100644 --- a/crates/routes/src/images.rs +++ b/crates/routes/src/images.rs @@ -13,13 +13,17 @@ use actix_web::{ use futures::stream::{Stream, StreamExt}; use lemmy_api_common::utils::get_local_user_view_from_jwt; use lemmy_db_schema::source::local_site::LocalSite; -use lemmy_utils::{claims::Claims, rate_limit::RateLimit, REQWEST_TIMEOUT}; +use lemmy_utils::{claims::Claims, rate_limit::RateLimitCell, REQWEST_TIMEOUT}; use lemmy_websocket::LemmyContext; use reqwest::Body; use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use serde::{Deserialize, Serialize}; -pub fn config(cfg: &mut web::ServiceConfig, client: ClientWithMiddleware, rate_limit: &RateLimit) { +pub fn config( + cfg: &mut web::ServiceConfig, + client: ClientWithMiddleware, + rate_limit: &RateLimitCell, +) { cfg .app_data(web::Data::new(client)) .service( diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 43cbd8bff..dd79fa84e 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -45,6 +45,7 @@ rosetta-i18n = "0.1.2" parking_lot = "0.12.1" typed-builder = "0.10.0" percent-encoding = "2.2.0" +tokio = "1.21.2" [build-dependencies] rosetta-build = "0.1.2" diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index ed019255f..6dc9dcbef 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,7 +1,7 @@ use crate::{error::LemmyError, utils::get_ip, IpAddr}; use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; use futures::future::{ok, Ready}; -use rate_limiter::{RateLimitType, RateLimiter}; +use rate_limiter::{RateLimitStorage, RateLimitType}; use serde::{Deserialize, Serialize}; use std::{ future::Future, @@ -10,6 +10,7 @@ use std::{ sync::{Arc, Mutex}, task::{Context, Poll}, }; +use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; use typed_builder::TypedBuilder; pub mod rate_limiter; @@ -55,65 +56,102 @@ pub struct RateLimitConfig { } #[derive(Debug, Clone)] -pub struct RateLimit { - // it might be reasonable to use a std::sync::Mutex here, since we don't need to lock this - // across await points - pub rate_limiter: Arc>, +struct RateLimit { + pub rate_limiter: RateLimitStorage, pub rate_limit_config: RateLimitConfig, } #[derive(Debug, Clone)] -pub struct RateLimited { - rate_limiter: Arc>, - rate_limit_config: RateLimitConfig, +pub struct RateLimitedGuard { + rate_limit: Arc>, type_: RateLimitType, } -pub struct RateLimitedMiddleware { - rate_limited: RateLimited, - service: Rc, +/// Single instance of rate limit config and buckets, which is shared across all threads. +#[derive(Clone)] +pub struct RateLimitCell { + tx: Sender, + rate_limit: Arc>, } -impl RateLimit { - pub fn message(&self) -> RateLimited { +impl RateLimitCell { + /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. + pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { + static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); + LOCAL_INSTANCE + .get_or_init(|| async { + let (tx, mut rx) = mpsc::channel::(4); + let rate_limit = Arc::new(Mutex::new(RateLimit { + rate_limiter: Default::default(), + rate_limit_config, + })); + let rate_limit2 = rate_limit.clone(); + tokio::spawn(async move { + while let Some(r) = rx.recv().await { + rate_limit2 + .lock() + .expect("Failed to lock rate limit mutex for updating") + .rate_limit_config = r; + } + }); + RateLimitCell { tx, rate_limit } + }) + .await + } + + /// Call this when the config was updated, to update all in-memory cells. + pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { + self.tx.send(config).await?; + Ok(()) + } + + pub fn message(&self) -> RateLimitedGuard { self.kind(RateLimitType::Message) } - pub fn post(&self) -> RateLimited { + pub fn post(&self) -> RateLimitedGuard { self.kind(RateLimitType::Post) } - pub fn register(&self) -> RateLimited { + pub fn register(&self) -> RateLimitedGuard { self.kind(RateLimitType::Register) } - pub fn image(&self) -> RateLimited { + pub fn image(&self) -> RateLimitedGuard { self.kind(RateLimitType::Image) } - pub fn comment(&self) -> RateLimited { + pub fn comment(&self) -> RateLimitedGuard { self.kind(RateLimitType::Comment) } - pub fn search(&self) -> RateLimited { + pub fn search(&self) -> RateLimitedGuard { self.kind(RateLimitType::Search) } - fn kind(&self, type_: RateLimitType) -> RateLimited { - RateLimited { - rate_limiter: self.rate_limiter.clone(), - rate_limit_config: self.rate_limit_config.clone(), + fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { + RateLimitedGuard { + rate_limit: self.rate_limit.clone(), type_, } } } -impl RateLimited { +pub struct RateLimitedMiddleware { + rate_limited: RateLimitedGuard, + service: Rc, +} + +impl RateLimitedGuard { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone - let rate_limit = self.rate_limit_config; + let mut guard = self + .rate_limit + .lock() + .expect("Failed to lock rate limit mutex for reading"); + let rate_limit = &guard.rate_limit_config; let (kind, interval) = match self.type_ { RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), @@ -123,13 +161,13 @@ impl RateLimited { RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), }; - let mut limiter = self.rate_limiter.lock().expect("mutex poison error"); + let limiter = &mut guard.rate_limiter; limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) } } -impl Transform for RateLimited +impl Transform for RateLimitedGuard where S: Service + 'static, S::Future: 'static, diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs index 258d7704a..80148340a 100644 --- a/crates/utils/src/rate_limit/rate_limiter.rs +++ b/crates/utils/src/rate_limit/rate_limiter.rs @@ -21,11 +21,11 @@ pub(crate) enum RateLimitType { /// Rate limiting based on rate type and IP addr #[derive(Debug, Clone, Default)] -pub struct RateLimiter { +pub struct RateLimitStorage { buckets: HashMap>, } -impl RateLimiter { +impl RateLimitStorage { fn insert_ip(&mut self, ip: &IpAddr) { for rate_limit_type in RateLimitType::iter() { if self.buckets.get(&rate_limit_type).is_none() { diff --git a/crates/websocket/src/chat_server.rs b/crates/websocket/src/chat_server.rs index 65ab4cbd7..d07f11bd1 100644 --- a/crates/websocket/src/chat_server.rs +++ b/crates/websocket/src/chat_server.rs @@ -17,7 +17,7 @@ use lemmy_db_schema::{ use lemmy_utils::{ error::LemmyError, location_info, - rate_limit::RateLimit, + rate_limit::RateLimitCell, settings::structs::Settings, ConnectionId, IpAddr, @@ -76,9 +76,6 @@ pub struct ChatServer { /// The Secrets pub(super) secret: Secret, - /// Rate limiting based on rate type and IP addr - pub(super) rate_limiter: RateLimit, - /// A list of the current captchas pub(super) captchas: Vec, @@ -87,6 +84,8 @@ pub struct ChatServer { /// An HTTP Client client: ClientWithMiddleware, + + rate_limit_cell: RateLimitCell, } pub struct SessionInfo { @@ -101,12 +100,12 @@ impl ChatServer { #![allow(clippy::too_many_arguments)] pub fn startup( pool: DbPool, - rate_limiter: RateLimit, message_handler: MessageHandlerType, message_handler_crud: MessageHandlerCrudType, client: ClientWithMiddleware, settings: Settings, secret: Secret, + rate_limit_cell: RateLimitCell, ) -> ChatServer { ChatServer { sessions: HashMap::new(), @@ -116,13 +115,13 @@ impl ChatServer { user_rooms: HashMap::new(), rng: rand::thread_rng(), pool, - rate_limiter, captchas: Vec::new(), message_handler, message_handler_crud, client, settings, secret, + rate_limit_cell, } } @@ -446,8 +445,6 @@ impl ChatServer { msg: StandardMessage, ctx: &mut Context, ) -> impl Future> { - let rate_limiter = self.rate_limiter.clone(); - let ip: IpAddr = match self.sessions.get(&msg.id) { Some(info) => info.ip.to_owned(), None => IpAddr("blank_ip".to_string()), @@ -459,9 +456,11 @@ impl ChatServer { client: self.client.to_owned(), settings: self.settings.to_owned(), secret: self.secret.to_owned(), + rate_limit_cell: self.rate_limit_cell.to_owned(), }; let message_handler_crud = self.message_handler_crud; let message_handler = self.message_handler; + let rate_limiter = self.rate_limit_cell.clone(); async move { let json: Value = serde_json::from_str(&msg.msg)?; let data = &json["data"].to_string(); diff --git a/crates/websocket/src/lib.rs b/crates/websocket/src/lib.rs index 7f363b0b5..e73e784e2 100644 --- a/crates/websocket/src/lib.rs +++ b/crates/websocket/src/lib.rs @@ -6,6 +6,7 @@ use actix::Addr; use lemmy_db_schema::{source::secret::Secret, utils::DbPool}; use lemmy_utils::{ error::LemmyError, + rate_limit::RateLimitCell, settings::{structs::Settings, SETTINGS}, }; use reqwest_middleware::ClientWithMiddleware; @@ -23,6 +24,7 @@ pub struct LemmyContext { client: ClientWithMiddleware, settings: Settings, secret: Secret, + rate_limit_cell: RateLimitCell, } impl LemmyContext { @@ -32,6 +34,7 @@ impl LemmyContext { client: ClientWithMiddleware, settings: Settings, secret: Secret, + settings_updated_channel: RateLimitCell, ) -> LemmyContext { LemmyContext { pool, @@ -39,6 +42,7 @@ impl LemmyContext { client, settings, secret, + rate_limit_cell: settings_updated_channel, } } pub fn pool(&self) -> &DbPool { @@ -56,6 +60,9 @@ impl LemmyContext { pub fn secret(&self) -> &Secret { &self.secret } + pub fn settings_updated_channel(&self) -> &RateLimitCell { + &self.rate_limit_cell + } } impl Clone for LemmyContext { @@ -66,6 +73,7 @@ impl Clone for LemmyContext { client: self.client.clone(), settings: self.settings.clone(), secret: self.secret.clone(), + rate_limit_cell: self.rate_limit_cell.clone(), } } } diff --git a/crates/websocket/src/routes.rs b/crates/websocket/src/routes.rs index e99e683eb..453a87b9e 100644 --- a/crates/websocket/src/routes.rs +++ b/crates/websocket/src/routes.rs @@ -6,7 +6,7 @@ use crate::{ use actix::prelude::*; use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web_actors::ws; -use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr}; +use lemmy_utils::{rate_limit::RateLimitCell, utils::get_ip, ConnectionId, IpAddr}; use std::time::{Duration, Instant}; use tracing::{debug, error, info}; @@ -20,7 +20,7 @@ pub async fn chat_route( req: HttpRequest, stream: web::Payload, context: web::Data, - rate_limiter: web::Data, + rate_limiter: web::Data, ) -> Result { ws::start( WsSession { @@ -44,7 +44,7 @@ struct WsSession { /// otherwise we drop connection. hb: Instant, /// A rate limiter for websocket joins - rate_limiter: RateLimit, + rate_limiter: RateLimitCell, } impl Actor for WsSession { diff --git a/src/api_routes.rs b/src/api_routes.rs index 006140262..02161121e 100644 --- a/src/api_routes.rs +++ b/src/api_routes.rs @@ -10,11 +10,11 @@ use lemmy_api_common::{ websocket::*, }; use lemmy_api_crud::PerformCrud; -use lemmy_utils::rate_limit::RateLimit; +use lemmy_utils::rate_limit::RateLimitCell; use lemmy_websocket::{routes::chat_route, LemmyContext}; use serde::Deserialize; -pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimit) { +pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) { cfg.service( web::scope("/api/v3") // Websocket diff --git a/src/main.rs b/src/main.rs index 5d1a3a1a7..f57aa5025 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,7 @@ use lemmy_server::{ }; use lemmy_utils::{ error::LemmyError, - rate_limit::{rate_limiter::RateLimiter, RateLimit}, + rate_limit::RateLimitCell, settings::{structs::Settings, SETTINGS}, }; use lemmy_websocket::{chat_server::ChatServer, LemmyContext}; @@ -37,12 +37,7 @@ use reqwest::Client; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use reqwest_tracing::TracingMiddleware; -use std::{ - env, - sync::{Arc, Mutex}, - thread, - time::Duration, -}; +use std::{env, thread, time::Duration}; use tracing_actix_web::TracingLogger; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); @@ -107,13 +102,7 @@ async fn main() -> Result<(), LemmyError> { // Set up the rate limiter let rate_limit_config = local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); - - // TODO this isn't live-updating - // https://github.com/LemmyNet/lemmy/issues/2508 - let rate_limiter = RateLimit { - rate_limiter: Arc::new(Mutex::new(RateLimiter::default())), - rate_limit_config, - }; + let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; println!( "Starting http server at {}:{}", @@ -144,12 +133,12 @@ async fn main() -> Result<(), LemmyError> { let chat_server = ChatServer::startup( pool.clone(), - rate_limiter.clone(), |c, i, o, d| Box::pin(match_websocket_operation(c, i, o, d)), |c, i, o, d| Box::pin(match_websocket_operation_crud(c, i, o, d)), client.clone(), settings.clone(), secret.clone(), + rate_limit_cell.clone(), ) .start(); @@ -162,15 +151,15 @@ async fn main() -> Result<(), LemmyError> { client.clone(), settings.to_owned(), secret.to_owned(), + rate_limit_cell.clone(), ); - let rate_limiter = rate_limiter.clone(); App::new() - .wrap(actix_web::middleware::Logger::default()) + .wrap(middleware::Logger::default()) .wrap(TracingLogger::::new()) .app_data(Data::new(context)) - .app_data(Data::new(rate_limiter.clone())) + .app_data(Data::new(rate_limit_cell.clone())) // The routes - .configure(|cfg| api_routes::config(cfg, &rate_limiter)) + .configure(|cfg| api_routes::config(cfg, rate_limit_cell)) .configure(|cfg| { if federation_enabled { lemmy_apub::http::routes::config(cfg); @@ -178,7 +167,7 @@ async fn main() -> Result<(), LemmyError> { } }) .configure(feeds::config) - .configure(|cfg| images::config(cfg, pictrs_client.clone(), &rate_limiter)) + .configure(|cfg| images::config(cfg, pictrs_client.clone(), rate_limit_cell)) .configure(nodeinfo::config) }) .bind((settings_bind.bind, settings_bind.port))?