net/aws/transcriber: add translation request src pads

This commit adds an optional transcript translation feature implemented as
request src Pads.

When requesting a src Pad, the user can specify the translation language code
using Pad properties 'language-code'.

The following properties are defined on the Element:

- 'transcribe-latency': formerly 'latency', defines the expected latency for
  the Transcribe webservice.
- 'translate-latency': defines the expected latency for the Translate
  webservice.
- 'transcript-lookahead': maximum transcript duration to send to translation
  when a transcript is hitting its deadline and no punctuation was found.

When the input and output languages are the same, only the 'transcribe-latency'
is used for the Pad. Otherwise, the resulting latency is the addition of
'transcribe-latency' and 'translate-latency'.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1109>
This commit is contained in:
François Laignel 2023-03-10 14:47:38 +01:00 committed by GStreamer Marge Bot
parent 9a55fda69c
commit 743e97738f
6 changed files with 1707 additions and 730 deletions

View file

@ -628,6 +628,9 @@
"GInitiallyUnowned",
"GObject"
],
"interfaces": [
"GstChildProxy"
],
"klass": "Audio/Text/Filter",
"long-name": "Transcriber",
"pad-templates": {
@ -639,7 +642,14 @@
"src": {
"caps": "text/x-raw:\n format: utf8\n",
"direction": "src",
"presence": "always"
"presence": "always",
"type": "GstTranslationSrcPad"
},
"src_%%u": {
"caps": "text/x-raw:\n format: utf8\n",
"direction": "src",
"presence": "request",
"type": "GstTranslationSrcPad"
}
},
"properties": {
@ -668,7 +678,7 @@
"writable": true
},
"latency": {
"blurb": "Amount of milliseconds to allow AWS transcribe",
"blurb": "Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)",
"conditionally-available": false,
"construct": false,
"construct-only": false,
@ -743,6 +753,48 @@
"type": "gchararray",
"writable": true
},
"transcribe-latency": {
"blurb": "Amount of milliseconds to allow AWS transcribe",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "8000",
"max": "-1",
"min": "0",
"mutable": "ready",
"readable": true,
"type": "guint",
"writable": true
},
"transcript-lookahead": {
"blurb": "Maximum duration in milliseconds of transcript to lookahead before sending to translation when no separator was encountered",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "3000",
"max": "-1",
"min": "0",
"mutable": "ready",
"readable": true,
"type": "guint",
"writable": true
},
"translate-latency": {
"blurb": "Amount of milliseconds to allow AWS translate (ignored if the input and output languages are the same)",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "500",
"max": "-1",
"min": "0",
"mutable": "ready",
"readable": true,
"type": "guint",
"writable": true
},
"vocabulary-filter-method": {
"blurb": "Defines how filtered words will be edited, has no effect when vocabulary-filter-name isn't set",
"conditionally-available": false,
@ -845,6 +897,30 @@
"value": "2"
}
]
},
"GstTranslationSrcPad": {
"hierarchy": [
"GstTranslationSrcPad",
"GstPad",
"GstObject",
"GInitiallyUnowned",
"GObject"
],
"kind": "object",
"properties": {
"language-code": {
"blurb": "The Language the Stream must be translated to",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "NULL",
"mutable": "ready",
"readable": true,
"type": "gchararray",
"writable": true
}
}
}
},
"package": "gst-plugin-aws",

View file

