diff --git a/Cargo.lock b/Cargo.lock index debd8bcc5..073e0a95c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2659,6 +2659,7 @@ dependencies = [ "anyhow", "chrono", "encoding", + "enum-map", "futures", "getrandom", "jsonwebtoken", diff --git a/Cargo.toml b/Cargo.toml index 356abb035..9bf1000b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,6 +129,7 @@ rustls = { version = "0.21.3", features = ["dangerous_configuration"] } futures-util = "0.3.28" tokio-postgres = "0.7.8" tokio-postgres-rustls = "0.10.0" +enum-map = "2.6" [dependencies] lemmy_api = { workspace = true } diff --git a/crates/api_common/Cargo.toml b/crates/api_common/Cargo.toml index 5325350c8..a01e6008c 100644 --- a/crates/api_common/Cargo.toml +++ b/crates/api_common/Cargo.toml @@ -68,6 +68,7 @@ actix-web = { workspace = true, optional = true } jsonwebtoken = { version = "8.3.0", optional = true } # necessary for wasmt compilation getrandom = { version = "0.2.10", features = ["js"] } +enum-map = { workspace = true } [dev-dependencies] serial_test = { workspace = true } diff --git a/crates/api_common/src/claims.rs b/crates/api_common/src/claims.rs index 6676840dc..09191ad71 100644 --- a/crates/api_common/src/claims.rs +++ b/crates/api_common/src/claims.rs @@ -88,7 +88,7 @@ mod tests { traits::Crud, utils::build_db_pool_for_tests, }; - use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig}; + use lemmy_utils::rate_limit::RateLimitCell; use reqwest::Client; use reqwest_middleware::ClientBuilder; use serial_test::serial; @@ -103,9 +103,7 @@ mod tests { pool_.clone(), ClientBuilder::new(Client::default()).build(), secret, - RateLimitCell::new(RateLimitConfig::builder().build()) - .await - .clone(), + RateLimitCell::with_test_config(), ); let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string()) diff --git a/crates/api_common/src/context.rs b/crates/api_common/src/context.rs index 0d448ef97..888a98741 100644 --- a/crates/api_common/src/context.rs +++ b/crates/api_common/src/context.rs @@ -46,7 +46,7 @@ impl LemmyContext { pub fn secret(&self) -> &Secret { &self.secret } - pub fn settings_updated_channel(&self) -> &RateLimitCell { + pub fn rate_limit_cell(&self) -> &RateLimitCell { &self.rate_limit_cell } } diff --git a/crates/api_common/src/utils.rs b/crates/api_common/src/utils.rs index b3dcd7558..5ba9a34c3 100644 --- a/crates/api_common/src/utils.rs +++ b/crates/api_common/src/utils.rs @@ -7,6 +7,7 @@ use crate::{ use actix_web::cookie::{Cookie, SameSite}; use anyhow::Context; use chrono::{DateTime, Days, Local, TimeZone, Utc}; +use enum_map::{enum_map, EnumMap}; use lemmy_db_schema::{ newtypes::{CommunityId, DbUrl, PersonId, PostId}, source::{ @@ -34,7 +35,7 @@ use lemmy_utils::{ email::{send_email, translations::Lang}, error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult}, location_info, - rate_limit::RateLimitConfig, + rate_limit::{ActionType, BucketConfig}, settings::structs::Settings, utils::slurs::build_slur_regex, }; @@ -390,25 +391,21 @@ fn lang_str_to_lang(lang: &str) -> Lang { } pub fn local_site_rate_limit_to_rate_limit_config( - local_site_rate_limit: &LocalSiteRateLimit, -) -> RateLimitConfig { - let l = local_site_rate_limit; - RateLimitConfig { - message: l.message, - message_per_second: l.message_per_second, - post: l.post, - post_per_second: l.post_per_second, - register: l.register, - register_per_second: l.register_per_second, - image: l.image, - image_per_second: l.image_per_second, - comment: l.comment, - comment_per_second: l.comment_per_second, - search: l.search, - search_per_second: l.search_per_second, - import_user_settings: l.import_user_settings, - import_user_settings_per_second: l.import_user_settings_per_second, + l: &LocalSiteRateLimit, +) -> EnumMap { + enum_map! { + ActionType::Message => (l.message, l.message_per_second), + ActionType::Post => (l.post, l.post_per_second), + ActionType::Register => (l.register, l.register_per_second), + ActionType::Image => (l.image, l.image_per_second), + ActionType::Comment => (l.comment, l.comment_per_second), + ActionType::Search => (l.search, l.search_per_second), + ActionType::ImportUserSettings => (l.import_user_settings, l.import_user_settings_per_second), } + .map(|_key, (capacity, secs_to_refill)| BucketConfig { + capacity: u32::try_from(capacity).unwrap_or(0), + secs_to_refill: u32::try_from(secs_to_refill).unwrap_or(0), + }) } pub fn local_site_to_slur_regex(local_site: &LocalSite) -> Option { diff --git a/crates/api_crud/src/site/create.rs b/crates/api_crud/src/site/create.rs index 61dfd7c77..1449f4844 100644 --- a/crates/api_crud/src/site/create.rs +++ b/crates/api_crud/src/site/create.rs @@ -119,10 +119,7 @@ pub async fn create_site( let rate_limit_config = local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); - context - .settings_updated_channel() - .send(rate_limit_config) - .await?; + context.rate_limit_cell().set_config(rate_limit_config); Ok(Json(SiteResponse { site_view, diff --git a/crates/api_crud/src/site/update.rs b/crates/api_crud/src/site/update.rs index 3afc79559..b9d8f6a7f 100644 --- a/crates/api_crud/src/site/update.rs +++ b/crates/api_crud/src/site/update.rs @@ -157,10 +157,7 @@ pub async fn update_site( let rate_limit_config = local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); - context - .settings_updated_channel() - .send(rate_limit_config) - .await?; + context.rate_limit_cell().set_config(rate_limit_config); Ok(Json(SiteResponse { site_view, diff --git a/crates/apub/src/objects/mod.rs b/crates/apub/src/objects/mod.rs index b3653172a..6e27c0d09 100644 --- a/crates/apub/src/objects/mod.rs +++ b/crates/apub/src/objects/mod.rs @@ -61,10 +61,7 @@ pub(crate) mod tests { use anyhow::anyhow; use lemmy_api_common::{context::LemmyContext, request::build_user_agent}; use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool_for_tests}; - use lemmy_utils::{ - rate_limit::{RateLimitCell, RateLimitConfig}, - settings::SETTINGS, - }; + use lemmy_utils::{rate_limit::RateLimitCell, settings::SETTINGS}; use reqwest::{Client, Request, Response}; use reqwest_middleware::{ClientBuilder, Middleware, Next}; use task_local_extensions::Extensions; @@ -101,8 +98,7 @@ pub(crate) mod tests { jwt_secret: String::new(), }; - let rate_limit_config = RateLimitConfig::builder().build(); - let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; + let rate_limit_cell = RateLimitCell::with_test_config(); let context = LemmyContext::create(pool, client, secret, rate_limit_cell.clone()); let config = FederationConfig::builder() diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 20611702e..dc9714b0d 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -47,7 +47,7 @@ smart-default = "0.7.1" lettre = { version = "0.10.4", features = ["tokio1", "tokio1-native-tls"] } markdown-it = "0.5.1" ts-rs = { workspace = true, optional = true } -enum-map = "2.6" +enum-map = { workspace = true } [dev-dependencies] reqwest = { workspace = true } diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index 114daf452..63090749b 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,9 +1,9 @@ use crate::error::{LemmyError, LemmyErrorType}; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; -use enum_map::enum_map; +use enum_map::{enum_map, EnumMap}; use futures::future::{ok, Ready}; -use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType}; -use serde::{Deserialize, Serialize}; +pub use rate_limiter::{ActionType, BucketConfig}; +use rate_limiter::{InstantSecs, RateLimitState}; use std::{ future::Future, net::{IpAddr, Ipv4Addr, SocketAddr}, @@ -14,208 +14,140 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tokio::sync::{mpsc, mpsc::Sender, OnceCell}; -use typed_builder::TypedBuilder; pub mod rate_limiter; -#[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)] -pub struct RateLimitConfig { - #[builder(default = 180)] - /// Maximum number of messages created in interval - pub message: i32, - #[builder(default = 60)] - /// Interval length for message limit, in seconds - pub message_per_second: i32, - #[builder(default = 6)] - /// Maximum number of posts created in interval - pub post: i32, - #[builder(default = 300)] - /// Interval length for post limit, in seconds - pub post_per_second: i32, - #[builder(default = 3)] - /// Maximum number of registrations in interval - pub register: i32, - #[builder(default = 3600)] - /// Interval length for registration limit, in seconds - pub register_per_second: i32, - #[builder(default = 6)] - /// Maximum number of image uploads in interval - pub image: i32, - #[builder(default = 3600)] - /// Interval length for image uploads, in seconds - pub image_per_second: i32, - #[builder(default = 6)] - /// Maximum number of comments created in interval - pub comment: i32, - #[builder(default = 600)] - /// Interval length for comment limit, in seconds - pub comment_per_second: i32, - #[builder(default = 60)] - /// Maximum number of searches created in interval - pub search: i32, - #[builder(default = 600)] - /// Interval length for search limit, in seconds - pub search_per_second: i32, - #[builder(default = 1)] - /// Maximum number of user settings imports in interval - pub import_user_settings: i32, - #[builder(default = 24 * 60 * 60)] - /// Interval length for importing user settings, in seconds (defaults to 24 hours) - pub import_user_settings_per_second: i32, -} - #[derive(Debug, Clone)] -struct RateLimit { - pub rate_limiter: RateLimitStorage, - pub rate_limit_config: RateLimitConfig, -} - -#[derive(Debug, Clone)] -pub struct RateLimitedGuard { - rate_limit: Arc>, - type_: RateLimitType, +pub struct RateLimitChecker { + state: Arc>, + action_type: ActionType, } /// Single instance of rate limit config and buckets, which is shared across all threads. #[derive(Clone)] pub struct RateLimitCell { - tx: Sender, - rate_limit: Arc>, + state: Arc>, } impl RateLimitCell { - /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell. - pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self { - static LOCAL_INSTANCE: OnceCell = OnceCell::const_new(); - LOCAL_INSTANCE - .get_or_init(|| async { - let (tx, mut rx) = mpsc::channel::(4); - let rate_limit = Arc::new(Mutex::new(RateLimit { - rate_limiter: Default::default(), - rate_limit_config, - })); - let rate_limit2 = rate_limit.clone(); - tokio::spawn(async move { - while let Some(r) = rx.recv().await { - rate_limit2 - .lock() - .expect("Failed to lock rate limit mutex for updating") - .rate_limit_config = r; - } - }); - RateLimitCell { tx, rate_limit } - }) - .await + pub fn new(rate_limit_config: EnumMap) -> Self { + let state = Arc::new(Mutex::new(RateLimitState::new(rate_limit_config))); + + let state_weak_ref = Arc::downgrade(&state); + + tokio::spawn(async move { + let hour = Duration::from_secs(3600); + + // This loop stops when all other references to `state` are dropped + while let Some(state) = state_weak_ref.upgrade() { + tokio::time::sleep(hour).await; + state + .lock() + .expect("Failed to lock rate limit mutex for reading") + .remove_full_buckets(InstantSecs::now()); + } + }); + + RateLimitCell { state } } - /// Call this when the config was updated, to update all in-memory cells. - pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> { - self.tx.send(config).await?; - Ok(()) - } - - /// Remove buckets older than the given duration - pub fn remove_older_than(&self, mut duration: Duration) { - let mut guard = self - .rate_limit + pub fn set_config(&self, config: EnumMap) { + self + .state .lock() - .expect("Failed to lock rate limit mutex for reading"); - let rate_limit = &guard.rate_limit_config; + .expect("Failed to lock rate limit mutex for updating") + .set_config(config); + } - // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check. - let max_interval_secs = enum_map! { - RateLimitType::Message => rate_limit.message_per_second, - RateLimitType::Post => rate_limit.post_per_second, - RateLimitType::Register => rate_limit.register_per_second, - RateLimitType::Image => rate_limit.image_per_second, - RateLimitType::Comment => rate_limit.comment_per_second, - RateLimitType::Search => rate_limit.search_per_second, - RateLimitType::ImportUserSettings => rate_limit.import_user_settings_per_second + pub fn message(&self) -> RateLimitChecker { + self.new_checker(ActionType::Message) + } + + pub fn post(&self) -> RateLimitChecker { + self.new_checker(ActionType::Post) + } + + pub fn register(&self) -> RateLimitChecker { + self.new_checker(ActionType::Register) + } + + pub fn image(&self) -> RateLimitChecker { + self.new_checker(ActionType::Image) + } + + pub fn comment(&self) -> RateLimitChecker { + self.new_checker(ActionType::Comment) + } + + pub fn search(&self) -> RateLimitChecker { + self.new_checker(ActionType::Search) + } + + pub fn import_user_settings(&self) -> RateLimitChecker { + self.new_checker(ActionType::ImportUserSettings) + } + + fn new_checker(&self, action_type: ActionType) -> RateLimitChecker { + RateLimitChecker { + state: self.state.clone(), + action_type, } - .into_values() - .max() - .and_then(|max| u64::try_from(max).ok()) - .unwrap_or(0); - - duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs)); - - guard - .rate_limiter - .remove_older_than(duration, InstantSecs::now()) } - pub fn message(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Message) - } - - pub fn post(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Post) - } - - pub fn register(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Register) - } - - pub fn image(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Image) - } - - pub fn comment(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Comment) - } - - pub fn search(&self) -> RateLimitedGuard { - self.kind(RateLimitType::Search) - } - - pub fn import_user_settings(&self) -> RateLimitedGuard { - self.kind(RateLimitType::ImportUserSettings) - } - - fn kind(&self, type_: RateLimitType) -> RateLimitedGuard { - RateLimitedGuard { - rate_limit: self.rate_limit.clone(), - type_, - } + pub fn with_test_config() -> Self { + Self::new(enum_map! { + ActionType::Message => BucketConfig { + capacity: 180, + secs_to_refill: 60, + }, + ActionType::Post => BucketConfig { + capacity: 6, + secs_to_refill: 300, + }, + ActionType::Register => BucketConfig { + capacity: 3, + secs_to_refill: 3600, + }, + ActionType::Image => BucketConfig { + capacity: 6, + secs_to_refill: 3600, + }, + ActionType::Comment => BucketConfig { + capacity: 6, + secs_to_refill: 600, + }, + ActionType::Search => BucketConfig { + capacity: 60, + secs_to_refill: 600, + }, + ActionType::ImportUserSettings => BucketConfig { + capacity: 1, + secs_to_refill: 24 * 60 * 60, + }, + }) } } pub struct RateLimitedMiddleware { - rate_limited: RateLimitedGuard, + checker: RateLimitChecker, service: Rc, } -impl RateLimitedGuard { +impl RateLimitChecker { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone - let mut guard = self - .rate_limit + let mut state = self + .state .lock() .expect("Failed to lock rate limit mutex for reading"); - let rate_limit = &guard.rate_limit_config; - let (kind, interval) = match self.type_ { - RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), - RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), - RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), - RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), - RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), - RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second), - RateLimitType::ImportUserSettings => ( - rate_limit.import_user_settings, - rate_limit.import_user_settings_per_second, - ), - }; - let limiter = &mut guard.rate_limiter; - - limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now()) + state.check(self.action_type, ip_addr, InstantSecs::now()) } } -impl Transform for RateLimitedGuard +impl Transform for RateLimitChecker where S: Service + 'static, S::Future: 'static, @@ -228,7 +160,7 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { - rate_limited: self.clone(), + checker: self.clone(), service: Rc::new(service), }) } @@ -252,11 +184,11 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); - let rate_limited = self.rate_limited.clone(); + let checker = self.checker.clone(); let service = self.service.clone(); Box::pin(async move { - if rate_limited.check(ip_addr) { + if checker.check(ip_addr) { service.call(req).await } else { let (http_req, _) = req.into_parts(); diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs index 7ba1345c5..d0dad5df2 100644 --- a/crates/utils/src/rate_limit/rate_limiter.rs +++ b/crates/utils/src/rate_limit/rate_limiter.rs @@ -1,15 +1,13 @@ -use enum_map::{enum_map, EnumMap}; +use enum_map::EnumMap; use once_cell::sync::Lazy; use std::{ collections::HashMap, hash::Hash, net::{IpAddr, Ipv4Addr, Ipv6Addr}, - time::{Duration, Instant}, + time::Instant, }; use tracing::debug; -const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0; - static START_TIME: Lazy = Lazy::new(Instant::now); /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't @@ -26,27 +24,48 @@ impl InstantSecs { .expect("server has been running for over 136 years"), } } - - fn secs_since(self, earlier: Self) -> u32 { - self.secs.saturating_sub(earlier.secs) - } - - fn to_instant(self) -> Instant { - *START_TIME + Duration::from_secs(self.secs.into()) - } } -#[derive(PartialEq, Debug, Clone)] -struct RateLimitBucket { +#[derive(PartialEq, Debug, Clone, Copy)] +struct Bucket { last_checked: InstantSecs, /// This field stores the amount of tokens that were present at `last_checked`. /// The amount of tokens steadily increases until it reaches the bucket's capacity. /// Performing the rate-limited action consumes 1 token. - tokens: f32, + tokens: u32, +} + +#[derive(PartialEq, Debug, Copy, Clone)] +pub struct BucketConfig { + pub capacity: u32, + pub secs_to_refill: u32, +} + +impl Bucket { + fn update(self, now: InstantSecs, config: BucketConfig) -> Self { + let secs_since_last_checked = now.secs.saturating_sub(self.last_checked.secs); + + // For `secs_since_last_checked` seconds, the amount of tokens increases by `capacity` every `secs_to_refill` seconds. + // The amount of tokens added per second is `capacity / secs_to_refill`. + // The expression below is like `secs_since_last_checked * (capacity / secs_to_refill)` but with precision and non-overflowing multiplication. + let added_tokens = u64::from(secs_since_last_checked) * u64::from(config.capacity) + / u64::from(config.secs_to_refill); + + // The amount of tokens there would be if the bucket had infinite capacity + let unbounded_tokens = self.tokens + (added_tokens as u32); + + // Bucket stops filling when capacity is reached + let tokens = std::cmp::min(unbounded_tokens, config.capacity); + + Bucket { + last_checked: now, + tokens, + } + } } #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)] -pub(crate) enum RateLimitType { +pub enum ActionType { Message, Register, Post, @@ -56,179 +75,228 @@ pub(crate) enum RateLimitType { ImportUserSettings, } -type Map = HashMap>; - #[derive(PartialEq, Debug, Clone)] struct RateLimitedGroup { - total: EnumMap, + total: EnumMap, children: C, } +type Map = HashMap>; + +/// Implemented for `()`, `Map`, `Map>`, etc. +trait MapLevel: Default { + type CapacityFactors; + type AddrParts; + + fn check( + &mut self, + action_type: ActionType, + now: InstantSecs, + configs: EnumMap, + capacity_factors: Self::CapacityFactors, + addr_parts: Self::AddrParts, + ) -> bool; + + /// Remove full buckets and return `true` if there's any buckets remaining + fn remove_full_buckets( + &mut self, + now: InstantSecs, + configs: EnumMap, + ) -> bool; +} + +impl MapLevel for Map { + type CapacityFactors = (u32, C::CapacityFactors); + type AddrParts = (K, C::AddrParts); + + fn check( + &mut self, + action_type: ActionType, + now: InstantSecs, + configs: EnumMap, + (capacity_factor, child_capacity_factors): Self::CapacityFactors, + (addr_part, child_addr_parts): Self::AddrParts, + ) -> bool { + // Multiplies capacities by `capacity_factor` for groups in `self` + let adjusted_configs = configs.map(|_, config| BucketConfig { + capacity: config.capacity.saturating_mul(capacity_factor), + ..config + }); + + // Remove groups that are no longer needed if the hash map's existing allocation has no space for new groups. + // This is done before calling `HashMap::entry` because that immediately allocates just like `HashMap::insert`. + if (self.capacity() == self.len()) && !self.contains_key(&addr_part) { + self.remove_full_buckets(now, configs); + } + + let group = self + .entry(addr_part) + .or_insert(RateLimitedGroup::new(now, adjusted_configs)); + + #[allow(clippy::indexing_slicing)] + let total_passes = group.check_total(action_type, now, adjusted_configs[action_type]); + + let children_pass = group.children.check( + action_type, + now, + configs, + child_capacity_factors, + child_addr_parts, + ); + + total_passes && children_pass + } + + fn remove_full_buckets( + &mut self, + now: InstantSecs, + configs: EnumMap, + ) -> bool { + self.retain(|_key, group| { + let some_children_remaining = group.children.remove_full_buckets(now, configs); + + // Evaluated if `some_children_remaining` is false + let total_has_refill_in_future = || { + group.total.into_iter().all(|(action_type, bucket)| { + #[allow(clippy::indexing_slicing)] + let config = configs[action_type]; + bucket.update(now, config).tokens != config.capacity + }) + }; + + some_children_remaining || total_has_refill_in_future() + }); + + self.shrink_to_fit(); + + !self.is_empty() + } +} + +impl MapLevel for () { + type CapacityFactors = (); + type AddrParts = (); + + fn check( + &mut self, + _: ActionType, + _: InstantSecs, + _: EnumMap, + _: Self::CapacityFactors, + _: Self::AddrParts, + ) -> bool { + true + } + + fn remove_full_buckets(&mut self, _: InstantSecs, _: EnumMap) -> bool { + false + } +} + impl RateLimitedGroup { - fn new(now: InstantSecs) -> Self { + fn new(now: InstantSecs, configs: EnumMap) -> Self { RateLimitedGroup { - total: enum_map! { - _ => RateLimitBucket { - last_checked: now, - tokens: UNINITIALIZED_TOKEN_AMOUNT, - }, - }, + total: configs.map(|_, config| Bucket { + last_checked: now, + tokens: config.capacity, + }), + // `HashMap::new()` or `()` children: Default::default(), } } fn check_total( &mut self, - type_: RateLimitType, + action_type: ActionType, now: InstantSecs, - capacity: i32, - secs_to_refill: i32, + config: BucketConfig, ) -> bool { - let capacity = capacity as f32; - let secs_to_refill = secs_to_refill as f32; - #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton - let bucket = &mut self.total[type_]; + let bucket = &mut self.total[action_type]; - if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT { - bucket.tokens = capacity; - } + let new_bucket = bucket.update(now, config); - let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32; - bucket.last_checked = now; - - // For `secs_since_last_checked` seconds, increase `bucket.tokens` - // by `capacity` every `secs_to_refill` seconds - bucket.tokens += { - let tokens_per_sec = capacity / secs_to_refill; - secs_since_last_checked * tokens_per_sec - }; - - // Prevent `bucket.tokens` from exceeding `capacity` - if bucket.tokens > capacity { - bucket.tokens = capacity; - } - - if bucket.tokens < 1.0 { + if new_bucket.tokens == 0 { // Not enough tokens yet - debug!( - "Rate limited type: {}, time_passed: {}, allowance: {}", - type_.as_ref(), - secs_since_last_checked, - bucket.tokens - ); + // Setting `bucket` to `new_bucket` here is useless and would cause the bucket to start over at 0 tokens because of rounding false } else { // Consume 1 token - bucket.tokens -= 1.0; + *bucket = new_bucket; + bucket.tokens -= 1; true } } } /// Rate limiting based on rate type and IP addr -#[derive(PartialEq, Debug, Clone, Default)] -pub struct RateLimitStorage { - /// One bucket per individual IPv4 address +#[derive(PartialEq, Debug, Clone)] +pub struct RateLimitState { + /// Each individual IPv4 address gets one `RateLimitedGroup`. ipv4_buckets: Map, - /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses + /// All IPv6 addresses that share the same first 64 bits share the same `RateLimitedGroup`. + /// + /// The same thing happens for the first 48 and 56 bits, but with increased capacity. + /// + /// This is done because all users can easily switch to any other IPv6 address that has the same first 64 bits. + /// It could be as low as 48 bits for some networks, which is the reason for 48 and 56 bit address groups. ipv6_buckets: Map<[u8; 6], Map>>, + /// This stores a `BucketConfig` for each `ActionType`. `EnumMap` makes it impossible to have a missing `BucketConfig`. + bucket_configs: EnumMap, } -impl RateLimitStorage { +impl RateLimitState { + pub fn new(bucket_configs: EnumMap) -> Self { + RateLimitState { + ipv4_buckets: HashMap::new(), + ipv6_buckets: HashMap::new(), + bucket_configs, + } + } + /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 /// /// Returns true if the request passed the rate limit, false if it failed and should be rejected. - pub(super) fn check_rate_limit_full( - &mut self, - type_: RateLimitType, - ip: IpAddr, - capacity: i32, - secs_to_refill: i32, - now: InstantSecs, - ) -> bool { - let mut result = true; - - match ip { + pub fn check(&mut self, action_type: ActionType, ip: IpAddr, now: InstantSecs) -> bool { + let result = match ip { IpAddr::V4(ipv4) => { - // Only used by one address. - let group = self + self .ipv4_buckets - .entry(ipv4) - .or_insert(RateLimitedGroup::new(now)); - - result &= group.check_total(type_, now, capacity, secs_to_refill); + .check(action_type, now, self.bucket_configs, (1, ()), (ipv4, ())) } IpAddr::V6(ipv6) => { let (key_48, key_56, key_64) = split_ipv6(ipv6); - - // Contains all addresses with the same first 48 bits. These addresses might be part of the same network. - let group_48 = self - .ipv6_buckets - .entry(key_48) - .or_insert(RateLimitedGroup::new(now)); - result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill); - - // Contains all addresses with the same first 56 bits. These addresses might be part of the same network. - let group_56 = group_48 - .children - .entry(key_56) - .or_insert(RateLimitedGroup::new(now)); - result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill); - - // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network. - let group_64 = group_56 - .children - .entry(key_64) - .or_insert(RateLimitedGroup::new(now)); - - result &= group_64.check_total(type_, now, capacity, secs_to_refill); + self.ipv6_buckets.check( + action_type, + now, + self.bucket_configs, + (16, (4, (1, ()))), + (key_48, (key_56, (key_64, ()))), + ) } }; if !result { - debug!("Rate limited IP: {ip}"); + debug!("Rate limited IP: {ip}, type: {action_type:?}"); } result } - /// Remove buckets older than the given duration - pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) { - // Only retain buckets that were last used after `instant` - let Some(instant) = now.to_instant().checked_sub(duration) else { - return; - }; - - let is_recently_used = |group: &RateLimitedGroup<_>| { - group - .total - .values() - .all(|bucket| bucket.last_checked.to_instant() > instant) - }; - - retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group)); - - retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| { - retain_and_shrink(&mut group_48.children, |_, group_56| { - retain_and_shrink(&mut group_56.children, |_, group_64| { - is_recently_used(group_64) - }); - !group_56.children.is_empty() - }); - !group_48.children.is_empty() - }) + /// Remove buckets that are now full + pub fn remove_full_buckets(&mut self, now: InstantSecs) { + self + .ipv4_buckets + .remove_full_buckets(now, self.bucket_configs); + self + .ipv6_buckets + .remove_full_buckets(now, self.bucket_configs); } -} -fn retain_and_shrink(map: &mut HashMap, f: F) -where - K: Eq + Hash, - F: FnMut(&K, &mut V) -> bool, -{ - map.retain(f); - map.shrink_to_fit(); + pub fn set_config(&mut self, new_configs: EnumMap) { + self.bucket_configs = new_configs; + } } fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) { @@ -241,6 +309,8 @@ mod tests { #![allow(clippy::unwrap_used)] #![allow(clippy::indexing_slicing)] + use super::{ActionType, BucketConfig, InstantSecs, RateLimitState, RateLimitedGroup}; + #[test] fn test_split_ipv6() { let ip = std::net::Ipv6Addr::new( @@ -254,9 +324,20 @@ mod tests { #[test] fn test_rate_limiter() { - let mut rate_limiter = super::RateLimitStorage::default(); - let mut now = super::InstantSecs::now(); + let bucket_configs = enum_map::enum_map! { + ActionType::Message => BucketConfig { + capacity: 2, + secs_to_refill: 1, + }, + _ => BucketConfig { + capacity: 2, + secs_to_refill: 1, + }, + }; + let mut rate_limiter = RateLimitState::new(bucket_configs); + let mut now = InstantSecs::now(); + // Do 1 `Message` and 1 `Post` action for each IP address, and expect the limit to not be reached let ips = [ "123.123.123.123", "1:2:3::", @@ -266,66 +347,71 @@ mod tests { ]; for ip in ips { let ip = ip.parse().unwrap(); - let message_passed = - rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now); - let post_passed = - rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now); + let message_passed = rate_limiter.check(ActionType::Message, ip, now); + let post_passed = rate_limiter.check(ActionType::Post, ip, now); assert!(message_passed); assert!(post_passed); } #[allow(clippy::indexing_slicing)] - let expected_buckets = |factor: f32, tokens_consumed: f32| { - let mut buckets = super::RateLimitedGroup::<()>::new(now).total; - buckets[super::RateLimitType::Message] = super::RateLimitBucket { - last_checked: now, - tokens: (2.0 * factor) - tokens_consumed, - }; - buckets[super::RateLimitType::Post] = super::RateLimitBucket { - last_checked: now, - tokens: (3.0 * factor) - tokens_consumed, - }; + let expected_buckets = |factor: u32, tokens_consumed: u32| { + let adjusted_configs = bucket_configs.map(|_, config| BucketConfig { + capacity: config.capacity.saturating_mul(factor), + ..config + }); + let mut buckets = RateLimitedGroup::<()>::new(now, adjusted_configs).total; + buckets[ActionType::Message].tokens -= tokens_consumed; + buckets[ActionType::Post].tokens -= tokens_consumed; buckets }; - let bottom_group = |tokens_consumed| super::RateLimitedGroup { - total: expected_buckets(1.0, tokens_consumed), + let bottom_group = |tokens_consumed| RateLimitedGroup { + total: expected_buckets(1, tokens_consumed), children: (), }; assert_eq!( rate_limiter, - super::RateLimitStorage { - ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(), + RateLimitState { + bucket_configs, + ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1))].into(), ipv6_buckets: [( [0, 1, 0, 2, 0, 3], - super::RateLimitedGroup { - total: expected_buckets(16.0, 4.0), + RateLimitedGroup { + total: expected_buckets(16, 4), children: [ ( 0, - super::RateLimitedGroup { - total: expected_buckets(4.0, 1.0), - children: [(0, bottom_group(1.0)),].into(), + RateLimitedGroup { + total: expected_buckets(4, 1), + children: [(0, bottom_group(1))].into(), } ), ( 4, - super::RateLimitedGroup { - total: expected_buckets(4.0, 3.0), - children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(), + RateLimitedGroup { + total: expected_buckets(4, 3), + children: [(0, bottom_group(1)), (5, bottom_group(2))].into(), } ), ] .into(), } - ),] + )] .into(), } ); + // Do 2 `Message` actions for 1 IP address and expect only the 2nd one to fail + for expected_to_pass in [true, false] { + let ip = "1:2:3:0400::".parse().unwrap(); + let passed = rate_limiter.check(ActionType::Message, ip, now); + assert_eq!(passed, expected_to_pass); + } + + // Expect `remove_full_buckets` to remove everything when called 2 seconds later now.secs += 2; - rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now); + rate_limiter.remove_full_buckets(now); assert!(rate_limiter.ipv4_buckets.is_empty()); assert!(rate_limiter.ipv6_buckets.is_empty()); } diff --git a/src/lib.rs b/src/lib.rs index c093faaca..2df231dd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -156,7 +156,7 @@ pub async fn start_lemmy_server(args: CmdArgs) -> Result<(), LemmyError> { // Set up the rate limiter let rate_limit_config = local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit); - let rate_limit_cell = RateLimitCell::new(rate_limit_config).await; + let rate_limit_cell = RateLimitCell::new(rate_limit_config); println!( "Starting http server at {}:{}", @@ -298,7 +298,7 @@ fn create_http_server( .expect("Should always be buildable"); let context: LemmyContext = federation_config.deref().clone(); - let rate_limit_cell = federation_config.settings_updated_channel().clone(); + let rate_limit_cell = federation_config.rate_limit_cell().clone(); let self_origin = settings.get_protocol_and_hostname(); // Create Http server with websocket support let server = HttpServer::new(move || { diff --git a/src/scheduled_tasks.rs b/src/scheduled_tasks.rs index 99dd16829..8db74ef9d 100644 --- a/src/scheduled_tasks.rs +++ b/src/scheduled_tasks.rs @@ -78,17 +78,6 @@ pub async fn setup(context: LemmyContext) -> Result<(), LemmyError> { } }); - let context_1 = context.clone(); - // Remove old rate limit buckets after 1 to 2 hours of inactivity - scheduler.every(CTimeUnits::hour(1)).run(move || { - let context = context_1.clone(); - - async move { - let hour = Duration::from_secs(3600); - context.settings_updated_channel().remove_older_than(hour); - } - }); - let context_1 = context.clone(); // Overwrite deleted & removed posts and comments every day scheduler.every(CTimeUnits::days(1)).run(move || { diff --git a/src/session_middleware.rs b/src/session_middleware.rs index ae82cd44d..f50e0eccd 100644 --- a/src/session_middleware.rs +++ b/src/session_middleware.rs @@ -112,7 +112,7 @@ mod tests { traits::Crud, utils::build_db_pool_for_tests, }; - use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig}; + use lemmy_utils::rate_limit::RateLimitCell; use reqwest::Client; use reqwest_middleware::ClientBuilder; use serial_test::serial; @@ -131,9 +131,7 @@ mod tests { pool_.clone(), ClientBuilder::new(Client::default()).build(), secret, - RateLimitCell::new(RateLimitConfig::builder().build()) - .await - .clone(), + RateLimitCell::with_test_config(), ); let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string())