net/aws: enqueue transcribed buffers within the ws loop

Instead of sending transcription events to the src pad loop, this commit
enqueues the transcribed buffers immediately in the ws loop, then notifies
the src pad loop. The src pad loop is only in charge of dequeuing the buffers.

This should help with upcoming evolutions.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
This commit is contained in:
François Laignel 2023-02-28 16:28:13 +01:00 committed by GStreamer Marge Bot
parent 00153754bb
commit 36ae29d746

View file

@ -66,7 +66,7 @@ struct Settings {
vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
}
impl std::default::Default for Settings {
impl Default for Settings {
fn default() -> Self {
Self {
latency: DEFAULT_LATENCY,
@ -112,26 +112,25 @@ impl TranscriptionSettings {
struct State {
client: Option<aws_transcribe::Client>,
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
transcript_notif_tx: Option<mpsc::Sender<()>>,
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
in_segment: gst::FormattedSegment<gst::ClockTime>,
out_segment: gst::FormattedSegment<gst::ClockTime>,
seqnum: gst::Seqnum,
buffers: VecDeque<gst::Buffer>,
send_eos: bool,
// FIXME never set to true
discont: bool,
partial_index: usize,
send_events: bool,
start_time: Option<gst::ClockTime>,
}
impl std::default::Default for State {
impl Default for State {
fn default() -> Self {
Self {
client: None,
buffer_tx: None,
transcript_tx: None,
transcript_notif_tx: None,
ws_loop_handle: None,
in_segment: gst::FormattedSegment::new(),
out_segment: gst::FormattedSegment::new(),
@ -297,8 +296,11 @@ impl Transcriber {
true
}
fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) {
let lateness = self.settings.lock().unwrap().lateness;
/// Enqueues a buffer for each of the provided stable items.
///
/// Returns `true` if at least one buffer was enqueued.
fn enqueue(&self, items: &[model::Item], partial: bool, lateness: gst::ClockTime) -> bool {
let mut state = self.state.lock().unwrap();
if items.len() <= state.partial_index {
gst::error!(
@ -313,53 +315,55 @@ impl Transcriber {
state.partial_index = 0;
}
return;
return false;
}
for item in &items[state.partial_index..] {
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let mut enqueued = false;
for item in &items[state.partial_index..] {
if !item.stable().unwrap_or(false) {
break;
}
// FIXME could probably just unwrap
if let Some(content) = item.content() {
/* Should be sent now */
gst::debug!(
CAT,
imp: self,
"Item is ready for queuing: {content}, PTS {start_time}",
);
let Some(content) = item.content() else { continue };
let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
{
let buf = buf.get_mut().unwrap();
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
/* Should be sent now */
gst::debug!(
CAT,
imp: self,
"Item is ready for queuing: {content}, PTS {start_time}",
);
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
state.partial_index += 1;
state.buffers.push_back(buf);
} else {
gst::debug!(CAT, imp: self, "None transcript item content");
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
state.partial_index += 1;
state.buffers.push_back(buf);
enqueued = true;
}
if !partial {
state.partial_index = 0;
}
enqueued
}
fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver<model::TranscriptEvent>) -> Result<(), ()> {
fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) {
let mut events = {
let mut events = vec![];
@ -400,56 +404,24 @@ impl Transcriber {
}
let future = async move {
enum Winner {
TranscriptEvent(Option<model::TranscriptEvent>),
Timeout,
}
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout);
let timer = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timer);
let race_res = futures::select_biased! {
transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt),
_ = timer => Winner::Timeout,
futures::select! {
notif = transcript_notif_rx.next() => {
if notif.is_none() {
// Transcriber loop terminated
self.state.lock().unwrap().send_eos = true;
return;
};
}
_ = timeout => (),
};
use Winner::*;
match race_res {
TranscriptEvent(Some(transcript_evt)) => {
if let Some(result) = transcript_evt
.transcript
.as_ref()
.and_then(|transcript| transcript.results())
.and_then(|results| results.get(0))
{
gst::trace!(CAT, imp: self, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.as_ref()
.and_then(|alternatives| alternatives.get(0))
{
if let Some(items) = alternative.items() {
let mut state = self.state.lock().unwrap();
self.enqueue(&mut state, items, result.is_partial)
}
}
}
}
TranscriptEvent(None) => {
gst::info!(CAT, imp: self, "Transcript evt channel disconnected");
// Something bad happened elsewhere, let the other side report.
return Err(());
}
Timeout => (),
}
if !self.dequeue() {
gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
let _ = self.srcpad.pause_task();
}
Ok(())
};
let _enter = RUNTIME.enter();
@ -459,24 +431,19 @@ impl Transcriber {
fn start_task(&self) -> Result<(), gst::LoggableError> {
let mut state = self.state.lock().unwrap();
let (transcript_tx, mut transcript_rx) = mpsc::channel(1);
let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1);
let imp = self.ref_counted();
let res = self.srcpad.start_task(move || {
if imp.pad_loop_fn(&mut transcript_rx).is_err() {
// Pad loop fn reported an unrecoverable error.
// FIXME we should probably stop the task as
// there's nothing we can do about it except restarting.
let _ = imp.srcpad.pause_task();
}
});
let res = self
.srcpad
.start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx));
if res.is_err() {
state.transcript_tx = None;
state.transcript_notif_tx = None;
return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
}
state.transcript_tx = Some(transcript_tx);
state.transcript_notif_tx = Some(transcript_notif_tx);
Ok(())
}
@ -490,7 +457,7 @@ impl Transcriber {
ws_loop_handle.abort();
}
state.transcript_tx = None;
state.transcript_notif_tx = None;
state.buffer_tx = None;
}
@ -652,7 +619,8 @@ impl Transcriber {
},
}
let (client_stage, transcription_settings, transcript_tx) = {
let (client_stage, transcription_settings, lateness, transcript_notif_tx);
{
let mut state = self.state.lock().unwrap();
if let Some(ref ws_loop_handle) = state.ws_loop_handle {
@ -667,14 +635,15 @@ impl Transcriber {
return Ok(());
}
let transcript_tx = state
.transcript_tx
transcript_notif_tx = state
.transcript_notif_tx
.take()
.expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
let settings = self.settings.lock().unwrap();
if settings.latency + settings.lateness <= 2 * GRANULARITY {
lateness = settings.lateness;
if settings.latency + lateness <= 2 * GRANULARITY {
const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
gst::error!(CAT, imp: self, "{ERR}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
@ -684,9 +653,9 @@ impl Transcriber {
let s = in_caps.structure(0).unwrap();
let sample_rate = s.get::<i32>("rate").unwrap();
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
let client_stage = if let Some(client) = state.client.take() {
client_stage = if let Some(client) = state.client.take() {
ClientStage::Ready(client)
} else {
ClientStage::NotReady {
@ -695,8 +664,6 @@ impl Transcriber {
session_token: settings.session_token.to_owned(),
}
};
(client_stage, transcription_settings, transcript_tx)
};
let client = match client_stage {
@ -745,8 +712,9 @@ impl Transcriber {
let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
client,
transcription_settings,
lateness,
buffer_rx,
transcript_tx,
transcript_notif_tx,
));
state.ws_loop_handle = Some(ws_loop_handle);
@ -759,18 +727,19 @@ impl Transcriber {
&self,
client: aws_transcribe::Client,
settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_tx: mpsc::Sender<model::TranscriptEvent>,
transcript_notif_tx: mpsc::Sender<()>,
) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
let imp_weak = self.downgrade();
async move {
use gst::glib::subclass::ObjectImplWeakRef;
// Guard that restores client & transcript_tx when the ws loop is done
// Guard that restores client & transcript_notif_tx when the ws loop is done
struct Guard {
imp_weak: ObjectImplWeakRef<Transcriber>,
client: Option<aws_transcribe::Client>,
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
transcript_notif_tx: Option<mpsc::Sender<()>>,
}
impl Guard {
@ -778,8 +747,8 @@ impl Transcriber {
self.client.as_ref().unwrap()
}
fn transcript_tx(&mut self) -> &mut mpsc::Sender<model::TranscriptEvent> {
self.transcript_tx.as_mut().unwrap()
fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> {
self.transcript_notif_tx.as_mut().unwrap()
}
}
@ -788,7 +757,7 @@ impl Transcriber {
if let Some(imp) = self.imp_weak.upgrade() {
let mut state = imp.state.lock().unwrap();
state.client = self.client.take();
state.transcript_tx = self.transcript_tx.take();
state.transcript_notif_tx = self.transcript_notif_tx.take();
}
}
}
@ -796,7 +765,7 @@ impl Transcriber {
let mut guard = Guard {
imp_weak: imp_weak.clone(),
client: Some(client),
transcript_tx: Some(transcript_tx),
transcript_notif_tx: Some(transcript_notif_tx),
};
// Stream the incoming buffers chunked
@ -852,9 +821,32 @@ impl Transcriber {
})?
{
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
if guard.transcript_tx().send(transcript_evt).await.is_err() {
let mut enqueued = false;
if let Some(result) = transcript_evt
.transcript
.as_ref()
.and_then(|transcript| transcript.results())
.and_then(|results| results.get(0))
{
let Some(imp) = imp_weak.upgrade() else { break };
gst::trace!(CAT, imp: imp, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.as_ref()
.and_then(|alternatives| alternatives.get(0))
{
if let Some(items) = alternative.items() {
enqueued = imp.enqueue(items, result.is_partial, lateness);
}
}
}
if enqueued && guard.transcript_notif_tx().send(()).await.is_err() {
if let Some(imp) = imp_weak.upgrade() {
gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel");
gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel");
}
break;
}
@ -882,6 +874,7 @@ impl Transcriber {
let mut state = self.state.lock().unwrap();
gst::info!(CAT, imp: self, "Unpreparing");
self.stop_task();
// Also resets discont to true
*state = State::default();
gst::info!(CAT, imp: self, "Unprepared");
}