@ -16,6 +16,7 @@ base32 = "0.4"
aws-config = "0.54.0"
aws-sdk-s3 = "0.24.0"
aws-sdk-transcribestreaming = "0.24.0"
aws-sdk-translate = "0.24.0"
aws-types = "0.54.0"
aws-credential-types = "0.54.0"
aws-sig-auth = "0.54.0"

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,18 @@ use gst::glib;
use gst::prelude::*;
mod imp;
mod transcribe;
mod translate;
use once_cell::sync::Lazy;
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
gst::DebugCategory::new(
"awstranscribe",
gst::DebugColorFlags::empty(),
Some("AWS Transcribe element"),
)
});
use aws_sdk_transcribestreaming::model::{PartialResultsStability, VocabularyFilterMethod};
@ -68,7 +80,11 @@ impl From<AwsTranscriberVocabularyFilterMethod> for VocabularyFilterMethod {
}
glib::wrapper! {
pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object;
pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object, @implements gst::ChildProxy;
}
glib::wrapper! {
pub struct TranslationSrcPad(ObjectSubclass<imp::TranslationSrcPad>) @extends gst::Pad, gst::Object;
}
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
@ -78,6 +94,7 @@ pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
AwsTranscriberVocabularyFilterMethod::static_type()
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
TranslationSrcPad::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
}
gst::Element::register(
Some(plugin),

View file

@ -0,0 +1,277 @@
// Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
// Copyright (C) 2023 François Laignel <francois@centricular.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
// <https://mozilla.org/MPL/2.0/>.
//
// SPDX-License-Identifier: MPL-2.0
use gst::subclass::prelude::*;
use gst::{glib, prelude::*};
use aws_sdk_transcribestreaming as aws_transcribe;
use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc;
use futures::prelude::*;
use tokio::sync::broadcast;
use std::sync::Arc;
use super::imp::{Settings, Transcriber};
use super::CAT;
#[derive(Debug)]
pub struct TranscriptionSettings {
lang_code: model::LanguageCode,
sample_rate: i32,
vocabulary: Option<String>,
vocabulary_filter: Option<String>,
vocabulary_filter_method: model::VocabularyFilterMethod,
session_id: Option<String>,
results_stability: model::PartialResultsStability,
}
impl TranscriptionSettings {
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
TranscriptionSettings {
lang_code: settings.language_code.as_str().into(),
sample_rate,
vocabulary: settings.vocabulary.clone(),
vocabulary_filter: settings.vocabulary_filter.clone(),
vocabulary_filter_method: settings.vocabulary_filter_method.into(),
session_id: settings.session_id.clone(),
results_stability: settings.results_stability.into(),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct TranscriptItem {
pub pts: gst::ClockTime,
pub duration: gst::ClockTime,
pub content: String,
pub is_punctuation: bool,
}
impl TranscriptItem {
pub fn from(item: model::Item, lateness: gst::ClockTime) -> Option<TranscriptItem> {
let content = item.content?;
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
Some(TranscriptItem {
pts: start_time,
duration: end_time - start_time,
content,
is_punctuation: matches!(item.r#type, Some(model::ItemType::Punctuation)),
})
}
#[inline]
pub fn push(&mut self, item: &TranscriptItem) {
self.duration += item.duration;
self.is_punctuation &= item.is_punctuation;
if !item.is_punctuation {
self.content.push(' ');
}
self.content.push_str(&item.content);
}
}
#[derive(Clone)]
pub enum TranscriptEvent {
Items(Arc<Vec<TranscriptItem>>),
Eos,
}
impl From<Vec<TranscriptItem>> for TranscriptEvent {
fn from(transcript_items: Vec<TranscriptItem>) -> Self {
TranscriptEvent::Items(transcript_items.into())
}
}
pub struct TranscriberLoop {
imp: glib::subclass::ObjectImplRef<Transcriber>,
client: aws_transcribe::Client,
settings: Option<TranscriptionSettings>,
lateness: gst::ClockTime,
buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
partial_index: usize,
}
impl TranscriberLoop {
pub fn new(
imp: &Transcriber,
settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
) -> Self {
let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config
.as_ref()
.expect("aws_config must be initialized at this stage");
TranscriberLoop {
imp: imp.ref_counted(),
client: aws_transcribe::Client::new(aws_config),
settings: Some(settings),
lateness,
buffer_rx: Some(buffer_rx),
transcript_items_tx,
partial_index: 0,
}
}
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
// Stream the incoming buffers chunked
let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
for chunk in data.chunks(8192) {
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
}
}
});
let settings = self.settings.take().unwrap();
let mut transcribe_builder = self
.client
.start_stream_transcription()
.language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate)
.media_encoding(model::MediaEncoding::Pcm)
.enable_partial_results_stabilization(true)
.partial_results_stability(settings.results_stability)
.set_vocabulary_name(settings.vocabulary)
.set_session_id(settings.session_id);
if let Some(vocabulary_filter) = settings.vocabulary_filter {
transcribe_builder = transcribe_builder
.vocabulary_filter_name(vocabulary_filter)
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
let mut output = transcribe_builder
.audio_stream(chunk_stream.into())
.send()
.await
.map_err(|err| {
let err = format!("Transcribe ws init error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?;
while let Some(event) = output
.transcript_result_stream
.recv()
.await
.map_err(|err| {
let err = format!("Transcribe ws stream error: {err}");
gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?
{
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut ready_items = None;
if let Some(result) = transcript_evt
.transcript
.and_then(|transcript| transcript.results)
.and_then(|mut results| results.drain(..).next())
{
gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
{
ready_items = alternative.items.and_then(|items| {
self.get_ready_transcript_items(items, result.is_partial)
});
}
}
if let Some(ready_items) = ready_items {
if self.transcript_items_tx.send(ready_items.into()).is_err() {
gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
break;
}
}
} else {
gst::warning!(
CAT,
imp: self.imp,
"Transcribe ws returned unknown event: consider upgrading the SDK"
)
}
}
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
Ok(())
}
/// Builds a list from the provided stable items.
fn get_ready_transcript_items(
&mut self,
mut items: Vec<model::Item>,
partial: bool,
) -> Option<Vec<TranscriptItem>> {
if items.len() <= self.partial_index {
gst::error!(
CAT,
imp: self.imp,
"sanity check failed, alternative length {} < partial_index {}",
items.len(),
self.partial_index
);
if !partial {
self.partial_index = 0;
}
return None;
}
let mut output = vec![];
for item in items.drain(self.partial_index..) {
if !item.stable().unwrap_or(false) {
break;
}
let Some(item) = TranscriptItem::from(item, self.lateness) else { continue };
gst::debug!(
CAT,
imp: self.imp,
"Item is ready for queuing: {}, PTS {}",
item.content,
item.pts,
);
self.partial_index += 1;
output.push(item);
}
if !partial {
self.partial_index = 0;
}
if output.is_empty() {
return None;
}
Some(output)
}
}

View file

@ -0,0 +1,215 @@
// Copyright (C) 2023 François Laignel <francois@centricular.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
// <https://mozilla.org/MPL/2.0/>.
//
// SPDX-License-Identifier: MPL-2.0
use gst::glib;
use gst::subclass::prelude::*;
use aws_sdk_translate as aws_translate;
use futures::channel::mpsc;
use futures::prelude::*;
use std::collections::VecDeque;
use super::imp::TranslationSrcPad;
use super::transcribe::TranscriptItem;
use super::CAT;
pub struct TranslatedItem {
pub pts: gst::ClockTime,
pub duration: gst::ClockTime,
pub content: String,
}
impl From<&TranscriptItem> for TranslatedItem {
fn from(transcript_item: &TranscriptItem) -> Self {
TranslatedItem {
pts: transcript_item.pts,
duration: transcript_item.duration,
content: transcript_item.content.clone(),
}
}
}
#[derive(Default)]
pub struct TranslationQueue {
items: VecDeque<TranscriptItem>,
}
impl TranslationQueue {
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Pushes the provided item.
///
/// Returns `Some(..)` if items are ready for translation.
pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<TranscriptItem> {
// Keep track of the item individually so we can schedule translation precisely.
self.items.push_back(transcript_item.clone());
if transcript_item.is_punctuation {
// This makes it a good chunk for translation.
// Concatenate as a single item for translation
let mut items = self.items.drain(..);
let mut item_acc = items.next()?;
for item in items {
item_acc.push(&item);
}
item_acc.push(transcript_item);
return Some(item_acc);
}
// Regular case: no separator detected, don't push transcript items
// to translation now. They will be pushed either if a punctuation
// is found or of a `dequeue()` is requested.
None
}
/// Dequeues items from the specified `deadline` up to `lookahead`.
///
/// Returns `Some(..)` with the accumulated items matching the criteria.
pub fn dequeue(
&mut self,
deadline: gst::ClockTime,
lookahead: gst::ClockTime,
) -> Option<TranscriptItem> {
if self.items.front()?.pts < deadline {
// First item is too early to be sent to translation now
// we can wait for more items to accumulate.
return None;
}
// Can't wait any longer to send the first item to translation
// Try to get up to lookahead more items to improve translation accuracy
let limit = deadline + lookahead;
let mut item_acc = self.items.pop_front().unwrap();
while let Some(item) = self.items.front() {
if item.pts > limit {
break;
}
let item = self.items.pop_front().unwrap();
item_acc.push(&item);
}
Some(item_acc)
}
}
pub struct TranslationLoop {
pad: glib::subclass::ObjectImplRef<TranslationSrcPad>,
client: aws_translate::Client,
input_lang: String,
output_lang: String,
transcript_rx: mpsc::Receiver<TranscriptItem>,
translation_tx: mpsc::Sender<TranslatedItem>,
}
impl TranslationLoop {
pub fn new(
imp: &super::imp::Transcriber,
pad: &TranslationSrcPad,
input_lang: &str,
output_lang: &str,
transcript_rx: mpsc::Receiver<TranscriptItem>,
translation_tx: mpsc::Sender<TranslatedItem>,
) -> Self {
let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config
.as_ref()
.expect("aws_config must be initialized at this stage");
TranslationLoop {
pad: pad.ref_counted(),
client: aws_sdk_translate::Client::new(aws_config),
input_lang: input_lang.to_string(),
output_lang: output_lang.to_string(),
transcript_rx,
translation_tx,
}
}
pub async fn check_language(&self) -> Result<(), gst::ErrorMessage> {
let language_list = self.client.list_languages().send().await.map_err(|err| {
let err = format!("Failed to call list_languages service: {err}");
gst::info!(CAT, imp: self.pad, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?;
let found_output_lang = language_list
.languages()
.and_then(|langs| {
langs
.iter()
.find(|lang| lang.language_code() == Some(&self.output_lang))
})
.is_some();
if !found_output_lang {
let err = format!("Unknown output languages: {}", self.output_lang);
gst::info!(CAT, imp: self.pad, "{err}");
return Err(gst::error_msg!(gst::LibraryError::Failed, ["{err}"]));
}
Ok(())
}
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
while let Some(transcript_item) = self.transcript_rx.next().await {
let TranscriptItem {
pts,
duration,
content,
..
} = transcript_item;
let translated_text = if content.is_empty() {
content
} else {
self.client
.translate_text()
.set_source_language_code(Some(self.input_lang.clone()))
.set_target_language_code(Some(self.output_lang.clone()))
.set_text(Some(content))
.send()
.await
.map_err(|err| {
let err = format!("Failed to call translation service: {err}");
gst::info!(CAT, imp: self.pad, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})?
.translated_text
.unwrap_or_default()
};
let translated_item = TranslatedItem {
pts,
duration,
content: translated_text,
};
if self.translation_tx.send(translated_item).await.is_err() {
gst::info!(
CAT,
imp: self.pad,
"translation chan terminated, exiting translation loop"
);
break;
}
}
Ok(())
}
}