net/aws/transcriber: use two queues for sending transcript items

* A queue dedicated to transcript items not intended for translation.
* A queue dedicated to transcript items intended for translation. The items are
  enqueued after a separator is detected or translate-lookahead was reached.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1137>
This commit is contained in:
François Laignel 2023-03-16 18:20:08 +01:00
parent 5a5ca76d9d
commit 2b32d00589
3 changed files with 373 additions and 323 deletions

View file

@ -29,12 +29,12 @@ use futures::prelude::*;
use tokio::{runtime, sync::broadcast, task};
use std::collections::{BTreeSet, VecDeque};
use std::sync::Mutex;
use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy;
use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings};
use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem};
use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
use super::translate::{TranslateLoop, TranslatedItem};
use super::{
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
TranslationTokenizationMethod, CAT,
@ -148,6 +148,7 @@ struct State {
srcpads: BTreeSet<super::TranslateSrcPad>,
pad_serial: u32,
seqnum: gst::Seqnum,
start_time: Option<gst::ClockTime>,
}
impl Default for State {
@ -158,6 +159,7 @@ impl Default for State {
srcpads: Default::default(),
pad_serial: 0,
seqnum: gst::Seqnum::next(),
start_time: None,
}
}
}
@ -168,7 +170,9 @@ pub struct Transcriber {
settings: Mutex<Settings>,
state: Mutex<State>,
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
// sender to broadcast transcript items to the translate src pads.
// sender to broadcast transcript items to the src pads for translation.
transcript_event_for_translate_tx: broadcast::Sender<TranscriptEvent>,
// sender to broadcast transcript items to the src pads, not intended for translation.
transcript_event_tx: broadcast::Sender<TranscriptEvent>,
}
@ -276,7 +280,10 @@ impl Transcriber {
) -> Result<gst::FlowSuccess, gst::FlowError> {
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
self.ensure_connection();
self.ensure_connection().map_err(|err| {
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
gst::FlowError::Error
})?;
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
gst::log!(CAT, obj: pad, "Flushing");
@ -292,12 +299,82 @@ impl Transcriber {
Ok(gst::FlowSuccess::Ok)
}
}
fn ensure_connection(&self) {
#[derive(Default)]
struct TranslateQueue {
items: VecDeque<TranscriptItem>,
}
impl TranslateQueue {
fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Pushes the provided item.
///
/// Returns `Some(..)` if items are ready for translation.
fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
// Keep track of the item individually so we can schedule translation precisely.
self.items.push_back(transcript_item.clone());
if transcript_item.is_punctuation {
// This makes it a good chunk for translation.
// Concatenate as a single item for translation
return Some(self.items.drain(..).collect());
}
// Regular case: no separator detected, don't push transcript items
// to translation now. They will be pushed either if a punctuation
// is found or of a `dequeue()` is requested.
None
}
/// Dequeues items from the specified `deadline` up to `lookahead`.
///
/// Returns `Some(..)` if some items match the criteria.
fn dequeue(
&mut self,
latency: gst::ClockTime,
threshold: gst::ClockTime,
lookahead: gst::ClockTime,
) -> Option<Vec<TranscriptItem>> {
let first_pts = self.items.front()?.pts;
if first_pts + latency > threshold {
// First item is too early to be sent to translation now
// we can wait for more items to accumulate.
return None;
}
// Can't wait any longer to send the first item to translation
// Try to get up to lookahead worth of items to improve translation accuracy
let limit = first_pts + lookahead;
let mut items_acc = vec![self.items.pop_front().unwrap()];
while let Some(item) = self.items.front() {
if item.pts > limit {
break;
}
items_acc.push(self.items.pop_front().unwrap());
}
Some(items_acc)
}
fn drain(&mut self) -> impl Iterator<Item = TranscriptItem> + '_ {
self.items.drain(..)
}
}
impl Transcriber {
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
let mut state = self.state.lock().unwrap();
if state.buffer_tx.is_some() {
return;
return Ok(());
}
let settings = self.settings.lock().unwrap();
@ -306,21 +383,116 @@ impl Transcriber {
let s = in_caps.structure(0).unwrap();
let sample_rate = s.get::<i32>("rate").unwrap();
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
let transcription_settings = TranscriberSettings::from(&settings, sample_rate);
let (buffer_tx, buffer_rx) = mpsc::channel(1);
let transcriber_loop = TranscriberLoop::new(
let _enter = RUNTIME.enter();
let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new(
self,
transcription_settings,
settings.lateness,
buffer_rx,
self.transcript_event_tx.clone(),
);
let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run());
))?;
// Latency budget for an item to be pushed to stream on time
// Margin:
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
// - 1 * GRANULARITY: extra margin to account for additional overheads.
let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY);
let translate_lookahead = settings.translate_lookahead;
let mut translate_queue = TranslateQueue::default();
let imp = self.ref_counted();
let transcriber_loop_handle = RUNTIME.spawn(async move {
loop {
// This is to make sure we send items on a timely basis or at least Gap events.
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout);
let transcriber_next = transcriber_stream.next().fuse();
futures::pin_mut!(transcriber_next);
// `transcriber_next` takes precedence over `timeout`
// because we don't want to loose any incoming items.
let res = futures::select_biased! {
event = transcriber_next => Some(event?),
_ = timeout => None,
};
use TranscriptEvent::*;
match res {
None => (),
Some(Items(items)) => {
if imp.transcript_event_tx.receiver_count() > 0 {
let _ = imp.transcript_event_tx.send(Items(items.clone()));
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
for item in items.iter() {
if let Some(items_to_translate) = translate_queue.push(item) {
let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
}
}
}
}
Some(Eos) => {
gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS");
if imp.transcript_event_tx.receiver_count() > 0 {
let _ = imp.transcript_event_tx.send(Eos);
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
let items_to_translate: Vec<TranscriptItem> =
translate_queue.drain().collect();
let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
let _ = imp.transcript_event_for_translate_tx.send(Eos);
}
break;
}
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
// Check if we need to push items for translation
let Some((start_time, now)) = imp.get_start_time_and_now() else {
continue;
};
if !translate_queue.is_empty() {
let threshold = now - start_time;
if let Some(items_to_translate) =
translate_queue.dequeue(latency, threshold, translate_lookahead)
{
gst::debug!(
CAT,
imp: imp,
"Forcing to translation (threshold {threshold}): {items_to_translate:?}"
);
let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
}
}
}
}
gst::debug!(CAT, imp: imp, "Exiting transcriber loop");
Ok(())
});
state.transcriber_loop_handle = Some(transcriber_loop_handle);
state.buffer_tx = Some(buffer_tx);
Ok(())
}
fn prepare(&self) -> Result<(), gst::ErrorMessage> {
@ -382,6 +554,18 @@ impl Transcriber {
}
gst::info!(CAT, imp: self, "Unprepared");
}
fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> {
let now = self.obj().current_running_time()?;
let mut state = self.state.lock().unwrap();
if state.start_time.is_none() {
state.start_time = Some(now);
}
Some((state.start_time.unwrap(), now))
}
}
#[glib::object_subclass]
@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber {
// Setting the channel capacity so that a TranslateSrcPad that would lag
// behind for some reasons get a chance to catch-up without loosing items.
// Receiver will be created by subscribing to sender later.
let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
let (transcript_event_tx, _) = broadcast::channel(128);
Self {
@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber {
settings: Default::default(),
state: Default::default(),
aws_config: Default::default(),
transcript_event_for_translate_tx,
transcript_event_tx,
}
}
@ -876,51 +1062,93 @@ struct TranslationPadTask {
elem: super::Transcriber,
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
needs_translate: bool,
translate_queue: TranslateQueue,
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
to_translate_tx: Option<mpsc::Sender<Vec<TranscriptItem>>>,
to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
translate_latency: gst::ClockTime,
translate_lookahead: gst::ClockTime,
send_events: bool,
output_items: VecDeque<OutputItem>,
our_latency: gst::ClockTime,
seqnum: gst::Seqnum,
send_eos: bool,
pending_translations: usize,
start_time: Option<gst::ClockTime>,
}
impl TranslationPadTask {
fn try_new(
async fn try_new(
pad: &TranslateSrcPad,
elem: super::Transcriber,
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
) -> Result<TranslationPadTask, gst::ErrorMessage> {
let mut this = TranslationPadTask {
let mut translation_loop = None;
let mut translate_loop_handle = None;
let mut to_translate_tx = None;
let mut from_translate_rx = None;
let (our_latency, transcript_event_rx, needs_translate);
{
let elem_imp = elem.imp();
let elem_settings = elem_imp.settings.lock().unwrap();
let pad_settings = pad.settings.lock().unwrap();
our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
if our_latency + elem_settings.lateness <= 2 * GRANULARITY {
let err = format!(
"total latency + lateness must be greater than {}",
2 * GRANULARITY
);
gst::error!(CAT, imp: pad, "{err}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
}
needs_translate = TranslateSrcPad::needs_translation(
&elem_settings.language_code,
pad_settings.language_code.as_deref(),
);
if needs_translate {
let (to_loop_tx, to_loop_rx) = mpsc::channel(64);
let (from_loop_tx, from_loop_rx) = mpsc::channel(64);
translation_loop = Some(TranslateLoop::new(
elem_imp,
pad,
&elem_settings.language_code,
pad_settings.language_code.as_deref().unwrap(),
pad_settings.tokenization_method,
to_loop_rx,
from_loop_tx,
));
to_translate_tx = Some(to_loop_tx);
from_translate_rx = Some(from_loop_rx);
transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe();
} else {
transcript_event_rx = elem_imp.transcript_event_tx.subscribe();
}
}
if let Some(translation_loop) = translation_loop {
translation_loop.check_language().await?;
translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
}
Ok(TranslationPadTask {
pad: pad.ref_counted(),
elem,
transcript_event_rx,
needs_translate: false,
translate_queue: TranslateQueue::default(),
translate_loop_handle: None,
to_translate_tx: None,
from_translate_rx: None,
translate_latency: DEFAULT_TRANSLATE_LATENCY,
translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
needs_translate,
translate_loop_handle,
to_translate_tx,
from_translate_rx,
send_events: true,
output_items: VecDeque::new(),
our_latency: DEFAULT_TRANSCRIBE_LATENCY,
our_latency,
seqnum: gst::Seqnum::next(),
send_eos: false,
pending_translations: 0,
start_time: None,
};
let _enter_guard = RUNTIME.enter();
futures::executor::block_on(this.init_translate())?;
Ok(this)
})
}
}
@ -958,11 +1186,9 @@ impl TranslationPadTask {
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx);
// `timeout` takes precedence over `transcript_events` reception
// because we may need to `dequeue` `items` or push a `Gap` event
// before current latency budget is exhausted.
// `transcript_event_rx` takes precedence over `timeout`
// because we don't want to loose any incoming items.
futures::select_biased! {
_ = timeout => (),
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
@ -983,6 +1209,7 @@ impl TranslationPadTask {
}
}
}
_ = timeout => (),
}
Ok(())
@ -999,121 +1226,100 @@ impl TranslationPadTask {
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
let transcript_items = {
let items_to_translate = {
// This is to make sure we send items on a timely basis or at least Gap events.
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout);
let from_translate_rx = self
.from_translate_rx
.as_mut()
.expect("from_translation chan must be available in translation mode");
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx);
// `timeout` takes precedence over `transcript_events` reception
// because we may need to `dequeue` `items` or push a `Gap` event
// before current latency budget is exhausted.
// `transcript_event_rx` takes precedence over `timeout`
// because we don't want to loose any incoming items.
futures::select_biased! {
_ = timeout => return Ok(()),
translated_items = from_translate_rx.next() => {
let Some(translated_items) = translated_items else {
const ERR: &str = "translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
};
self.output_items.extend(translated_items.into_iter().map(Into::into));
self.pending_translations = self.pending_translations.saturating_sub(1);
return Ok(());
}
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
match items_res {
Ok(Items(transcript_items)) => transcript_items,
Ok(Items(items_to_translate)) => Some(items_to_translate),
Ok(Eos) => {
gst::debug!(CAT, imp: self.pad, "Got eos");
self.send_eos = true;
return Ok(());
None
}
Err(RecvError::Lagged(nb_msg)) => {
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
return Ok(());
None
}
Err(RecvError::Closed) => {
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
self.send_eos = true;
return Ok(());
None
}
}
}
_ = timeout => None,
}
};
for items in transcript_items.iter() {
if let Some(items_to_translate) = self.translate_queue.push(items) {
self.send_for_translation(items_to_translate).await?;
if let Some(items_to_translate) = items_to_translate {
if !items_to_translate.is_empty() {
let res = self
.to_translate_tx
.as_mut()
.expect("to_translation chan must be available in translation mode")
.send(items_to_translate)
.await;
if res.is_err() {
const ERR: &str = "to_translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
self.pending_translations += 1;
}
}
Ok(())
}
// Check pending translated items
let from_translate_rx = self
.from_translate_rx
.as_mut()
.expect("from_translation chan must be available in translation mode");
async fn dequeue_for_translation(
&mut self,
start_time: gst::ClockTime,
now: gst::ClockTime,
) -> Result<(), gst::ErrorMessage> {
if !self.translate_queue.is_empty() {
// Latency budget for an item to be pushed to stream on time
// Margin:
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
// - 1 * GRANULARITY: extra margin to account for additional overheads.
let latency = self.our_latency.saturating_sub(3 * GRANULARITY);
while let Ok(translated_items) = from_translate_rx.try_next() {
let Some(translated_items) = translated_items else {
const ERR: &str = "translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
};
// Estimated time of arrival for an item sent to translation now.
// (in transcript item ts base)
let translation_eta = now + self.translate_latency - start_time;
if let Some(items_to_translate) =
self.translate_queue
.dequeue(latency, translation_eta, self.translate_lookahead)
{
gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}");
self.send_for_translation(items_to_translate).await?;
}
self.output_items
.extend(translated_items.into_iter().map(Into::into));
self.pending_translations = self.pending_translations.saturating_sub(1);
}
Ok(())
}
async fn dequeue(&mut self) -> bool {
let (now, start_time, mut last_position, mut discont_pending);
{
let mut pad_state = self.pad.state.lock().unwrap();
let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else {
// Wait for the clock to be available
return true;
};
let Some(cur_rt) = self.elem.current_running_time() else {
// Wait for the clock to be available
return true;
let (mut last_position, mut discont_pending) = {
let mut state = self.pad.state.lock().unwrap();
let last_position = if let Some(pos) = state.out_segment.position() {
pos
} else {
state.out_segment.set_position(start_time);
start_time
};
now = cur_rt;
if self.start_time.is_none() {
self.start_time = Some(now);
pad_state.out_segment.set_position(now);
}
start_time = self.start_time.unwrap();
last_position = pad_state.out_segment.position().unwrap();
discont_pending = pad_state.discont_pending;
}
if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() {
return false;
}
(last_position, state.discont_pending)
};
/* First, check our pending buffers */
while let Some(item) = self.output_items.front() {
@ -1206,11 +1412,7 @@ impl TranslationPadTask {
}
}
if self.send_eos
&& self.pending_translations == 0
&& self.output_items.is_empty()
&& self.translate_queue.is_empty()
{
if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() {
/* We're EOS, we can pause and exit early */
let _ = self.pad.obj().pause_task();
@ -1261,28 +1463,6 @@ impl TranslationPadTask {
true
}
async fn send_for_translation(
&mut self,
transcript_items: Vec<TranscriptItem>,
) -> Result<(), gst::ErrorMessage> {
let res = self
.to_translate_tx
.as_mut()
.expect("to_translation chan must be available in translation mode")
.send(transcript_items)
.await;
if res.is_err() {
const ERR: &str = "to_translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
self.pending_translations += 1;
Ok(())
}
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
if !self.send_events {
return Ok(());
@ -1332,62 +1512,6 @@ impl TranslationPadTask {
}
}
impl TranslationPadTask {
async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
let mut translation_loop = None;
{
let elem_imp = self.elem.imp();
let elem_settings = elem_imp.settings.lock().unwrap();
let pad_settings = self.pad.settings.lock().unwrap();
self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
let err = format!(
"total latency + lateness must be greater than {}",
2 * GRANULARITY
);
gst::error!(CAT, imp: self.pad, "{err}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
}
self.translate_latency = elem_settings.translate_latency;
self.translate_lookahead = elem_settings.translate_lookahead;
self.needs_translate = TranslateSrcPad::needs_translation(
&elem_settings.language_code,
pad_settings.language_code.as_deref(),
);
if self.needs_translate {
let (to_translate_tx, to_translate_rx) = mpsc::channel(64);
let (from_translate_tx, from_translate_rx) = mpsc::channel(64);
translation_loop = Some(TranslateLoop::new(
elem_imp,
&self.pad,
&elem_settings.language_code,
pad_settings.language_code.as_deref().unwrap(),
pad_settings.tokenization_method,
to_translate_rx,
from_translate_tx,
));
self.to_translate_tx = Some(to_translate_tx);
self.from_translate_rx = Some(from_translate_rx);
}
}
if let Some(translation_loop) = translation_loop {
translation_loop.check_language().await?;
self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
}
Ok(())
}
}
#[derive(Debug)]
struct TranslationPadState {
discont_pending: bool,
@ -1422,8 +1546,8 @@ impl TranslateSrcPad {
gst::debug!(CAT, imp: self, "Starting task");
let elem = self.parent();
let transcript_event_rx = elem.imp().transcript_event_tx.subscribe();
let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx)
let _enter = RUNTIME.enter();
let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem))
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
let imp = self.ref_counted();

View file

@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc;
use futures::prelude::*;
use tokio::sync::broadcast;
use std::sync::Arc;
@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber};
use super::CAT;
#[derive(Debug)]
pub struct TranscriptionSettings {
pub struct TranscriberSettings {
lang_code: model::LanguageCode,
sample_rate: i32,
vocabulary: Option<String>,
@ -33,9 +32,9 @@ pub struct TranscriptionSettings {
results_stability: model::PartialResultsStability,
}
impl TranscriptionSettings {
impl TranscriberSettings {
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
TranscriptionSettings {
TranscriberSettings {
lang_code: settings.language_code.as_str().into(),
sample_rate,
vocabulary: settings.vocabulary.clone(),
@ -83,43 +82,30 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
}
}
pub struct TranscriberLoop {
pub struct TranscriberStream {
imp: glib::subclass::ObjectImplRef<Transcriber>,
client: aws_transcribe::Client,
settings: Option<TranscriptionSettings>,
output: aws_transcribe::output::StartStreamTranscriptionOutput,
lateness: gst::ClockTime,
buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
partial_index: usize,
}
impl TranscriberLoop {
pub fn new(
impl TranscriberStream {
pub async fn try_new(
imp: &Transcriber,
settings: TranscriptionSettings,
settings: TranscriberSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
) -> Self {
let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config
.as_ref()
.expect("aws_config must be initialized at this stage");
) -> Result<Self, gst::ErrorMessage> {
let client = {
let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config
.as_ref()
.expect("aws_config must be initialized at this stage");
aws_transcribe::Client::new(aws_config)
};
TranscriberLoop {
imp: imp.ref_counted(),
client: aws_transcribe::Client::new(aws_config),
settings: Some(settings),
lateness,
buffer_rx: Some(buffer_rx),
transcript_items_tx,
partial_index: 0,
}
}
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
// Stream the incoming buffers chunked
let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
@ -129,9 +115,7 @@ impl TranscriberLoop {
}
});
let settings = self.settings.take().unwrap();
let mut transcribe_builder = self
.client
let mut transcribe_builder = client
.start_stream_transcription()
.language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate)
@ -147,26 +131,42 @@ impl TranscriberLoop {
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
let mut output = transcribe_builder
let output = transcribe_builder
.audio_stream(chunk_stream.into())
.send()
.await
.map_err(|err| {
let err = format!("Transcribe ws init error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error!(CAT, imp: imp, "{err}");
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?;
while let Some(event) = output
.transcript_result_stream
.recv()
.await
.map_err(|err| {
let err = format!("Transcribe ws stream error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?
{
Ok(TranscriberStream {
imp: imp.ref_counted(),
output,
lateness,
partial_index: 0,
})
}
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
loop {
let event = self
.output
.transcript_result_stream
.recv()
.await
.map_err(|err| {
let err = format!("Transcribe ws stream error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?;
let Some(event) = event else {
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
return Ok(TranscriptEvent::Eos);
};
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut ready_items = None;
@ -188,10 +188,7 @@ impl TranscriberLoop {
}
if let Some(ready_items) = ready_items {
if self.transcript_items_tx.send(ready_items.into()).is_err() {
gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
break;
}
return Ok(ready_items.into());
}
} else {
gst::warning!(
@ -201,13 +198,6 @@ impl TranscriberLoop {
)
}
}
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
Ok(())
}
/// Builds a list from the provided stable items.

View file

@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate;
use futures::channel::mpsc;
use futures::prelude::*;
use std::collections::VecDeque;
use std::sync::Arc;
use super::imp::TranslateSrcPad;
use super::transcribe::TranscriptItem;
@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem {
}
}
#[derive(Default)]
pub struct TranslateQueue {
items: VecDeque<TranscriptItem>,
}
impl TranslateQueue {
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Pushes the provided item.
///
/// Returns `Some(..)` if items are ready for translation.
pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
// Keep track of the item individually so we can schedule translation precisely.
self.items.push_back(transcript_item.clone());
if transcript_item.is_punctuation {
// This makes it a good chunk for translation.
// Concatenate as a single item for translation
return Some(self.items.drain(..).collect());
}
// Regular case: no separator detected, don't push transcript items
// to translation now. They will be pushed either if a punctuation
// is found or of a `dequeue()` is requested.
None
}
/// Dequeues items from the specified `deadline` up to `lookahead`.
///
/// Returns `Some(..)` if some items match the criteria.
pub fn dequeue(
&mut self,
latency: gst::ClockTime,
threshold: gst::ClockTime,
lookahead: gst::ClockTime,
) -> Option<Vec<TranscriptItem>> {
let first_pts = self.items.front()?.pts;
if first_pts + latency > threshold {
// First item is too early to be sent to translation now
// we can wait for more items to accumulate.
return None;
}
// Can't wait any longer to send the first item to translation
// Try to get up to lookahead worth of items to improve translation accuracy
let limit = first_pts + lookahead;
let mut items_acc = vec![self.items.pop_front().unwrap()];
while let Some(item) = self.items.front() {
if item.pts > limit {
break;
}
items_acc.push(self.items.pop_front().unwrap());
}
Some(items_acc)
}
}
pub struct TranslateLoop {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
client: aws_translate::Client,
input_lang: String,
output_lang: String,
tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
}
@ -121,7 +57,7 @@ impl TranslateLoop {
input_lang: &str,
output_lang: &str,
tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
) -> Self {
let aws_config = imp.aws_config.lock().unwrap();
@ -175,12 +111,12 @@ impl TranslateLoop {
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
transcript_items
.into_iter()
.iter()
.map(|item| {
(
(item.pts, item.duration),
match self.tokenization_method {
Tokenization::None => item.content,
Tokenization::None => item.content.clone(),
Tokenization::SpanBased => {
format!("{SPAN_START}{}{SPAN_END}", item.content)
}