diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 63f1c0e3..7af8f50f 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -109,8 +109,124 @@ impl TranscriptionSettings { } } +struct TranscriberLoop { + imp: glib::subclass::ObjectImplRef, + client: aws_transcribe::Client, + settings: TranscriptionSettings, + lateness: gst::ClockTime, + buffer_rx: mpsc::Receiver, + transcript_notif_tx: mpsc::Sender<()>, +} + +impl TranscriberLoop { + fn new( + imp: &Transcriber, + aws_config: &aws_config::SdkConfig, + settings: TranscriptionSettings, + lateness: gst::ClockTime, + buffer_rx: mpsc::Receiver, + transcript_notif_tx: mpsc::Sender<()>, + ) -> Self { + TranscriberLoop { + imp: imp.ref_counted(), + client: aws_transcribe::Client::new(aws_config), + settings, + lateness, + buffer_rx, + transcript_notif_tx, + } + } + + async fn run(mut self) -> Result<(), gst::ErrorMessage> { + // Stream the incoming buffers chunked + let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| { + async_stream::stream! { + let data = buffer.map_readable().unwrap(); + use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; + for chunk in data.chunks(8192) { + yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); + } + } + }); + + let mut transcribe_builder = self + .client + .start_stream_transcription() + .language_code(self.settings.lang_code) + .media_sample_rate_hertz(self.settings.sample_rate) + .media_encoding(model::MediaEncoding::Pcm) + .enable_partial_results_stabilization(true) + .partial_results_stability(self.settings.results_stability) + .set_vocabulary_name(self.settings.vocabulary) + .set_session_id(self.settings.session_id); + + if let Some(vocabulary_filter) = self.settings.vocabulary_filter { + transcribe_builder = transcribe_builder + .vocabulary_filter_name(vocabulary_filter) + .vocabulary_filter_method(self.settings.vocabulary_filter_method); + } + + let mut output = transcribe_builder + .audio_stream(chunk_stream.into()) + .send() + .await + .map_err(|err| { + let err = format!("Transcribe ws init error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Init, ["{err}"]) + })?; + + while let Some(event) = output + .transcript_result_stream + .recv() + .await + .map_err(|err| { + let err = format!("Transcribe ws stream error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })? + { + if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { + let mut enqueued = false; + + if let Some(result) = transcript_evt + .transcript + .and_then(|transcript| transcript.results) + .and_then(|mut results| results.drain(..).next()) + { + gst::trace!(CAT, imp: self.imp, "Received: {result:?}"); + + if let Some(alternative) = result + .alternatives + .and_then(|mut alternatives| alternatives.drain(..).next()) + { + if let Some(items) = alternative.items { + enqueued = self.imp.enqueue(items, result.is_partial, self.lateness); + } + } + } + + if enqueued && self.transcript_notif_tx.send(()).await.is_err() { + gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel"); + break; + } + } else { + gst::warning!( + CAT, + imp: self.imp, + "Transcribe ws returned unknown event: consider upgrading the SDK" + ) + } + } + + gst::debug!(CAT, imp: self.imp, "Exiting ws loop"); + + Ok(()) + } +} + struct State { - client: Option, + aws_config: Option, buffer_tx: Option>, transcript_notif_tx: Option>, ws_loop_handle: Option>>, @@ -128,7 +244,7 @@ struct State { impl Default for State { fn default() -> Self { Self { - client: None, + aws_config: None, buffer_tx: None, transcript_notif_tx: None, ws_loop_handle: None, @@ -615,8 +731,8 @@ impl Transcriber { } fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> { - enum ClientStage { - Ready(aws_transcribe::Client), + enum ConfigStatus { + Ready(aws_config::SdkConfig), NotReady { access_key: Option, secret_access_key: Option, @@ -624,7 +740,7 @@ impl Transcriber { }, } - let (client_stage, transcription_settings, lateness, transcript_notif_tx); + let (config_status, transcription_settings, lateness, transcript_notif_tx); { let mut state = self.state.lock().unwrap(); @@ -660,10 +776,10 @@ impl Transcriber { transcription_settings = TranscriptionSettings::from(&settings, sample_rate); - client_stage = if let Some(client) = state.client.take() { - ClientStage::Ready(client) + config_status = if let Some(aws_config) = state.aws_config.take() { + ConfigStatus::Ready(aws_config) } else { - ClientStage::NotReady { + ConfigStatus::NotReady { access_key: settings.access_key.to_owned(), secret_access_key: settings.secret_access_key.to_owned(), session_token: settings.session_token.to_owned(), @@ -671,14 +787,14 @@ impl Transcriber { }; }; - let client = match client_stage { - ClientStage::Ready(client) => client, - ClientStage::NotReady { + let aws_config = match config_status { + ConfigStatus::Ready(aws_config) => aws_config, + ConfigStatus::NotReady { access_key, secret_access_key, session_token, } => { - gst::info!(CAT, imp: self, "Connecting..."); + gst::info!(CAT, imp: self, "Loading aws config..."); let _enter_guard = RUNTIME.enter(); let config_loader = match (access_key, secret_access_key) { @@ -707,172 +823,31 @@ impl Transcriber { let config = futures::executor::block_on(config_loader.load()); gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap()); - aws_transcribe::Client::new(&config) + config } }; let mut state = self.state.lock().unwrap(); let (buffer_tx, buffer_rx) = mpsc::channel(1); - let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( - client, + + let ws_loop_ctx = TranscriberLoop::new( + self, + &aws_config, transcription_settings, lateness, buffer_rx, transcript_notif_tx, - )); + ); + let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run()); + state.aws_config = Some(aws_config); state.ws_loop_handle = Some(ws_loop_handle); state.buffer_tx = Some(buffer_tx); Ok(()) } - fn build_ws_loop_fut( - &self, - client: aws_transcribe::Client, - settings: TranscriptionSettings, - lateness: gst::ClockTime, - buffer_rx: mpsc::Receiver, - transcript_notif_tx: mpsc::Sender<()>, - ) -> impl Future> { - let imp_weak = self.downgrade(); - async move { - use gst::glib::subclass::ObjectImplWeakRef; - - // Guard that restores client & transcript_notif_tx when the ws loop is done - struct Guard { - imp_weak: ObjectImplWeakRef, - client: Option, - transcript_notif_tx: Option>, - } - - impl Guard { - fn client(&self) -> &aws_transcribe::Client { - self.client.as_ref().unwrap() - } - - fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> { - self.transcript_notif_tx.as_mut().unwrap() - } - } - - impl Drop for Guard { - fn drop(&mut self) { - if let Some(imp) = self.imp_weak.upgrade() { - let mut state = imp.state.lock().unwrap(); - state.client = self.client.take(); - state.transcript_notif_tx = self.transcript_notif_tx.take(); - } - } - } - - let mut guard = Guard { - imp_weak: imp_weak.clone(), - client: Some(client), - transcript_notif_tx: Some(transcript_notif_tx), - }; - - // Stream the incoming buffers chunked - let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| { - async_stream::stream! { - let data = buffer.map_readable().unwrap(); - use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; - for chunk in data.chunks(8192) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - } - }); - - let mut transcribe_builder = guard - .client() - .start_stream_transcription() - .language_code(settings.lang_code) - .media_sample_rate_hertz(settings.sample_rate) - .media_encoding(model::MediaEncoding::Pcm) - .enable_partial_results_stabilization(true) - .partial_results_stability(settings.results_stability) - .set_vocabulary_name(settings.vocabulary) - .set_session_id(settings.session_id); - - if let Some(vocabulary_filter) = settings.vocabulary_filter { - transcribe_builder = transcribe_builder - .vocabulary_filter_name(vocabulary_filter) - .vocabulary_filter_method(settings.vocabulary_filter_method); - } - - let mut output = transcribe_builder - .audio_stream(chunk_stream.into()) - .send() - .await - .map_err(|err| { - let err = format!("Transcribe ws init error: {err}"); - if let Some(imp) = imp_weak.upgrade() { - gst::error!(CAT, imp: imp, "{err}"); - } - gst::error_msg!(gst::LibraryError::Init, ["{err}"]) - })?; - - while let Some(event) = output - .transcript_result_stream - .recv() - .await - .map_err(|err| { - let err = format!("Transcribe ws stream error: {err}"); - if let Some(imp) = imp_weak.upgrade() { - gst::error!(CAT, imp: imp, "{err}"); - } - gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) - })? - { - if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - let mut enqueued = false; - - if let Some(result) = transcript_evt - .transcript - .and_then(|transcript| transcript.results) - .and_then(|mut results| results.drain(..).next()) - { - let Some(imp) = imp_weak.upgrade() else { break }; - - gst::trace!(CAT, imp: imp, "Received: {result:?}"); - - if let Some(alternative) = result - .alternatives - .and_then(|mut alternatives| alternatives.drain(..).next()) - { - if let Some(items) = alternative.items { - enqueued = imp.enqueue(items, result.is_partial, lateness); - } - } - } - - if enqueued && guard.transcript_notif_tx().send(()).await.is_err() { - if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel"); - } - break; - } - } else if let Some(imp) = imp_weak.upgrade() { - gst::warning!( - CAT, - imp: imp, - "Transcribe ws returned unknown event: consider upgrading the SDK" - ) - } else { - // imp has left the building - break; - } - } - - if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Exiting ws loop"); - } - - Ok(()) - } - } - fn disconnect(&self) { let mut state = self.state.lock().unwrap(); gst::info!(CAT, imp: self, "Unpreparing");