Move usage of env::var to lemmy_utils, simplify db init (ref #4095) (#4108)

This commit is contained in:
Nutomic 2023-10-25 17:34:38 +02:00 committed by GitHub
parent b63836b31b
commit 08739e2925
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 80 deletions

View file

@ -30,7 +30,7 @@ use diesel_migrations::EmbeddedMigrations;
use futures_util::{future::BoxFuture, Future, FutureExt}; use futures_util::{future::BoxFuture, Future, FutureExt};
use lemmy_utils::{ use lemmy_utils::{
error::{LemmyError, LemmyErrorExt, LemmyErrorType}, error::{LemmyError, LemmyErrorExt, LemmyErrorType},
settings::structs::Settings, settings::SETTINGS,
}; };
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use regex::Regex; use regex::Regex;
@ -39,8 +39,6 @@ use rustls::{
ServerName, ServerName,
}; };
use std::{ use std::{
env,
env::VarError,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@ -149,10 +147,6 @@ macro_rules! try_join_with_pool {
}}; }};
} }
pub fn get_database_url_from_env() -> Result<String, VarError> {
env::var("LEMMY_DATABASE_URL")
}
pub fn fuzzy_search(q: &str) -> String { pub fn fuzzy_search(q: &str) -> String {
let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%"); let replaced = q.replace('%', "\\%").replace('_', "\\_").replace(' ', "%");
format!("%{replaced}%") format!("%{replaced}%")
@ -238,36 +232,6 @@ pub fn diesel_option_overwrite_to_url_create(
} }
} }
async fn build_db_pool_settings_opt(
settings: Option<&Settings>,
) -> Result<ActualDbPool, LemmyError> {
let db_url = get_database_url(settings);
let pool_size = settings.map(|s| s.database.pool_size).unwrap_or(5);
// We only support TLS with sslmode=require currently
let tls_enabled = db_url.contains("sslmode=require");
let manager = if tls_enabled {
// diesel-async does not support any TLS connections out of the box, so we need to manually
// provide a setup function which handles creating the connection
AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
} else {
AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
};
let pool = Pool::builder(manager)
.max_size(pool_size)
.wait_timeout(POOL_TIMEOUT)
.create_timeout(POOL_TIMEOUT)
.recycle_timeout(POOL_TIMEOUT)
.runtime(Runtime::Tokio1)
.build()?;
// If there's no settings, that means its a unit test, and migrations need to be run
if settings.is_none() {
run_migrations(&db_url);
}
Ok(pool)
}
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> { fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
let fut = async { let fut = async {
let rustls_config = rustls::ClientConfig::builder() let rustls_config = rustls::ClientConfig::builder()
@ -308,7 +272,7 @@ impl ServerCertVerifier for NoCertVerifier {
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
pub fn run_migrations(db_url: &str) { fn run_migrations(db_url: &str) {
// Needs to be a sync connection // Needs to be a sync connection
let mut conn = let mut conn =
PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}")); PgConnection::establish(db_url).unwrap_or_else(|e| panic!("Error connecting to {db_url}: {e}"));
@ -319,29 +283,36 @@ pub fn run_migrations(db_url: &str) {
info!("Database migrations complete."); info!("Database migrations complete.");
} }
pub async fn build_db_pool(settings: &Settings) -> Result<ActualDbPool, LemmyError> { pub async fn build_db_pool() -> Result<ActualDbPool, LemmyError> {
build_db_pool_settings_opt(Some(settings)).await let db_url = SETTINGS.get_database_url();
// We only support TLS with sslmode=require currently
let tls_enabled = db_url.contains("sslmode=require");
let manager = if tls_enabled {
// diesel-async does not support any TLS connections out of the box, so we need to manually
// provide a setup function which handles creating the connection
AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_setup(&db_url, establish_connection)
} else {
AsyncDieselConnectionManager::<AsyncPgConnection>::new(&db_url)
};
let pool = Pool::builder(manager)
.max_size(SETTINGS.database.pool_size)
.wait_timeout(POOL_TIMEOUT)
.create_timeout(POOL_TIMEOUT)
.recycle_timeout(POOL_TIMEOUT)
.runtime(Runtime::Tokio1)
.build()?;
run_migrations(&db_url);
Ok(pool)
} }
pub async fn build_db_pool_for_tests() -> ActualDbPool { pub async fn build_db_pool_for_tests() -> ActualDbPool {
build_db_pool_settings_opt(None) build_db_pool().await.expect("db pool missing")
.await
.expect("db pool missing")
}
pub fn get_database_url(settings: Option<&Settings>) -> String {
// The env var should override anything in the settings config
match get_database_url_from_env() {
Ok(url) => url,
Err(e) => match settings {
Some(settings) => settings.get_database_url(),
None => panic!("Failed to read database URL from env var LEMMY_DATABASE_URL: {e}"),
},
}
} }
pub fn naive_now() -> DateTime<Utc> { pub fn naive_now() -> DateTime<Utc> {
chrono::prelude::Utc::now() Utc::now()
} }
pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType { pub fn post_to_comment_sort_type(sort: SortType) -> CommentSortType {

View file

@ -75,10 +75,7 @@ pub async fn send_email(
}; };
// Set the creds if they exist // Set the creds if they exist
let smtp_password = std::env::var("LEMMY_SMTP_PASSWORD") let smtp_password = email_config.smtp_password();
.ok()
.or(email_config.smtp_password);
if let (Some(username), Some(password)) = (email_config.smtp_login, smtp_password) { if let (Some(username), Some(password)) = (email_config.smtp_login, smtp_password) {
builder = builder.credentials(Credentials::new(username, password)); builder = builder.credentials(Credentials::new(username, password));
} }

View file

@ -45,6 +45,9 @@ impl Settings {
} }
pub fn get_database_url(&self) -> String { pub fn get_database_url(&self) -> String {
if let Ok(url) = env::var("LEMMY_DATABASE_URL") {
return url;
}
match &self.database.connection { match &self.database.connection {
DatabaseConnection::Uri { uri } => uri.clone(), DatabaseConnection::Uri { uri } => uri.clone(),
DatabaseConnection::Parts(parts) => { DatabaseConnection::Parts(parts) => {

View file

@ -1,6 +1,9 @@
use doku::Document; use doku::Document;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::{IpAddr, Ipv4Addr}; use std::{
env,
net::{IpAddr, Ipv4Addr},
};
use url::Url; use url::Url;
#[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)]
@ -53,7 +56,15 @@ pub struct Settings {
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
#[default(None)] #[default(None)]
#[doku(example = "*")] #[doku(example = "*")]
pub cors_origin: Option<String>, cors_origin: Option<String>,
}
impl Settings {
pub fn cors_origin(&self) -> Option<String> {
env::var("LEMMY_CORS_ORIGIN")
.ok()
.or(self.cors_origin.clone())
}
} }
#[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)]
@ -77,7 +88,7 @@ pub struct PictrsConfig {
#[serde(default)] #[serde(default)]
pub struct DatabaseConfig { pub struct DatabaseConfig {
#[serde(flatten, default)] #[serde(flatten, default)]
pub connection: DatabaseConnection, pub(crate) connection: DatabaseConnection,
/// Maximum number of active sql connections /// Maximum number of active sql connections
#[default(95)] #[default(95)]
@ -122,10 +133,10 @@ pub struct DatabaseConnectionParts {
pub(super) user: String, pub(super) user: String,
/// Password to connect to postgres /// Password to connect to postgres
#[default("password")] #[default("password")]
pub password: String, pub(super) password: String,
#[default("localhost")] #[default("localhost")]
/// Host where postgres is running /// Host where postgres is running
pub host: String, pub(super) host: String,
/// Port where postgres can be accessed /// Port where postgres can be accessed
#[default(5432)] #[default(5432)]
pub(super) port: i32, pub(super) port: i32,
@ -143,7 +154,7 @@ pub struct EmailConfig {
/// Login name for smtp server /// Login name for smtp server
pub smtp_login: Option<String>, pub smtp_login: Option<String>,
/// Password to login to the smtp server /// Password to login to the smtp server
pub smtp_password: Option<String>, smtp_password: Option<String>,
#[doku(example = "noreply@example.com")] #[doku(example = "noreply@example.com")]
/// Address to send emails from, eg "noreply@your-instance.com" /// Address to send emails from, eg "noreply@your-instance.com"
pub smtp_from_address: String, pub smtp_from_address: String,
@ -153,6 +164,14 @@ pub struct EmailConfig {
pub tls_type: String, pub tls_type: String,
} }
impl EmailConfig {
pub fn smtp_password(&self) -> Option<String> {
std::env::var("LEMMY_SMTP_PASSWORD")
.ok()
.or(self.smtp_password.clone())
}
}
#[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)] #[derive(Debug, Deserialize, Serialize, Clone, SmartDefault, Document)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct SetupConfig { pub struct SetupConfig {

View file

@ -39,10 +39,7 @@ use lemmy_apub::{
VerifyUrlData, VerifyUrlData,
FEDERATION_HTTP_FETCH_LIMIT, FEDERATION_HTTP_FETCH_LIMIT,
}; };
use lemmy_db_schema::{ use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool};
source::secret::Secret,
utils::{build_db_pool, get_database_url, run_migrations},
};
use lemmy_federate::{start_stop_federation_workers_cancellable, Opts}; use lemmy_federate::{start_stop_federation_workers_cancellable, Opts};
use lemmy_routes::{feeds, images, nodeinfo, webfinger}; use lemmy_routes::{feeds, images, nodeinfo, webfinger};
use lemmy_utils::{ use lemmy_utils::{
@ -114,12 +111,8 @@ pub async fn start_lemmy_server(args: CmdArgs) -> Result<(), LemmyError> {
startup_server_handle = Some(create_startup_server()?); startup_server_handle = Some(create_startup_server()?);
} }
// Run the DB migrations
let db_url = get_database_url(Some(&SETTINGS));
run_migrations(&db_url);
// Set up the connection pool // Set up the connection pool
let pool = build_db_pool(&SETTINGS).await?; let pool = build_db_pool().await?;
// Run the Code-required migrations // Run the Code-required migrations
run_advanced_migrations(&mut (&pool).into(), &SETTINGS).await?; run_advanced_migrations(&mut (&pool).into(), &SETTINGS).await?;
@ -282,13 +275,10 @@ fn create_http_server(
let context: LemmyContext = federation_config.deref().clone(); let context: LemmyContext = federation_config.deref().clone();
let rate_limit_cell = federation_config.rate_limit_cell().clone(); let rate_limit_cell = federation_config.rate_limit_cell().clone();
let self_origin = settings.get_protocol_and_hostname(); let self_origin = settings.get_protocol_and_hostname();
let cors_origin_setting = settings.cors_origin; let cors_origin_setting = settings.cors_origin();
// Create Http server with websocket support // Create Http server with websocket support
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
let cors_origin = env::var("LEMMY_CORS_ORIGIN") let cors_config = match (cors_origin_setting.clone(), cfg!(debug_assertions)) {
.ok()
.or(cors_origin_setting.clone());
let cors_config = match (cors_origin, cfg!(debug_assertions)) {
(Some(origin), false) => Cors::default() (Some(origin), false) => Cors::default()
.allowed_origin(&origin) .allowed_origin(&origin)
.allowed_origin(&self_origin), .allowed_origin(&self_origin),
@ -341,7 +331,7 @@ fn create_http_server(
pub fn init_logging(opentelemetry_url: &Option<Url>) -> Result<(), LemmyError> { pub fn init_logging(opentelemetry_url: &Option<Url>) -> Result<(), LemmyError> {
LogTracer::init()?; LogTracer::init()?;
let log_description = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()); let log_description = env::var("RUST_LOG").unwrap_or_else(|_| "info".into());
let targets = log_description let targets = log_description
.trim() .trim()