From 730b3459f10bdabd5918b8e89c4814c42aa0921f Mon Sep 17 00:00:00 2001 From: Jordan Yelloz Date: Wed, 28 Feb 2024 09:13:12 -0700 Subject: [PATCH] livekit_signaller: Added dual-role support Part-of: --- net/webrtc/src/livekit_signaller/imp.rs | 641 ++++++++++++++++-------- net/webrtc/src/livekit_signaller/mod.rs | 16 +- 2 files changed, 452 insertions(+), 205 deletions(-) diff --git a/net/webrtc/src/livekit_signaller/imp.rs b/net/webrtc/src/livekit_signaller/imp.rs index e9ad0af7..d2f61a5a 100644 --- a/net/webrtc/src/livekit_signaller/imp.rs +++ b/net/webrtc/src/livekit_signaller/imp.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use crate::signaller::{Signallable, SignallableImpl}; +use crate::signaller::{Signallable, SignallableImpl, WebRTCSignallerRole}; use crate::utils::{wait_async, WaitError}; use crate::RUNTIME; @@ -13,6 +13,7 @@ use gst::subclass::prelude::*; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::str::FromStr; use std::sync::{Arc, Mutex}; use tokio::sync::oneshot; use tokio::task::JoinHandle; @@ -40,6 +41,9 @@ struct Settings { identity: Option, room_name: Option, auth_token: Option, + role: WebRTCSignallerRole, + producer_peer_id: Option, + excluded_produder_peer_ids: Vec, timeout: u32, } @@ -53,6 +57,9 @@ impl Default for Settings { identity: Some("gstreamer".to_string()), room_name: None, auth_token: None, + role: WebRTCSignallerRole::default(), + producer_peer_id: None, + excluded_produder_peer_ids: vec![], timeout: DEFAULT_TRACK_PUBLISH_TIMEOUT, } } @@ -93,6 +100,99 @@ impl Signaller { .emit_by_name::<()>("error", &[&format!("Error: {msg}")]); } + fn role(&self) -> Option { + self.settings.lock().map(|s| s.role).ok() + } + + fn is_subscriber(&self) -> bool { + matches!(self.role(), Some(WebRTCSignallerRole::Consumer)) + } + + fn producer_peer_id(&self) -> Option { + assert!(self.is_subscriber()); + let settings = self.settings.lock().ok()?; + settings.producer_peer_id.clone() + } + + fn auto_subscribe(&self) -> bool { + self.is_subscriber() + && self.producer_peer_id().is_none() + && self.excluded_producer_peer_ids_is_empty() + } + + fn signal_target(&self) -> Option { + match self.role()? { + WebRTCSignallerRole::Consumer => Some(proto::SignalTarget::Subscriber), + WebRTCSignallerRole::Producer => Some(proto::SignalTarget::Publisher), + _ => None, + } + } + + fn excluded_producer_peer_ids_is_empty(&self) -> bool { + assert!(self.is_subscriber()); + self.settings + .lock() + .unwrap() + .excluded_produder_peer_ids + .is_empty() + } + + fn is_peer_excluded(&self, peer_id: &str) -> bool { + self.settings + .lock() + .unwrap() + .excluded_produder_peer_ids + .iter() + .any(|id| id == peer_id) + } + + fn signal_client(&self) -> Option> { + let connection = self.connection.lock().unwrap(); + Some(connection.as_ref()?.signal_client.clone()) + } + + fn require_signal_client(&self) -> Arc { + self.signal_client().unwrap() + } + + async fn send_trickle_request(&self, candidate_init: &str) { + let Some(signal_client) = self.signal_client() else { + return; + }; + let Some(target) = self.signal_target() else { + return; + }; + signal_client + .send(proto::signal_request::Message::Trickle( + proto::TrickleRequest { + candidate_init: candidate_init.to_string(), + target: target as i32, + }, + )) + .await; + } + + async fn send_delayed_ice_candidates(&self) { + let Some(mut early_candidates) = self + .connection + .lock() + .unwrap() + .as_mut() + .and_then(|c| c.early_candidates.take()) + else { + return; + }; + + while let Some(candidate_str) = early_candidates.pop() { + gst::debug!( + CAT, + imp: self, + "Sending delayed ice candidate {candidate_str:?}" + ); + self.send_trickle_request(&candidate_str).await; + } + } + async fn signal_task(&self, mut signal_events: signal_client::SignalEvents) { loop { match wait_async(&self.signal_task_canceller, signal_events.recv(), 0).await { @@ -136,10 +236,36 @@ impl Signaller { self.obj() .emit_by_name::<()>("session-description", &[&"unique", &answer]); } + + proto::signal_response::Message::Offer(offer) => { + if !self.is_subscriber() { + gst::warning!(CAT, imp: self, "Ignoring subscriber offer in non-subscriber mode: {:?}", offer); + return; + } + gst::debug!(CAT, imp: self, "Received subscriber offer: {:?}", offer); + let sdp = match gst_sdp::SDPMessage::parse_buffer(offer.sdp.as_bytes()) { + Ok(sdp) => sdp, + Err(_) => { + self.raise_error("Couldn't parse Offer SDP".to_string()); + return; + } + }; + let offer = gst_webrtc::WebRTCSessionDescription::new( + gst_webrtc::WebRTCSDPType::Offer, + sdp, + ); + self.obj() + .emit_by_name::<()>("session-description", &[&"unique", &offer]); + } + proto::signal_response::Message::Trickle(trickle) => { gst::debug!(CAT, imp: self, "Received ice_candidate {:?}", trickle); - if trickle.target() == proto::SignalTarget::Publisher { + let Some(target) = self.signal_target() else { + return; + }; + + if target == trickle.target() { if let Ok(json) = serde_json::from_str::(&trickle.candidate_init) { @@ -165,6 +291,17 @@ impl Signaller { } } + proto::signal_response::Message::Update(update) => { + if !self.is_subscriber() { + gst::trace!(CAT, imp: self, "Ignoring update in non-subscriber mode: {:?}", update); + return; + } + gst::debug!(CAT, imp: self, "Update: {:?}", update); + for participant in update.participants { + self.on_participant(&participant, true) + } + } + proto::signal_response::Message::Leave(leave) => { gst::debug!(CAT, imp: self, "Leave: {:?}", leave); } @@ -172,182 +309,36 @@ impl Signaller { _ => {} } } -} - -impl SignallableImpl for Signaller { - fn start(&self) { - gst::debug!(CAT, imp: self, "Connecting"); - - let wsurl = if let Some(wsurl) = &self.settings.lock().unwrap().wsurl { - wsurl.clone() - } else { - self.raise_error("WebSocket URL must be set".to_string()); - return; - }; - - let auth_token = { - let settings = self.settings.lock().unwrap(); - - if let Some(auth_token) = &settings.auth_token { - auth_token.clone() - } else if let ( - Some(api_key), - Some(secret_key), - Some(identity), - Some(participant_name), - Some(room_name), - ) = ( - &settings.api_key, - &settings.secret_key, - &settings.identity, - &settings.participant_name, - &settings.room_name, - ) { - let grants = VideoGrants { - room_join: true, - can_subscribe: false, - room: room_name.clone(), - ..Default::default() - }; - let access_token = AccessToken::with_api_key(api_key, secret_key) - .with_name(participant_name) - .with_identity(identity) - .with_grants(grants); - match access_token.to_jwt() { - Ok(token) => token, - Err(err) => { - self.raise_error(format!( - "{:?}", - anyhow!("Could not create auth token {err}") - )); - return; - } - } - } else { - self.raise_error("Either auth-token or (api-key and secret-key and identity and room-name) must be set".to_string()); - return; - } - }; - - gst::debug!(CAT, imp: self, "We have an authentication token"); + fn send_sdp_answer(&self, _session_id: &str, sessdesc: &gst_webrtc::WebRTCSessionDescription) { let weak_imp = self.downgrade(); + let sessdesc = sessdesc.clone(); + RUNTIME.spawn(async move { - let imp = if let Some(imp) = weak_imp.upgrade() { - imp - } else { - return; - }; - - let options = signal_client::SignalOptions::default(); - gst::debug!(CAT, imp: imp, "Connecting to {}", wsurl); - - let res = signal_client::SignalClient::connect(&wsurl, &auth_token, options).await; - let (signal_client, join_response, signal_events) = match res { - Err(err) => { - imp.obj() - .emit_by_name::<()>("error", &[&format!("{:?}", anyhow!("Error: {err}"))]); - return; - } - Ok(ok) => ok, - }; - let signal_client = Arc::new(signal_client); - - gst::debug!( - CAT, - imp: imp, - "Connected with JoinResponse: {:?}", - join_response - ); - - let weak_imp = imp.downgrade(); - let signal_task = RUNTIME.spawn(async move { - if let Some(imp) = weak_imp.upgrade() { - imp.signal_task(signal_events).await; - } - }); - - let weak_imp = imp.downgrade(); - imp.obj().connect_closure( - "webrtcbin-ready", - false, - glib::closure!(|_signaler: &super::LiveKitSignaller, - _consumer_identifier: &str, - webrtcbin: &gst::Element| { - gst::info!(CAT, "Adding data channels"); - let reliable_channel = webrtcbin.emit_by_name::( - "create-data-channel", - &[ - &"_reliable", - &gst::Structure::builder("config") - .field("ordered", true) - .build(), - ], - ); - let lossy_channel = webrtcbin.emit_by_name::( - "create-data-channel", - &[ - &"_lossy", - &gst::Structure::builder("config") - .field("ordered", true) - .field("max-retransmits", 0) - .build(), - ], - ); - - if let Some(imp) = weak_imp.upgrade() { - let mut connection = imp.connection.lock().unwrap(); - if let Some(connection) = connection.as_mut() { - connection.channels = Some(Channels { - reliable_channel, - lossy_channel, - }); - } - } - }), - ); - - let connection = Connection { - signal_client, - signal_task, - pending_tracks: Default::default(), - early_candidates: Some(Vec::new()), - channels: None, - }; - - if let Ok(mut sc) = imp.connection.lock() { - *sc = Some(connection); + if let Some(imp) = weak_imp.upgrade() { + let sdp = sessdesc.sdp(); + gst::debug!(CAT, imp: imp, "Sending SDP {:?} now", &sdp); + let signal_client = imp.require_signal_client(); + signal_client + .send(proto::signal_request::Message::Answer( + proto::SessionDescription { + r#type: "answer".to_string(), + sdp: sdp.to_string(), + }, + )) + .await; + imp.send_delayed_ice_candidates().await; } - - imp.obj().emit_by_name::<()>( - "session-requested", - &[ - &"unique", - &"unique", - &None::, - ], - ); }); } - fn send_sdp(&self, _session_id: &str, sessdesc: &gst_webrtc::WebRTCSessionDescription) { - gst::debug!(CAT, imp: self, "Created offer SDP {:#?}", sessdesc.sdp()); - - assert!(sessdesc.type_() == gst_webrtc::WebRTCSDPType::Offer); - + fn send_sdp_offer(&self, _session_id: &str, sessdesc: &gst_webrtc::WebRTCSessionDescription) { let weak_imp = self.downgrade(); let sessdesc = sessdesc.clone(); RUNTIME.spawn(async move { if let Some(imp) = weak_imp.upgrade() { let sdp = sessdesc.sdp(); - let signal_client = imp - .connection - .lock() - .unwrap() - .as_ref() - .unwrap() - .signal_client - .clone(); + let signal_client = imp.require_signal_client(); let timeout = imp.settings.lock().unwrap().timeout; for media in sdp.medias() { @@ -457,35 +448,260 @@ impl SignallableImpl for Signaller { .await; if let Some(imp) = weak_imp.upgrade() { - let early_candidates = - if let Some(connection) = &mut *imp.connection.lock().unwrap() { - connection.early_candidates.take() - } else { - None - }; - - if let Some(mut early_candidates) = early_candidates { - while let Some(candidate_str) = early_candidates.pop() { - gst::debug!( - CAT, - imp: imp, - "Sending delayed ice candidate {candidate_str:?}" - ); - signal_client - .send(proto::signal_request::Message::Trickle( - proto::TrickleRequest { - candidate_init: candidate_str, - target: proto::SignalTarget::Publisher as i32, - }, - )) - .await; - } - } + imp.send_delayed_ice_candidates().await; } } }); } + fn on_participant(&self, participant: &proto::ParticipantInfo, new_connection: bool) { + gst::debug!(CAT, imp: self, "{:?}", participant); + if !participant.is_publisher { + return; + } + let peer_sid = &participant.sid; + let peer_identity = &participant.identity; + match self.producer_peer_id() { + Some(id) if id == *peer_sid => { + gst::debug!(CAT, imp: self, "matching peer sid {id:?}"); + } + Some(id) if id == *peer_identity => { + gst::debug!(CAT, imp: self, "matching peer identity {id:?}"); + } + None => { + if self.is_peer_excluded(peer_sid) || self.is_peer_excluded(peer_identity) { + gst::debug!(CAT, imp: self, "ignoring excluded peer {participant:?}"); + return; + } + gst::debug!(CAT, imp: self, "catch-all mode, matching {participant:?}"); + } + _ => return, + } + let meta = Some(&participant.metadata) + .filter(|meta| !meta.is_empty()) + .and_then(|meta| gst::Structure::from_str(meta).ok()); + match participant.state { + x if x == proto::participant_info::State::Active as i32 => { + let track_sids = participant + .tracks + .iter() + .filter(|t| !t.muted) + .map(|t| t.sid.clone()) + .collect::>(); + let update = proto::UpdateSubscription { + track_sids: track_sids.clone(), + subscribe: true, + participant_tracks: vec![proto::ParticipantTracks { + participant_sid: participant.sid.clone(), + track_sids: track_sids.clone(), + }], + }; + let update = proto::signal_request::Message::Subscription(update); + let weak_imp = self.downgrade(); + let peer_sid = peer_sid.clone(); + RUNTIME.spawn(async move { + let imp = match weak_imp.upgrade() { + Some(imp) => imp, + None => return, + }; + let signal_client = imp.require_signal_client(); + signal_client.send(update).await; + imp.obj() + .emit_by_name::<()>("producer-added", &[&peer_sid, &meta, &new_connection]); + }); + } + _ => { + self.obj() + .emit_by_name::<()>("producer-removed", &[&peer_sid, &meta]); + } + } + } +} + +impl SignallableImpl for Signaller { + fn start(&self) { + gst::debug!(CAT, imp: self, "Connecting"); + + let wsurl = if let Some(wsurl) = &self.settings.lock().unwrap().wsurl { + wsurl.clone() + } else { + self.raise_error("WebSocket URL must be set".to_string()); + return; + }; + + let auth_token = { + let settings = self.settings.lock().unwrap(); + let role = settings.role; + + if let Some(auth_token) = &settings.auth_token { + auth_token.clone() + } else if let ( + Some(api_key), + Some(secret_key), + Some(identity), + Some(participant_name), + Some(room_name), + ) = ( + &settings.api_key, + &settings.secret_key, + &settings.identity, + &settings.participant_name, + &settings.room_name, + ) { + let grants = VideoGrants { + room_join: true, + can_subscribe: role == WebRTCSignallerRole::Consumer, + room: room_name.clone(), + ..Default::default() + }; + let access_token = AccessToken::with_api_key(api_key, secret_key) + .with_name(participant_name) + .with_identity(identity) + .with_grants(grants); + match access_token.to_jwt() { + Ok(token) => token, + Err(err) => { + self.raise_error(format!( + "{:?}", + anyhow!("Could not create auth token {err}") + )); + return; + } + } + } else { + self.raise_error("Either auth-token or (api-key and secret-key and identity and room-name) must be set".to_string()); + return; + } + }; + + gst::debug!(CAT, imp: self, "We have an authentication token"); + + let weak_imp = self.downgrade(); + RUNTIME.spawn(async move { + let imp = if let Some(imp) = weak_imp.upgrade() { + imp + } else { + return; + }; + + let options = signal_client::SignalOptions { + auto_subscribe: imp.auto_subscribe(), + ..Default::default() + }; + gst::debug!(CAT, imp: imp, "Connecting to {}", wsurl); + + let res = signal_client::SignalClient::connect(&wsurl, &auth_token, options).await; + let (signal_client, join_response, signal_events) = match res { + Err(err) => { + imp.obj() + .emit_by_name::<()>("error", &[&format!("{:?}", anyhow!("Error: {err}"))]); + return; + } + Ok(ok) => ok, + }; + let signal_client = Arc::new(signal_client); + + gst::debug!( + CAT, + imp: imp, + "Connected with JoinResponse: {:?}", + join_response + ); + + let weak_imp = imp.downgrade(); + let signal_task = RUNTIME.spawn(async move { + if let Some(imp) = weak_imp.upgrade() { + imp.signal_task(signal_events).await; + } + }); + + if imp.is_subscriber() { + imp.obj() + .emit_by_name::<()>("session-started", &[&"unique", &"unique"]); + for participant in &join_response.other_participants { + imp.on_participant(participant, false) + } + } + + let weak_imp = imp.downgrade(); + imp.obj().connect_closure( + "webrtcbin-ready", + false, + glib::closure!(|_signaller: &super::LiveKitSignaller, + _consumer_identifier: &str, + webrtcbin: &gst::Element| { + gst::info!(CAT, "Adding data channels"); + let reliable_channel = webrtcbin.emit_by_name::( + "create-data-channel", + &[ + &"_reliable", + &gst::Structure::builder("config") + .field("ordered", true) + .build(), + ], + ); + let lossy_channel = webrtcbin.emit_by_name::( + "create-data-channel", + &[ + &"_lossy", + &gst::Structure::builder("config") + .field("ordered", true) + .field("max-retransmits", 0) + .build(), + ], + ); + + if let Some(imp) = weak_imp.upgrade() { + let mut connection = imp.connection.lock().unwrap(); + if let Some(connection) = connection.as_mut() { + connection.channels = Some(Channels { + reliable_channel, + lossy_channel, + }); + } + } + }), + ); + + let connection = Connection { + signal_client, + signal_task, + pending_tracks: Default::default(), + early_candidates: Some(Vec::new()), + channels: None, + }; + + if let Ok(mut sc) = imp.connection.lock() { + *sc = Some(connection); + } + + imp.obj().emit_by_name::<()>( + "session-requested", + &[ + &"unique", + &"unique", + &None::, + ], + ); + }); + } + + fn send_sdp(&self, session_id: &str, sessdesc: &gst_webrtc::WebRTCSessionDescription) { + gst::debug!(CAT, imp: self, "Created SDP {:?}", sessdesc.sdp()); + + match sessdesc.type_() { + gst_webrtc::WebRTCSDPType::Offer => { + self.send_sdp_offer(session_id, sessdesc); + } + gst_webrtc::WebRTCSDPType::Answer => { + self.send_sdp_answer(session_id, sessdesc); + } + _ => { + gst::debug!(CAT, imp: self, "Ignoring SDP {:?}", sessdesc.sdp()); + } + } + } + fn add_ice( &self, _session_id: &str, @@ -514,20 +730,7 @@ impl SignallableImpl for Signaller { let imp = self.downgrade(); RUNTIME.spawn(async move { if let Some(imp) = imp.upgrade() { - let signal_client = if let Some(connection) = &mut *imp.connection.lock().unwrap() { - connection.signal_client.clone() - } else { - return; - }; - - signal_client - .send(proto::signal_request::Message::Trickle( - proto::TrickleRequest { - candidate_init: candidate_str, - target: proto::SignalTarget::Publisher as i32, - }, - )) - .await; + imp.send_trickle_request(&candidate_str).await; }; }); } @@ -615,6 +818,22 @@ impl ObjectImpl for Signaller { .blurb("Lossy Data Channel object.") .flags(glib::ParamFlags::READABLE) .build(), + glib::ParamSpecEnum::builder_with_default("role", WebRTCSignallerRole::default()) + .nick("Sigaller Role") + .blurb("Whether this signaller acts as either a Consumer or Producer. Listener is not currently supported.") + .flags(glib::ParamFlags::READWRITE) + .build(), + glib::ParamSpecString::builder("producer-peer-id") + .nick("Producer Peer ID") + .blurb("When in Consumer Role, the signaller will subscribe to this peer's tracks.") + .flags(glib::ParamFlags::READWRITE) + .build(), + gst::ParamSpecArray::builder("excluded-producer-peer-ids") + .nick("Excluded Producer Peer IDs") + .blurb("When in Consumer Role, the signaller will not subscribe to these peers' tracks.") + .flags(glib::ParamFlags::READWRITE) + .element_spec(&glib::ParamSpecString::builder("producer-peer-id").build()) + .build(), ] }); @@ -648,6 +867,18 @@ impl ObjectImpl for Signaller { "timeout" => { settings.timeout = value.get().unwrap(); } + "role" => settings.role = value.get().unwrap(), + "producer-peer-id" => settings.producer_peer_id = value.get().unwrap(), + "excluded-producer-peer-ids" => { + settings.excluded_produder_peer_ids = value + .get::() + .expect("type checked upstream") + .as_slice() + .iter() + .filter_map(|id| id.get::<&str>().ok()) + .map(|id| id.to_string()) + .collect::>() + } _ => unimplemented!(), } } @@ -679,6 +910,8 @@ impl ObjectImpl for Signaller { }; channel.to_value() } + "role" => settings.role.to_value(), + "producer-peer-id" => settings.producer_peer_id.to_value(), _ => unimplemented!(), } } diff --git a/net/webrtc/src/livekit_signaller/mod.rs b/net/webrtc/src/livekit_signaller/mod.rs index a8346e6c..61c4549a 100644 --- a/net/webrtc/src/livekit_signaller/mod.rs +++ b/net/webrtc/src/livekit_signaller/mod.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use crate::signaller::Signallable; +use crate::signaller::{Signallable, WebRTCSignallerRole}; use gst::glib; mod imp; @@ -9,6 +9,20 @@ glib::wrapper! { pub struct LiveKitSignaller(ObjectSubclass) @implements Signallable; } +impl LiveKitSignaller { + fn new(role: WebRTCSignallerRole) -> Self { + glib::Object::builder().property("role", role).build() + } + + pub fn new_consumer() -> Self { + Self::new(WebRTCSignallerRole::Consumer) + } + + pub fn new_producer() -> Self { + Self::new(WebRTCSignallerRole::Producer) + } +} + impl Default for LiveKitSignaller { fn default() -> Self { glib::Object::new()