awstranscriber: implement use-partial-results property

The current implementation only makes use of non-partial results,
requiring a crazy high latency.

With this mode, we use items from partial results when they're
older than latency - 2 * GRANULARITY_MS. Depending on the latency
that the user has set this may result in reduced accuracy, the
default latency has been modified to a pretty conservative sweet
spot of 8 seconds.

This complexifies the code a bit, as items aren't identified by
AWS, and their timings can change.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/348>
This commit is contained in:
Mathieu Duponchelle 2020-05-28 23:55:00 +02:00 committed by Mathieu Duponchelle
parent 08da51744b
commit 815aa80789

View file

@ -108,10 +108,11 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
.unwrap()
});
const DEFAULT_LATENCY_MS: u32 = 30000;
const DEFAULT_LATENCY_MS: u32 = 8000;
const DEFAULT_USE_PARTIAL_RESULTS: bool = true;
const GRANULARITY_MS: u32 = 100;
static PROPERTIES: [subclass::Property; 2] = [
static PROPERTIES: [subclass::Property; 3] = [
subclass::Property("language-code", |name| {
glib::ParamSpec::string(
name,
@ -123,12 +124,21 @@ static PROPERTIES: [subclass::Property; 2] = [
glib::ParamFlags::READWRITE,
)
}),
subclass::Property("use-partial-results", |name| {
glib::ParamSpec::boolean(
name,
"Latency",
"Whether partial results from AWS should be used",
DEFAULT_USE_PARTIAL_RESULTS,
glib::ParamFlags::READWRITE,
)
}),
subclass::Property("latency", |name| {
glib::ParamSpec::uint(
name,
"Latency",
"Amount of milliseconds to allow AWS transcribe",
GRANULARITY_MS,
2 * GRANULARITY_MS,
std::u32::MAX,
DEFAULT_LATENCY_MS,
glib::ParamFlags::READWRITE,
@ -140,6 +150,7 @@ static PROPERTIES: [subclass::Property; 2] = [
struct Settings {
latency_ms: u32,
language_code: Option<String>,
use_partial_results: bool,
}
impl Default for Settings {
@ -147,6 +158,7 @@ impl Default for Settings {
Self {
latency_ms: DEFAULT_LATENCY_MS,
language_code: Some("en-US".to_string()),
use_partial_results: DEFAULT_USE_PARTIAL_RESULTS,
}
}
}
@ -162,6 +174,8 @@ struct State {
buffers: VecDeque<gst::Buffer>,
send_eos: bool,
discont: bool,
last_partial_end_time: gst::ClockTime,
partial_alternative: Option<TranscriptAlternative>,
}
impl Default for State {
@ -177,6 +191,8 @@ impl Default for State {
buffers: VecDeque::new(),
send_eos: false,
discont: true,
last_partial_end_time: gst::CLOCK_TIME_NONE,
partial_alternative: None,
}
}
}
@ -257,13 +273,19 @@ impl Transcriber {
let (latency, now, mut last_position, send_eos, seqnum) = {
let mut state = self.state.lock().unwrap();
let send_eos = state.send_eos && state.buffers.is_empty();
// Multiply GRANULARITY by 2 in order to not send buffers that
// are less than GRANULARITY milliseconds away too late
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
- GRANULARITY_MS as u64)
- (2 * GRANULARITY_MS) as u64)
* gst::MSECOND;
let now = element.get_current_running_time();
if let Some(alternative) = state.partial_alternative.take() {
self.enqueue(element, &mut state, &alternative, true, latency, now);
state.partial_alternative = Some(alternative);
}
let send_eos = state.send_eos && state.buffers.is_empty();
while let Some(buf) = state.buffers.front() {
if now - buf.get_pts() > latency {
/* Safe unwrap, we know we have an item */
@ -352,6 +374,64 @@ impl Transcriber {
true
}
fn enqueue(
&self,
element: &gst::Element,
state: &mut State,
alternative: &TranscriptAlternative,
partial: bool,
latency: gst::ClockTime,
now: gst::ClockTime,
) {
for item in &alternative.items {
let mut start_time: gst::ClockTime =
((item.start_time as f64 * 1_000_000_000.0) as u64).into();
let mut end_time: gst::ClockTime =
((item.end_time as f64 * 1_000_000_000.0) as u64).into();
if start_time <= state.last_partial_end_time {
/* Already sent (hopefully) */
continue;
} else if !partial || start_time + latency < now {
/* Should be sent now */
gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
state.last_partial_end_time = end_time;
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
if start_time < state.out_segment.get_position() {
gst_debug!(
CAT,
obj: element,
"Adjusting item timing({:?} < {:?})",
start_time,
state.out_segment.get_position()
);
start_time = state.out_segment.get_position();
if end_time < start_time {
end_time = start_time;
}
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
state.buffers.push_back(buf);
} else {
/* Doesn't need to be sent yet */
break;
}
}
}
fn loop_fn(
&self,
element: &gst::Element,
@ -417,50 +497,78 @@ impl Transcriber {
if !transcript.transcript.results.is_empty() {
let mut result = transcript.transcript.results.remove(0);
let use_partial_results = self.settings.lock().unwrap().use_partial_results;
if !result.is_partial && !result.alternatives.is_empty() {
let alternative = result.alternatives.remove(0);
gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript);
let mut start_time: gst::ClockTime =
((result.start_time as f64 * 1_000_000_000.0) as u64).into();
let end_time: gst::ClockTime =
((result.end_time as f64 * 1_000_000_000.0) as u64).into();
let mut state = self.state.lock().unwrap();
let position = state.out_segment.get_position();
if end_time < position {
gst_warning!(CAT, obj: element,
"Received transcript is too late by {:?}, dropping, consider increasing the latency",
position - start_time);
} else {
if start_time < position {
gst_warning!(CAT, obj: element,
"Received transcript is too late by {:?}, clipping, consider increasing the latency",
position - start_time);
start_time = position;
}
let mut buf = gst::Buffer::from_mut_slice(
alternative.transcript.into_bytes(),
if !use_partial_results {
let alternative = result.alternatives.remove(0);
gst_info!(
CAT,
obj: element,
"Transcript: {}",
alternative.transcript
);
{
let buf = buf.get_mut().unwrap();
let mut start_time: gst::ClockTime =
((result.start_time as f64 * 1_000_000_000.0) as u64).into();
let end_time: gst::ClockTime =
((result.end_time as f64 * 1_000_000_000.0) as u64).into();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
let mut state = self.state.lock().unwrap();
let position = state.out_segment.get_position();
if end_time < position {
gst_warning!(CAT, obj: element,
"Received transcript is too late by {:?}, dropping, consider increasing the latency",
position - start_time);
} else {
if start_time < position {
gst_warning!(CAT, obj: element,
"Received transcript is too late by {:?}, clipping, consider increasing the latency",
position - start_time);
start_time = position;
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
let mut buf = gst::Buffer::from_mut_slice(
alternative.transcript.into_bytes(),
);
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
gst_debug!(
CAT,
obj: element,
"Adding pending buffer: {:?}",
buf
);
state.buffers.push_back(buf);
}
gst_debug!(CAT, obj: element, "Adding pending buffer: {:?}", buf);
state.buffers.push_back(buf);
} else {
let alternative = result.alternatives.remove(0);
let mut state = self.state.lock().unwrap();
self.enqueue(
element,
&mut state,
&alternative,
false,
0.into(),
0.into(),
);
state.partial_alternative = None;
}
} else if !result.alternatives.is_empty() && use_partial_results {
let mut state = self.state.lock().unwrap();
state.partial_alternative = Some(result.alternatives.remove(0));
}
}
Ok(())
@ -1001,6 +1109,10 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.latency_ms = value.get_some().expect("type checked upstream");
}
subclass::Property("use-partial-results", ..) => {
let mut settings = self.settings.lock().unwrap();
settings.use_partial_results = value.get_some().expect("type checked upstream");
}
_ => unimplemented!(),
}
}
@ -1017,6 +1129,10 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
Ok(settings.latency_ms.to_value())
}
subclass::Property("use-partial-results", ..) => {
let settings = self.settings.lock().unwrap();
Ok(settings.use_partial_results.to_value())
}
_ => unimplemented!(),
}
}