From 4a988aaeb8e650b3680514025fe7d52c6b10c659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laignel?= Date: Wed, 1 Mar 2023 09:44:51 +0100 Subject: [PATCH] net/aws/transcriber: use a TranscriberLoop struct This helps gather together the details related to the `TranscriberLoop`. One difference with previous implementation is that the ws `Client` is build each time the loop is started instead of being reused. With the new approach, we don't keep the connection open after EOS and we should be more resistant in case of a connection failure. Part-of: --- net/aws/src/transcriber/imp.rs | 297 +++++++++++++++------------------ 1 file changed, 136 insertions(+), 161 deletions(-) diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 63f1c0e3..7af8f50f 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -109,8 +109,124 @@ impl TranscriptionSettings { } } +struct TranscriberLoop { + imp: glib::subclass::ObjectImplRef, + client: aws_transcribe::Client, + settings: TranscriptionSettings, + lateness: gst::ClockTime, + buffer_rx: mpsc::Receiver, + transcript_notif_tx: mpsc::Sender<()>, +} + +impl TranscriberLoop { + fn new( + imp: &Transcriber, + aws_config: &aws_config::SdkConfig, + settings: TranscriptionSettings, + lateness: gst::ClockTime, + buffer_rx: mpsc::Receiver, + transcript_notif_tx: mpsc::Sender<()>, + ) -> Self { + TranscriberLoop { + imp: imp.ref_counted(), + client: aws_transcribe::Client::new(aws_config), + settings, + lateness, + buffer_rx, + transcript_notif_tx, + } + } + + async fn run(mut self) -> Result<(), gst::ErrorMessage> { + // Stream the incoming buffers chunked + let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| { + async_stream::stream! { + let data = buffer.map_readable().unwrap(); + use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; + for chunk in data.chunks(8192) { + yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); + } + } + }); + + let mut transcribe_builder = self + .client + .start_stream_transcription() + .language_code(self.settings.lang_code) + .media_sample_rate_hertz(self.settings.sample_rate) + .media_encoding(model::MediaEncoding::Pcm) + .enable_partial_results_stabilization(true) + .partial_results_stability(self.settings.results_stability) + .set_vocabulary_name(self.settings.vocabulary) + .set_session_id(self.settings.session_id); + + if let Some(vocabulary_filter) = self.settings.vocabulary_filter { + transcribe_builder = transcribe_builder + .vocabulary_filter_name(vocabulary_filter) + .vocabulary_filter_method(self.settings.vocabulary_filter_method); + } + + let mut output = transcribe_builder + .audio_stream(chunk_stream.into()) + .send() + .await + .map_err(|err| { + let err = format!("Transcribe ws init error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Init, ["{err}"]) + })?; + + while let Some(event) = output + .transcript_result_stream + .recv() + .await + .map_err(|err| { + let err = format!("Transcribe ws stream error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })? + { + if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { + let mut enqueued = false; + + if let Some(result) = transcript_evt + .transcript + .and_then(|transcript| transcript.results) + .and_then(|mut results| results.drain(..).next()) + { + gst::trace!(CAT, imp: self.imp, "Received: {result:?}"); + + if let Some(alternative) = result + .alternatives + .and_then(|mut alternatives| alternatives.drain(..).next()) + { + if let Some(items) = alternative.items { + enqueued = self.imp.enqueue(items, result.is_partial, self.lateness); + } + } + } + + if enqueued && self.transcript_notif_tx.send(()).await.is_err() { + gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel"); + break; + } + } else { + gst::warning!( + CAT, + imp: self.imp, + "Transcribe ws returned unknown event: consider upgrading the SDK" + ) + } + } + + gst::debug!(CAT, imp: self.imp, "Exiting ws loop"); + + Ok(()) + } +} + struct State { - client: Option, + aws_config: Option, buffer_tx: Option>, transcript_notif_tx: Option>, ws_loop_handle: Option>>, @@ -128,7 +244,7 @@ struct State { impl Default for State { fn default() -> Self { Self { - client: None, + aws_config: None, buffer_tx: None, transcript_notif_tx: None, ws_loop_handle: None, @@ -615,8 +731,8 @@ impl Transcriber { } fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> { - enum ClientStage { - Ready(aws_transcribe::Client), + enum ConfigStatus { + Ready(aws_config::SdkConfig), NotReady { access_key: Option, secret_access_key: Option, @@ -624,7 +740,7 @@ impl Transcriber { }, } - let (client_stage, transcription_settings, lateness, transcript_notif_tx); + let (config_status, transcription_settings, lateness, transcript_notif_tx); { let mut state = self.state.lock().unwrap(); @@ -660,10 +776,10 @@ impl Transcriber { transcription_settings = TranscriptionSettings::from(&settings, sample_rate); - client_stage = if let Some(client) = state.client.take() { - ClientStage::Ready(client) + config_status = if let Some(aws_config) = state.aws_config.take() { + ConfigStatus::Ready(aws_config) } else { - ClientStage::NotReady { + ConfigStatus::NotReady { access_key: settings.access_key.to_owned(), secret_access_key: settings.secret_access_key.to_owned(), session_token: settings.session_token.to_owned(), @@ -671,14 +787,14 @@ impl Transcriber { }; }; - let client = match client_stage { - ClientStage::Ready(client) => client, - ClientStage::NotReady { + let aws_config = match config_status { + ConfigStatus::Ready(aws_config) => aws_config, + ConfigStatus::NotReady { access_key, secret_access_key, session_token, } => { - gst::info!(CAT, imp: self, "Connecting..."); + gst::info!(CAT, imp: self, "Loading aws config..."); let _enter_guard = RUNTIME.enter(); let config_loader = match (access_key, secret_access_key) { @@ -707,172 +823,31 @@ impl Transcriber { let config = futures::executor::block_on(config_loader.load()); gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap()); - aws_transcribe::Client::new(&config) + config } }; let mut state = self.state.lock().unwrap(); let (buffer_tx, buffer_rx) = mpsc::channel(1); - let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( - client, + + let ws_loop_ctx = TranscriberLoop::new( + self, + &aws_config, transcription_settings, lateness, buffer_rx, transcript_notif_tx, - )); + ); + let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run()); + state.aws_config = Some(aws_config); state.ws_loop_handle = Some(ws_loop_handle); state.buffer_tx = Some(buffer_tx); Ok(()) } - fn build_ws_loop_fut( - &self, - client: aws_transcribe::Client, - settings: TranscriptionSettings, - lateness: gst::ClockTime, - buffer_rx: mpsc::Receiver, - transcript_notif_tx: mpsc::Sender<()>, - ) -> impl Future> { - let imp_weak = self.downgrade(); - async move { - use gst::glib::subclass::ObjectImplWeakRef; - - // Guard that restores client & transcript_notif_tx when the ws loop is done - struct Guard { - imp_weak: ObjectImplWeakRef, - client: Option, - transcript_notif_tx: Option>, - } - - impl Guard { - fn client(&self) -> &aws_transcribe::Client { - self.client.as_ref().unwrap() - } - - fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> { - self.transcript_notif_tx.as_mut().unwrap() - } - } - - impl Drop for Guard { - fn drop(&mut self) { - if let Some(imp) = self.imp_weak.upgrade() { - let mut state = imp.state.lock().unwrap(); - state.client = self.client.take(); - state.transcript_notif_tx = self.transcript_notif_tx.take(); - } - } - } - - let mut guard = Guard { - imp_weak: imp_weak.clone(), - client: Some(client), - transcript_notif_tx: Some(transcript_notif_tx), - }; - - // Stream the incoming buffers chunked - let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| { - async_stream::stream! { - let data = buffer.map_readable().unwrap(); - use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; - for chunk in data.chunks(8192) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - } - }); - - let mut transcribe_builder = guard - .client() - .start_stream_transcription() - .language_code(settings.lang_code) - .media_sample_rate_hertz(settings.sample_rate) - .media_encoding(model::MediaEncoding::Pcm) - .enable_partial_results_stabilization(true) - .partial_results_stability(settings.results_stability) - .set_vocabulary_name(settings.vocabulary) - .set_session_id(settings.session_id); - - if let Some(vocabulary_filter) = settings.vocabulary_filter { - transcribe_builder = transcribe_builder - .vocabulary_filter_name(vocabulary_filter) - .vocabulary_filter_method(settings.vocabulary_filter_method); - } - - let mut output = transcribe_builder - .audio_stream(chunk_stream.into()) - .send() - .await - .map_err(|err| { - let err = format!("Transcribe ws init error: {err}"); - if let Some(imp) = imp_weak.upgrade() { - gst::error!(CAT, imp: imp, "{err}"); - } - gst::error_msg!(gst::LibraryError::Init, ["{err}"]) - })?; - - while let Some(event) = output - .transcript_result_stream - .recv() - .await - .map_err(|err| { - let err = format!("Transcribe ws stream error: {err}"); - if let Some(imp) = imp_weak.upgrade() { - gst::error!(CAT, imp: imp, "{err}"); - } - gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) - })? - { - if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - let mut enqueued = false; - - if let Some(result) = transcript_evt - .transcript - .and_then(|transcript| transcript.results) - .and_then(|mut results| results.drain(..).next()) - { - let Some(imp) = imp_weak.upgrade() else { break }; - - gst::trace!(CAT, imp: imp, "Received: {result:?}"); - - if let Some(alternative) = result - .alternatives - .and_then(|mut alternatives| alternatives.drain(..).next()) - { - if let Some(items) = alternative.items { - enqueued = imp.enqueue(items, result.is_partial, lateness); - } - } - } - - if enqueued && guard.transcript_notif_tx().send(()).await.is_err() { - if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel"); - } - break; - } - } else if let Some(imp) = imp_weak.upgrade() { - gst::warning!( - CAT, - imp: imp, - "Transcribe ws returned unknown event: consider upgrading the SDK" - ) - } else { - // imp has left the building - break; - } - } - - if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Exiting ws loop"); - } - - Ok(()) - } - } - fn disconnect(&self) { let mut state = self.state.lock().unwrap(); gst::info!(CAT, imp: self, "Unpreparing");