diff --git a/net/webrtc/signalling/src/bin/server.rs b/net/webrtc/signalling/src/bin/server.rs index fe3c6925..d2f8e50b 100644 --- a/net/webrtc/signalling/src/bin/server.rs +++ b/net/webrtc/signalling/src/bin/server.rs @@ -2,17 +2,20 @@ use clap::Parser; use gst_plugin_webrtc_signalling::handlers::Handler; -use gst_plugin_webrtc_signalling::server::Server; +use gst_plugin_webrtc_signalling::server::{Server, ServerError}; use tokio::io::AsyncReadExt; use tokio::task; use tracing_subscriber::prelude::*; use anyhow::Error; +use std::time::Duration; use tokio::fs; use tokio::net::TcpListener; use tokio_native_tls::native_tls::TlsAcceptor; use tracing::{info, warn}; +const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5); + #[derive(Parser, Debug)] #[clap(about, version, author)] /// Program arguments @@ -94,15 +97,20 @@ async fn main() -> Result<(), Error> { info!("Accepting connection from {}", address); - if let Some(ref acceptor) = acceptor { - let stream = match acceptor.accept(stream).await { - Ok(stream) => stream, - Err(err) => { - warn!("Failed to accept TLS connection from {}: {}", address, err); - continue; + if let Some(acceptor) = acceptor.clone() { + tokio::spawn(async move { + match tokio::time::timeout(TLS_HANDSHAKE_TIMEOUT, acceptor.accept(stream)).await { + Ok(Ok(stream)) => server_clone.accept_async(stream).await, + Ok(Err(err)) => { + warn!("Failed to accept TLS connection from {}: {}", address, err); + Err(ServerError::TLSHandshake(err)) + } + Err(elapsed) => { + warn!("TLS connection timed out {} after {}", address, elapsed); + Err(ServerError::TLSHandshakeTimeout(elapsed)) + } } - }; - task::spawn(async move { server_clone.accept_async(stream).await }); + }); } else { task::spawn(async move { server_clone.accept_async(stream).await }); } diff --git a/net/webrtc/signalling/src/server/mod.rs b/net/webrtc/signalling/src/server/mod.rs index 0510437c..e0b034a3 100644 --- a/net/webrtc/signalling/src/server/mod.rs +++ b/net/webrtc/signalling/src/server/mod.rs @@ -32,6 +32,10 @@ pub struct Server { pub enum ServerError { #[error("error during handshake {0}")] Handshake(#[from] async_tungstenite::tungstenite::Error), + #[error("error during TLS handshake {0}")] + TLSHandshake(#[from] tokio_native_tls::native_tls::Error), + #[error("timeout during TLS handshake {0}")] + TLSHandshakeTimeout(#[from] tokio::time::error::Elapsed), } impl Server {