diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index bc160b63..1d77b14a 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -143,12 +143,14 @@ impl From for OutputItem { } struct State { - buffer_tx: Option>, + // second tuple member is running time + buffer_tx: Option>, transcriber_loop_handle: Option>>, srcpads: BTreeSet, pad_serial: u32, seqnum: gst::Seqnum, start_time: Option, + in_segment: gst::FormattedSegment, } impl Default for State { @@ -160,6 +162,7 @@ impl Default for State { pad_serial: 0, seqnum: gst::Seqnum::next(), start_time: None, + in_segment: gst::FormattedSegment::new(), } } } @@ -251,17 +254,21 @@ impl Transcriber { } } Segment(e) => { - let format = e.segment().format(); - if format != gst::Format::Time { - gst::element_imp_error!( - self, - gst::StreamError::Format, - ["Only Time segments supported, got {format:?}"] - ); - return false; + let segment = match e.segment().clone().downcast::() { + Err(segment) => { + gst::element_imp_error!( + self, + gst::StreamError::Format, + ["Only Time segments supported, got {:?}", segment.format(),] + ); + return false; + } + Ok(segment) => segment, }; - self.state.lock().unwrap().seqnum = e.seqnum(); + let mut state = self.state.lock().unwrap(); + state.seqnum = e.seqnum(); + state.in_segment = segment; true } @@ -297,12 +304,26 @@ impl Transcriber { gst::FlowError::Error })?; + let rtime = match self + .state + .lock() + .unwrap() + .in_segment + .to_running_time(buffer.pts()) + { + Some(rtime) => rtime, + None => { + gst::debug!(CAT, "Buffer outside segment, clipping (buffer:?)"); + return Ok(gst::FlowSuccess::Ok); + } + }; + let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else { gst::log!(CAT, obj: pad, "Flushing"); return Err(gst::FlowError::Flushing); }; - futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| { + futures::executor::block_on(buffer_tx.send((buffer, rtime))).map_err(|err| { gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]); gst::FlowError::Error })?; diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs index 826e4918..fb32b77f 100644 --- a/net/aws/src/transcriber/transcribe.rs +++ b/net/aws/src/transcriber/transcribe.rs @@ -16,7 +16,7 @@ use aws_sdk_transcribestreaming::types; use futures::channel::mpsc; use futures::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::imp::{Settings, Transcriber}; use super::CAT; @@ -55,11 +55,19 @@ pub struct TranscriptItem { } impl TranscriptItem { - pub fn from(item: types::Item, lateness: gst::ClockTime) -> Option { + pub fn from( + item: types::Item, + lateness: gst::ClockTime, + discont_offset: gst::ClockTime, + ) -> Option { let content = item.content?; - let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; - let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let start_time = + ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset; + let end_time = + ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset; + + gst::error!(CAT, "Discont offset is {discont_offset}"); Some(TranscriptItem { pts: start_time, @@ -82,11 +90,17 @@ impl From> for TranscriptEvent { } } +struct DiscontOffsetTracker { + discont_offset: gst::ClockTime, + last_chained_buffer_rtime: Option, +} + pub struct TranscriberStream { imp: glib::subclass::ObjectImplRef, output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput, lateness: gst::ClockTime, partial_index: usize, + discont_offset_tracker: Arc>, } impl TranscriberStream { @@ -94,7 +108,7 @@ impl TranscriberStream { imp: &Transcriber, settings: TranscriberSettings, lateness: gst::ClockTime, - buffer_rx: mpsc::Receiver, + buffer_rx: mpsc::Receiver<(gst::Buffer, gst::ClockTime)>, ) -> Result { let client = { let aws_config = imp.aws_config.lock().unwrap(); @@ -104,8 +118,23 @@ impl TranscriberStream { aws_transcribe::Client::new(aws_config) }; + let discont_offset_tracker = Arc::new(Mutex::new(DiscontOffsetTracker { + discont_offset: gst::ClockTime::ZERO, + last_chained_buffer_rtime: gst::ClockTime::NONE, + })); + + let discont_offset_tracker_clone = discont_offset_tracker.clone(); + // Stream the incoming buffers chunked - let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| { + let chunk_stream = buffer_rx.flat_map(move |(buffer, running_time)| { + let mut discont_offset_tracker = discont_offset_tracker_clone.lock().unwrap(); + if buffer.flags().contains(gst::BufferFlags::DISCONT) { + if let Some(last_chained_buffer_rtime) = discont_offset_tracker.last_chained_buffer_rtime { + discont_offset_tracker.discont_offset += running_time.saturating_sub(last_chained_buffer_rtime); + } + } + + discont_offset_tracker.last_chained_buffer_rtime = Some(running_time); async_stream::stream! { let data = buffer.map_readable().unwrap(); use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob}; @@ -146,6 +175,7 @@ impl TranscriberStream { output, lateness, partial_index: 0, + discont_offset_tracker, }) } @@ -229,7 +259,9 @@ impl TranscriberStream { break; } - let Some(item) = TranscriptItem::from(item, self.lateness) else { continue }; + let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset; + + let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else { continue }; gst::debug!( CAT, imp: self.imp,