diff --git a/crates/db_schema/src/utils.rs b/crates/db_schema/src/utils.rs index 7e83569a7..1ebdd36e2 100644 --- a/crates/db_schema/src/utils.rs +++ b/crates/db_schema/src/utils.rs @@ -30,7 +30,7 @@ use diesel_migrations::EmbeddedMigrations; use futures_util::{future::BoxFuture, Future, FutureExt}; use lemmy_utils::{ error::{LemmyError, LemmyErrorExt, LemmyErrorType}, - settings::structs::Settings, + settings::SETTINGS, }; use once_cell::sync::Lazy; use regex::Regex; @@ -39,8 +39,6 @@ use rustls::{ ServerName, }; use std::{ - env, - env::VarError, ops::{Deref, DerefMut}, sync::Arc, time::{Duration, SystemTime}, @@ -149,10 +147,6 @@ macro_rules! try_join_with_pool { }}; } -pub fn get_database_url_from_env() -> Result { - env::var("LEMMY_DATABASE_URL") -} - pub fn fuzzy_search(q: &str) -> String { let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%"); format!("%{replaced}%") @@ -238,36 +232,6 @@ pub fn diesel_option_overwrite_to_url_create( } } -async fn build_db_pool_settings_opt( - settings: Option<&Settings>, -) -> Result { - let db_url = get_database_url(settings); - let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5); - // We only support TLS with sslmode=require currently - let tls_enabled = db_url.contains("sslmode=require"); - let manager = if tls_enabled { - // diesel-async does not support any TLS connections out of the box, so we need to manually - // provide a setup function which handles creating the connection - AsyncDieselConnectionManager::::new_with_setup(&db_url, establish_connection) - } else { - AsyncDieselConnectionManager::::new(&db_url) - }; - let pool = Pool::builder(manager) - .max_size(pool_size) - .wait_timeout(POOL_TIMEOUT) - .create_timeout(POOL_TIMEOUT) - .recycle_timeout(POOL_TIMEOUT) - .runtime(Runtime::Tokio1) - .build()?; - - // If there's no settings, that means its a unit test, and migrations need to be run - if settings.is_none() { - run_migrations(&db_url); - } - - Ok(pool) -} - fn establish_connection(config: &str) -> BoxFuture> { let fut = async { let rustls_config = rustls::ClientConfig::builder() @@ -308,7 +272,7 @@ impl ServerCertVerifier for NoCertVerifier { pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); -pub fn run_migrations(db_url: &str) { +fn run_migrations(db_url: &str) { // Needs to be a sync connection let mut conn = PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}")); @@ -319,29 +283,36 @@ pub fn run_migrations(db_url: &str) { info!("Database migrations complete."); } -pub async fn build_db_pool(settings: &Settings) -> Result { - build_db_pool_settings_opt(Some(settings)).await +pub async fn build_db_pool() -> Result { + let db_url = SETTINGS.get_database_url(); + // We only support TLS with sslmode=require currently + let tls_enabled = db_url.contains("sslmode=require"); + let manager = if tls_enabled { + // diesel-async does not support any TLS connections out of the box, so we need to manually + // provide a setup function which handles creating the connection + AsyncDieselConnectionManager::::new_with_setup(&db_url, establish_connection) + } else { + AsyncDieselConnectionManager::::new(&db_url) + }; + let pool = Pool::builder(manager) + .max_size(SETTINGS.database.pool_size) + .wait_timeout(POOL_TIMEOUT) + .create_timeout(POOL_TIMEOUT) + .recycle_timeout(POOL_TIMEOUT) + .runtime(Runtime::Tokio1) + .build()?; + + run_migrations(&db_url); + + Ok(pool) } pub async fn build_db_pool_for_tests() -> ActualDbPool { - build_db_pool_settings_opt(None) - .await - .expect("db pool missing") -} - -pub fn get_database_url(settings: Option<&Settings>) -> String { - // The env var should override anything in the settings config - match get_database_url_from_env() { - Ok(url) => url, - Err(e) => match settings { - Some(settings) => settings.get_database_url(), - None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"), - }, - } + build_db_pool().await.expect("db pool missing") } pub fn naive_now() -> DateTime { - chrono::prelude::Utc::now() + Utc::now() } pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType { diff --git a/crates/utils/src/email.rs b/crates/utils/src/email.rs index fdff19033..1a786b0ef 100644 --- a/crates/utils/src/email.rs +++ b/crates/utils/src/email.rs @@ -75,10 +75,7 @@ pub async fn send_email( }; // Set the creds if they exist - let smtp_password = std::env::var("LEMMY_SMTP_PASSWORD") - .ok() - .or(email_config.smtp_password); - + let smtp_password = email_config.smtp_password(); if let (Some(username), Some(password)) = (email_config.smtp_login, smtp_password) { builder = builder.credentials(Credentials::new(username, password)); } diff --git a/crates/utils/src/settings/mod.rs b/crates/utils/src/settings/mod.rs index 6b8982a11..25aa7206d 100644 --- a/crates/utils/src/settings/mod.rs +++ b/crates/utils/src/settings/mod.rs @@ -45,6 +45,9 @@ impl Settings { } pub fn get_database_url(&self) -> String { + if let Ok(url) = env::var("LEMMY_DATABASE_URL") { + return url; + } match &self.database.connection { DatabaseConnection::Uri { uri } => uri.clone(), DatabaseConnection::Parts(parts) => { diff --git a/crates/utils/src/settings/structs.rs b/crates/utils/src/settings/structs.rs index 4c39e08aa..a31b3605e 100644 --- a/crates/utils/src/settings/structs.rs +++ b/crates/utils/src/settings/structs.rs @@ -1,6 +1,9 @@ use doku::Document; use serde::{Deserialize, Serialize}; -use std::net::{IpAddr, Ipv4Addr}; +use std::{ + env, + net::{IpAddr, Ipv4Addr}, +}; use url::Url; #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] @@ -53,7 +56,15 @@ pub struct Settings { /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin #[default(None)] #[doku(example = "*")] - pub cors_origin: Option, + cors_origin: Option, +} + +impl Settings { + pub fn cors_origin(&self) -> Option { + env::var("LEMMY_CORS_ORIGIN") + .ok() + .or(self.cors_origin.clone()) + } } #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] @@ -77,7 +88,7 @@ pub struct PictrsConfig { #[serde(default)] pub struct DatabaseConfig { #[serde(flatten, default)] - pub connection: DatabaseConnection, + pub(crate) connection: DatabaseConnection, /// Maximum number of active sql connections #[default(95)] @@ -122,10 +133,10 @@ pub struct DatabaseConnectionParts { pub(super) user: String, /// Password to connect to postgres #[default("password")] - pub password: String, + pub(super) password: String, #[default("localhost")] /// Host where postgres is running - pub host: String, + pub(super) host: String, /// Port where postgres can be accessed #[default(5432)] pub(super) port: i32, @@ -143,7 +154,7 @@ pub struct EmailConfig { /// Login name for smtp server pub smtp_login: Option, /// Password to login to the smtp server - pub smtp_password: Option, + smtp_password: Option, #[doku(example = "noreply@example.com")] /// Address to send emails from, eg "noreply@your-instance.com" pub smtp_from_address: String, @@ -153,6 +164,14 @@ pub struct EmailConfig { pub tls_type: String, } +impl EmailConfig { + pub fn smtp_password(&self) -> Option { + std::env::var("LEMMY_SMTP_PASSWORD") + .ok() + .or(self.smtp_password.clone()) + } +} + #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] #[serde(deny_unknown_fields)] pub struct SetupConfig { diff --git a/src/lib.rs b/src/lib.rs index 6e62a6803..ec2a8fdae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,10 +39,7 @@ use lemmy_apub::{ VerifyUrlData, FEDERATION_HTTP_FETCH_LIMIT, }; -use lemmy_db_schema::{ - source::secret::Secret, - utils::{build_db_pool, get_database_url, run_migrations}, -}; +use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool}; use lemmy_federate::{start_stop_federation_workers_cancellable, Opts}; use lemmy_routes::{feeds, images, nodeinfo, webfinger}; use lemmy_utils::{ @@ -114,12 +111,8 @@ pub async fn start_lemmy_server(args: CmdArgs) -> Result<(), LemmyError> { startup_server_handle = Some(create_startup_server()?); } - // Run the DB migrations - let db_url = get_database_url(Some(&SETTINGS)); - run_migrations(&db_url); - // Set up the connection pool - let pool = build_db_pool(&SETTINGS).await?; + let pool = build_db_pool().await?; // Run the Code-required migrations run_advanced_migrations(&mut (&pool).into(), &SETTINGS).await?; @@ -282,13 +275,10 @@ fn create_http_server( let context: LemmyContext = federation_config.deref().clone(); let rate_limit_cell = federation_config.rate_limit_cell().clone(); let self_origin = settings.get_protocol_and_hostname(); - let cors_origin_setting = settings.cors_origin; + let cors_origin_setting = settings.cors_origin(); // Create Http server with websocket support let server = HttpServer::new(move || { - let cors_origin = env::var("LEMMY_CORS_ORIGIN") - .ok() - .or(cors_origin_setting.clone()); - let cors_config = match (cors_origin, cfg!(debug_assertions)) { + let cors_config = match (cors_origin_setting.clone(), cfg!(debug_assertions)) { (Some(origin), false) => Cors::default() .allowed_origin(&origin) .allowed_origin(&self_origin), @@ -341,7 +331,7 @@ fn create_http_server( pub fn init_logging(opentelemetry_url: &Option) -> Result<(), LemmyError> { LogTracer::init()?; - let log_description = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()); + let log_description = env::var("RUST_LOG").unwrap_or_else(|_| "info".into()); let targets = log_description .trim()