2022-10-18 18:16:49 +00:00
|
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
|
2022-03-03 03:30:44 +00:00
|
|
|
use anyhow::Error;
|
|
|
|
use async_tungstenite::tungstenite::Message as WsMessage;
|
|
|
|
use futures::channel::mpsc;
|
|
|
|
use futures::prelude::*;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use std::pin::Pin;
|
|
|
|
use std::sync::{Arc, Mutex};
|
2022-12-27 02:12:39 +00:00
|
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
|
|
use tokio::task;
|
2022-03-03 03:30:44 +00:00
|
|
|
use tracing::{info, instrument, trace, warn};
|
|
|
|
|
|
|
|
struct Peer {
|
|
|
|
receive_task_handle: task::JoinHandle<()>,
|
|
|
|
send_task_handle: task::JoinHandle<Result<(), Error>>,
|
|
|
|
sender: mpsc::Sender<String>,
|
|
|
|
}
|
|
|
|
|
|
|
|
struct State {
|
|
|
|
tx: Option<mpsc::Sender<(String, Option<String>)>>,
|
|
|
|
peers: HashMap<String, Peer>,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
pub struct Server {
|
|
|
|
state: Arc<Mutex<State>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
|
|
pub enum ServerError {
|
|
|
|
#[error("error during handshake {0}")]
|
|
|
|
Handshake(#[from] async_tungstenite::tungstenite::Error),
|
2023-12-20 12:31:11 +00:00
|
|
|
#[error("error during TLS handshake {0}")]
|
|
|
|
TLSHandshake(#[from] tokio_native_tls::native_tls::Error),
|
|
|
|
#[error("timeout during TLS handshake {0}")]
|
|
|
|
TLSHandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
2022-03-03 03:30:44 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Server {
|
|
|
|
#[instrument(level = "debug", skip(factory))]
|
|
|
|
pub fn spawn<
|
|
|
|
I: for<'a> Deserialize<'a>,
|
2023-06-07 23:51:45 +00:00
|
|
|
O: Serialize + std::fmt::Debug + Send + Sync,
|
2022-03-03 03:30:44 +00:00
|
|
|
Factory: FnOnce(Pin<Box<dyn Stream<Item = (String, Option<I>)> + Send>>) -> St,
|
|
|
|
St: Stream<Item = (String, O)>,
|
|
|
|
>(
|
|
|
|
factory: Factory,
|
|
|
|
) -> Self
|
|
|
|
where
|
|
|
|
O: Serialize + std::fmt::Debug,
|
|
|
|
St: Send + Unpin + 'static,
|
|
|
|
{
|
|
|
|
let (tx, rx) = mpsc::channel::<(String, Option<String>)>(1000);
|
|
|
|
let mut handler = factory(Box::pin(rx.filter_map(|(peer_id, msg)| async move {
|
|
|
|
if let Some(msg) = msg {
|
|
|
|
match serde_json::from_str::<I>(&msg) {
|
|
|
|
Ok(msg) => Some((peer_id, Some(msg))),
|
|
|
|
Err(err) => {
|
2022-03-23 00:33:00 +00:00
|
|
|
warn!("Failed to parse incoming message: {} ({})", err, msg);
|
2022-03-03 03:30:44 +00:00
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
Some((peer_id, None))
|
|
|
|
}
|
|
|
|
})));
|
|
|
|
|
|
|
|
let state = Arc::new(Mutex::new(State {
|
|
|
|
tx: Some(tx),
|
|
|
|
peers: HashMap::new(),
|
|
|
|
}));
|
|
|
|
|
|
|
|
let state_clone = state.clone();
|
2023-01-25 08:23:46 +00:00
|
|
|
task::spawn(async move {
|
2022-03-03 03:30:44 +00:00
|
|
|
while let Some((peer_id, msg)) = handler.next().await {
|
|
|
|
match serde_json::to_string(&msg) {
|
2023-06-07 23:51:45 +00:00
|
|
|
Ok(msg_str) => {
|
|
|
|
let sender = {
|
|
|
|
let mut state = state_clone.lock().unwrap();
|
|
|
|
if let Some(peer) = state.peers.get_mut(&peer_id) {
|
|
|
|
Some(peer.sender.clone())
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
if let Some(mut sender) = sender {
|
|
|
|
trace!("Sending {}", msg_str);
|
|
|
|
let _ = sender.send(msg_str).await;
|
2022-03-03 03:30:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
Err(err) => {
|
|
|
|
warn!("Failed to serialize outgoing message: {}", err);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
Self { state }
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(level = "debug", skip(state))]
|
|
|
|
fn remove_peer(state: Arc<Mutex<State>>, peer_id: &str) {
|
|
|
|
if let Some(mut peer) = state.lock().unwrap().peers.remove(peer_id) {
|
|
|
|
let peer_id = peer_id.to_string();
|
|
|
|
task::spawn(async move {
|
|
|
|
peer.sender.close_channel();
|
|
|
|
if let Err(err) = peer.send_task_handle.await {
|
|
|
|
trace!(peer_id = %peer_id, "Error while joining send task: {}", err);
|
|
|
|
}
|
2022-12-12 13:36:11 +00:00
|
|
|
|
|
|
|
if let Err(err) = peer.receive_task_handle.await {
|
|
|
|
trace!(peer_id = %peer_id, "Error while joining receive task: {}", err);
|
|
|
|
}
|
2022-03-03 03:30:44 +00:00
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(level = "debug", skip(self, stream))]
|
|
|
|
pub async fn accept_async<S: 'static>(&mut self, stream: S) -> Result<String, ServerError>
|
|
|
|
where
|
|
|
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
|
|
|
{
|
2022-12-12 13:36:11 +00:00
|
|
|
let ws = match async_tungstenite::tokio::accept_async(stream).await {
|
2022-03-03 03:30:44 +00:00
|
|
|
Ok(ws) => ws,
|
|
|
|
Err(err) => {
|
|
|
|
warn!("Error during the websocket handshake: {}", err);
|
|
|
|
return Err(ServerError::Handshake(err));
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
let this_id = uuid::Uuid::new_v4().to_string();
|
|
|
|
info!(this_id = %this_id, "New WebSocket connection");
|
|
|
|
|
|
|
|
// 1000 is completely arbitrary, we simply don't want infinite piling
|
|
|
|
// up of messages as with unbounded
|
|
|
|
let (websocket_sender, mut websocket_receiver) = mpsc::channel::<String>(1000);
|
|
|
|
|
|
|
|
let this_id_clone = this_id.clone();
|
|
|
|
let (mut ws_sink, mut ws_stream) = ws.split();
|
|
|
|
let send_task_handle = task::spawn(async move {
|
2022-05-25 16:09:33 +00:00
|
|
|
loop {
|
2022-12-12 13:36:11 +00:00
|
|
|
match tokio::time::timeout(
|
2022-05-25 16:09:33 +00:00
|
|
|
std::time::Duration::from_secs(30),
|
|
|
|
websocket_receiver.next(),
|
|
|
|
)
|
|
|
|
.await
|
|
|
|
{
|
|
|
|
Ok(Some(msg)) => {
|
|
|
|
trace!(this_id = %this_id_clone, "sending {}", msg);
|
|
|
|
ws_sink.send(WsMessage::Text(msg)).await?;
|
|
|
|
}
|
|
|
|
Ok(None) => {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
Err(_) => {
|
|
|
|
trace!(this_id = %this_id_clone, "timeout, sending ping");
|
|
|
|
ws_sink.send(WsMessage::Ping(vec![])).await?;
|
|
|
|
}
|
|
|
|
}
|
2022-03-03 03:30:44 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
ws_sink.send(WsMessage::Close(None)).await?;
|
|
|
|
ws_sink.close().await?;
|
|
|
|
|
|
|
|
Ok::<(), Error>(())
|
|
|
|
});
|
|
|
|
|
|
|
|
let mut tx = self.state.lock().unwrap().tx.clone();
|
|
|
|
let this_id_clone = this_id.clone();
|
|
|
|
let state_clone = self.state.clone();
|
|
|
|
let receive_task_handle = task::spawn(async move {
|
2022-08-18 16:05:16 +00:00
|
|
|
if let Some(tx) = tx.as_mut() {
|
|
|
|
if let Err(err) = tx
|
|
|
|
.send((
|
|
|
|
this_id_clone.clone(),
|
|
|
|
Some(
|
|
|
|
serde_json::json!({
|
|
|
|
"type": "newPeer",
|
|
|
|
})
|
|
|
|
.to_string(),
|
|
|
|
),
|
|
|
|
))
|
|
|
|
.await
|
|
|
|
{
|
|
|
|
warn!(this = %this_id_clone, "Error handling message: {:?}", err);
|
|
|
|
}
|
|
|
|
}
|
2022-03-03 03:30:44 +00:00
|
|
|
while let Some(msg) = ws_stream.next().await {
|
2022-08-18 16:05:16 +00:00
|
|
|
info!("Received message {msg:?}");
|
2022-03-03 03:30:44 +00:00
|
|
|
match msg {
|
|
|
|
Ok(WsMessage::Text(msg)) => {
|
|
|
|
if let Some(tx) = tx.as_mut() {
|
|
|
|
if let Err(err) = tx.send((this_id_clone.clone(), Some(msg))).await {
|
|
|
|
warn!(this = %this_id_clone, "Error handling message: {:?}", err);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(WsMessage::Close(reason)) => {
|
|
|
|
info!(this_id = %this_id_clone, "connection closed: {:?}", reason);
|
|
|
|
break;
|
|
|
|
}
|
2022-05-25 16:09:33 +00:00
|
|
|
Ok(WsMessage::Pong(_)) => {
|
|
|
|
continue;
|
|
|
|
}
|
2022-03-03 03:30:44 +00:00
|
|
|
Ok(_) => warn!(this_id = %this_id_clone, "Unsupported message type"),
|
|
|
|
Err(err) => {
|
|
|
|
warn!(this_id = %this_id_clone, "recv error: {}", err);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(tx) = tx.as_mut() {
|
|
|
|
let _ = tx.send((this_id_clone.clone(), None)).await;
|
|
|
|
}
|
|
|
|
|
|
|
|
Self::remove_peer(state_clone, &this_id_clone);
|
|
|
|
});
|
|
|
|
|
|
|
|
self.state.lock().unwrap().peers.insert(
|
|
|
|
this_id.clone(),
|
|
|
|
Peer {
|
|
|
|
receive_task_handle,
|
|
|
|
send_task_handle,
|
|
|
|
sender: websocket_sender,
|
|
|
|
},
|
|
|
|
);
|
|
|
|
|
|
|
|
Ok(this_id)
|
|
|
|
}
|
|
|
|
}
|