net/aws/transcriber: fix deadlock when the pipeline is interrupted

... also makes sure to abort the taks_iter Future.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1122>
This commit is contained in:
François Laignel 2023-03-09 11:50:50 +01:00
parent f7c02cb3b0
commit 2ea9f147ab

View file

@ -15,6 +15,7 @@ use aws_sdk_transcribestreaming as aws_transcribe;
use aws_sdk_transcribestreaming::model; use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::future::AbortHandle;
use futures::prelude::*; use futures::prelude::*;
use tokio::{runtime, task}; use tokio::{runtime, task};
@ -230,6 +231,7 @@ struct State {
buffer_tx: Option<mpsc::Sender<gst::Buffer>>, buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
transcript_notif_tx: Option<mpsc::Sender<()>>, transcript_notif_tx: Option<mpsc::Sender<()>>,
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>, ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
task_abort_handle: Option<AbortHandle>,
in_segment: gst::FormattedSegment<gst::ClockTime>, in_segment: gst::FormattedSegment<gst::ClockTime>,
out_segment: gst::FormattedSegment<gst::ClockTime>, out_segment: gst::FormattedSegment<gst::ClockTime>,
seqnum: gst::Seqnum, seqnum: gst::Seqnum,
@ -248,6 +250,7 @@ impl Default for State {
buffer_tx: None, buffer_tx: None,
transcript_notif_tx: None, transcript_notif_tx: None,
ws_loop_handle: None, ws_loop_handle: None,
task_abort_handle: None,
in_segment: gst::FormattedSegment::new(), in_segment: gst::FormattedSegment::new(),
out_segment: gst::FormattedSegment::new(), out_segment: gst::FormattedSegment::new(),
seqnum: gst::Seqnum::next(), seqnum: gst::Seqnum::next(),
@ -545,11 +548,17 @@ impl Transcriber {
} }
}; };
let (abortable_future, abort_handle) = future::abortable(future);
self.state.lock().unwrap().task_abort_handle = Some(abort_handle);
let _enter = RUNTIME.enter(); let _enter = RUNTIME.enter();
futures::executor::block_on(future) if futures::executor::block_on(abortable_future).is_err() {
gst::debug!(CAT, imp: self, "task iter aborted");
}
} }
fn start_task(&self) -> Result<(), gst::LoggableError> { fn start_task(&self) -> Result<(), gst::LoggableError> {
gst::debug!(CAT, imp: self, "Starting task");
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1); let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1);
@ -566,20 +575,30 @@ impl Transcriber {
state.transcript_notif_tx = Some(transcript_notif_tx); state.transcript_notif_tx = Some(transcript_notif_tx);
gst::debug!(CAT, imp: self, "Task started");
Ok(()) Ok(())
} }
fn stop_task(&self) { fn stop_task(&self) {
let mut state = self.state.lock().unwrap(); gst::debug!(CAT, imp: self, "Stopping task");
let _ = self.srcpad.stop_task(); let _ = self.srcpad.stop_task();
let mut state = self.state.lock().unwrap();
if let Some(task_abort_handle) = state.task_abort_handle.take() {
task_abort_handle.abort();
}
if let Some(ws_loop_handle) = state.ws_loop_handle.take() { if let Some(ws_loop_handle) = state.ws_loop_handle.take() {
ws_loop_handle.abort(); ws_loop_handle.abort();
} }
state.transcript_notif_tx = None; state.transcript_notif_tx = None;
state.buffer_tx = None; state.buffer_tx = None;
gst::debug!(CAT, imp: self, "Task Stopped");
} }
fn stop_ws_loop(&self) { fn stop_ws_loop(&self) {
@ -849,11 +868,13 @@ impl Transcriber {
} }
fn disconnect(&self) { fn disconnect(&self) {
let mut state = self.state.lock().unwrap();
gst::info!(CAT, imp: self, "Unpreparing"); gst::info!(CAT, imp: self, "Unpreparing");
self.stop_task(); self.stop_task();
// Also resets discont to true // Also resets discont to true
*state = State::default(); *self.state.lock().unwrap() = State::default();
gst::info!(CAT, imp: self, "Unprepared"); gst::info!(CAT, imp: self, "Unprepared");
} }
} }