net/aws/transcriber: use a TranscriberLoop struct

This helps gather together the details related to the `TranscriberLoop`.
One difference with previous implementation is that the ws `Client` is
build each time the loop is started instead of being reused. With the new
approach, we don't keep the connection open after EOS and we should be
more resistant in case of a connection failure.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
This commit is contained in:
François Laignel 2023-03-01 09:44:51 +01:00 committed by GStreamer Marge Bot
parent f1a080c94e
commit 4a988aaeb8

View file

@ -109,8 +109,124 @@ impl TranscriptionSettings {
}
}
struct TranscriberLoop {
imp: glib::subclass::ObjectImplRef<Transcriber>,
client: aws_transcribe::Client,
settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_notif_tx: mpsc::Sender<()>,
}
impl TranscriberLoop {
fn new(
imp: &Transcriber,
aws_config: &aws_config::SdkConfig,
settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_notif_tx: mpsc::Sender<()>,
) -> Self {
TranscriberLoop {
imp: imp.ref_counted(),
client: aws_transcribe::Client::new(aws_config),
settings,
lateness,
buffer_rx,
transcript_notif_tx,
}
}
async fn run(mut self) -> Result<(), gst::ErrorMessage> {
// Stream the incoming buffers chunked
let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| {
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
for chunk in data.chunks(8192) {
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
}
}
});
let mut transcribe_builder = self
.client
.start_stream_transcription()
.language_code(self.settings.lang_code)
.media_sample_rate_hertz(self.settings.sample_rate)
.media_encoding(model::MediaEncoding::Pcm)
.enable_partial_results_stabilization(true)
.partial_results_stability(self.settings.results_stability)
.set_vocabulary_name(self.settings.vocabulary)
.set_session_id(self.settings.session_id);
if let Some(vocabulary_filter) = self.settings.vocabulary_filter {
transcribe_builder = transcribe_builder
.vocabulary_filter_name(vocabulary_filter)
.vocabulary_filter_method(self.settings.vocabulary_filter_method);
}
let mut output = transcribe_builder
.audio_stream(chunk_stream.into())
.send()
.await
.map_err(|err| {
let err = format!("Transcribe ws init error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?;
while let Some(event) = output
.transcript_result_stream
.recv()
.await
.map_err(|err| {
let err = format!("Transcribe ws stream error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?
{
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut enqueued = false;
if let Some(result) = transcript_evt
.transcript
.and_then(|transcript| transcript.results)
.and_then(|mut results| results.drain(..).next())
{
gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
{
if let Some(items) = alternative.items {
enqueued = self.imp.enqueue(items, result.is_partial, self.lateness);
}
}
}
if enqueued && self.transcript_notif_tx.send(()).await.is_err() {
gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel");
break;
}
} else {
gst::warning!(
CAT,
imp: self.imp,
"Transcribe ws returned unknown event: consider upgrading the SDK"
)
}
}
gst::debug!(CAT, imp: self.imp, "Exiting ws loop");
Ok(())
}
}
struct State {
client: Option<aws_transcribe::Client>,
aws_config: Option<aws_config::SdkConfig>,
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
transcript_notif_tx: Option<mpsc::Sender<()>>,
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
@ -128,7 +244,7 @@ struct State {
impl Default for State {
fn default() -> Self {
Self {
client: None,
aws_config: None,
buffer_tx: None,
transcript_notif_tx: None,
ws_loop_handle: None,
@ -615,8 +731,8 @@ impl Transcriber {
}
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
enum ClientStage {
Ready(aws_transcribe::Client),
enum ConfigStatus {
Ready(aws_config::SdkConfig),
NotReady {
access_key: Option<String>,
secret_access_key: Option<String>,
@ -624,7 +740,7 @@ impl Transcriber {
},
}
let (client_stage, transcription_settings, lateness, transcript_notif_tx);
let (config_status, transcription_settings, lateness, transcript_notif_tx);
{
let mut state = self.state.lock().unwrap();
@ -660,10 +776,10 @@ impl Transcriber {
transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
client_stage = if let Some(client) = state.client.take() {
ClientStage::Ready(client)
config_status = if let Some(aws_config) = state.aws_config.take() {
ConfigStatus::Ready(aws_config)
} else {
ClientStage::NotReady {
ConfigStatus::NotReady {
access_key: settings.access_key.to_owned(),
secret_access_key: settings.secret_access_key.to_owned(),
session_token: settings.session_token.to_owned(),
@ -671,14 +787,14 @@ impl Transcriber {
};
};
let client = match client_stage {
ClientStage::Ready(client) => client,
ClientStage::NotReady {
let aws_config = match config_status {
ConfigStatus::Ready(aws_config) => aws_config,
ConfigStatus::NotReady {
access_key,
secret_access_key,
session_token,
} => {
gst::info!(CAT, imp: self, "Connecting...");
gst::info!(CAT, imp: self, "Loading aws config...");
let _enter_guard = RUNTIME.enter();
let config_loader = match (access_key, secret_access_key) {
@ -707,172 +823,31 @@ impl Transcriber {
let config = futures::executor::block_on(config_loader.load());
gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap());
aws_transcribe::Client::new(&config)
config
}
};
let mut state = self.state.lock().unwrap();
let (buffer_tx, buffer_rx) = mpsc::channel(1);
let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
client,
let ws_loop_ctx = TranscriberLoop::new(
self,
&aws_config,
transcription_settings,
lateness,
buffer_rx,
transcript_notif_tx,
));
);
let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run());
state.aws_config = Some(aws_config);
state.ws_loop_handle = Some(ws_loop_handle);
state.buffer_tx = Some(buffer_tx);
Ok(())
}
fn build_ws_loop_fut(
&self,
client: aws_transcribe::Client,
settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_notif_tx: mpsc::Sender<()>,
) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
let imp_weak = self.downgrade();
async move {
use gst::glib::subclass::ObjectImplWeakRef;
// Guard that restores client & transcript_notif_tx when the ws loop is done
struct Guard {
imp_weak: ObjectImplWeakRef<Transcriber>,
client: Option<aws_transcribe::Client>,
transcript_notif_tx: Option<mpsc::Sender<()>>,
}
impl Guard {
fn client(&self) -> &aws_transcribe::Client {
self.client.as_ref().unwrap()
}
fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> {
self.transcript_notif_tx.as_mut().unwrap()
}
}
impl Drop for Guard {
fn drop(&mut self) {
if let Some(imp) = self.imp_weak.upgrade() {
let mut state = imp.state.lock().unwrap();
state.client = self.client.take();
state.transcript_notif_tx = self.transcript_notif_tx.take();
}
}
}
let mut guard = Guard {
imp_weak: imp_weak.clone(),
client: Some(client),
transcript_notif_tx: Some(transcript_notif_tx),
};
// Stream the incoming buffers chunked
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
for chunk in data.chunks(8192) {
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
}
}
});
let mut transcribe_builder = guard
.client()
.start_stream_transcription()
.language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate)
.media_encoding(model::MediaEncoding::Pcm)
.enable_partial_results_stabilization(true)
.partial_results_stability(settings.results_stability)
.set_vocabulary_name(settings.vocabulary)
.set_session_id(settings.session_id);
if let Some(vocabulary_filter) = settings.vocabulary_filter {
transcribe_builder = transcribe_builder
.vocabulary_filter_name(vocabulary_filter)
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
let mut output = transcribe_builder
.audio_stream(chunk_stream.into())
.send()
.await
.map_err(|err| {
let err = format!("Transcribe ws init error: {err}");
if let Some(imp) = imp_weak.upgrade() {
gst::error!(CAT, imp: imp, "{err}");
}
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?;
while let Some(event) = output
.transcript_result_stream
.recv()
.await
.map_err(|err| {
let err = format!("Transcribe ws stream error: {err}");
if let Some(imp) = imp_weak.upgrade() {
gst::error!(CAT, imp: imp, "{err}");
}
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?
{
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut enqueued = false;
if let Some(result) = transcript_evt
.transcript
.and_then(|transcript| transcript.results)
.and_then(|mut results| results.drain(..).next())
{
let Some(imp) = imp_weak.upgrade() else { break };
gst::trace!(CAT, imp: imp, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
{
if let Some(items) = alternative.items {
enqueued = imp.enqueue(items, result.is_partial, lateness);
}
}
}
if enqueued && guard.transcript_notif_tx().send(()).await.is_err() {
if let Some(imp) = imp_weak.upgrade() {
gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel");
}
break;
}
} else if let Some(imp) = imp_weak.upgrade() {
gst::warning!(
CAT,
imp: imp,
"Transcribe ws returned unknown event: consider upgrading the SDK"
)
} else {
// imp has left the building
break;
}
}
if let Some(imp) = imp_weak.upgrade() {
gst::debug!(CAT, imp: imp, "Exiting ws loop");
}
Ok(())
}
}
fn disconnect(&self) {
let mut state = self.state.lock().unwrap();
gst::info!(CAT, imp: self, "Unpreparing");