From 2839e0078b6b827af7929c1e85aeb5e098790b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Mon, 20 Nov 2023 17:02:55 +0200 Subject: [PATCH] rtp: Port RTP AV1 payloader/depayloader to new base classes Part-of: --- docs/plugins/gst_plugins_cache.json | 8 +- net/rtp/src/av1/depay/imp.rs | 202 ++++---- net/rtp/src/av1/depay/mod.rs | 12 +- net/rtp/src/av1/depay/tests.rs | 117 +++++ net/rtp/src/av1/pay/imp.rs | 443 +++++++----------- net/rtp/src/av1/pay/mod.rs | 11 +- .../{tests/rtpav1.rs => src/av1/pay/tests.rs} | 108 +---- net/rtp/src/lib.rs | 7 +- 8 files changed, 422 insertions(+), 486 deletions(-) create mode 100644 net/rtp/src/av1/depay/tests.rs rename net/rtp/{tests/rtpav1.rs => src/av1/pay/tests.rs} (54%) diff --git a/docs/plugins/gst_plugins_cache.json b/docs/plugins/gst_plugins_cache.json index ef0f4a11..40d5bbea 100644 --- a/docs/plugins/gst_plugins_cache.json +++ b/docs/plugins/gst_plugins_cache.json @@ -6193,7 +6193,7 @@ "description": "Depayload AV1 from RTP packets", "hierarchy": [ "GstRtpAv1Depay", - "GstRTPBaseDepayload", + "GstRtpBaseDepay2", "GstElement", "GstObject", "GInitiallyUnowned", @@ -6203,7 +6203,7 @@ "long-name": "RTP AV1 Depayloader", "pad-templates": { "sink": { - "caps": "application/x-rtp:\n media: video\n payload: [ 96, 127 ]\n clock-rate: 90000\n encoding-name: AV1\n", + "caps": "application/x-rtp:\n media: video\n clock-rate: 90000\n encoding-name: AV1\n", "direction": "sink", "presence": "always" }, @@ -6220,7 +6220,7 @@ "description": "Payload AV1 as RTP packets", "hierarchy": [ "GstRtpAv1Pay", - "GstRTPBasePayload", + "GstRtpBasePay2", "GstElement", "GstObject", "GInitiallyUnowned", @@ -6680,7 +6680,7 @@ "construct": false, "construct-only": false, "controllable": false, - "default": "8", + "default": "96", "max": "127", "min": "0", "mutable": "ready", diff --git a/net/rtp/src/av1/depay/imp.rs b/net/rtp/src/av1/depay/imp.rs index 0cc92f70..68f7587b 100644 --- a/net/rtp/src/av1/depay/imp.rs +++ b/net/rtp/src/av1/depay/imp.rs @@ -7,28 +7,36 @@ // // SPDX-License-Identifier: MPL-2.0 -use gst::{glib, subclass::prelude::*}; -use gst_rtp::prelude::*; -use gst_rtp::subclass::prelude::*; +use atomic_refcell::AtomicRefCell; +use gst::{glib, prelude::*, subclass::prelude::*}; use std::{ cmp::Ordering, io::{Cursor, Read, Seek, SeekFrom}, - sync::Mutex, }; use bitstream_io::{BitReader, BitWriter}; use once_cell::sync::Lazy; -use crate::av1::common::{ - err_flow, leb128_size, parse_leb128, write_leb128, AggregationHeader, ObuType, SizedObu, - UnsizedObu, CLOCK_RATE, ENDIANNESS, +use crate::{ + av1::common::{ + err_flow, leb128_size, parse_leb128, write_leb128, AggregationHeader, ObuType, SizedObu, + UnsizedObu, CLOCK_RATE, ENDIANNESS, + }, + basedepay::PacketToBufferRelation, }; +use crate::basedepay::RtpBaseDepay2Ext; + // TODO: handle internal size fields in RTP OBUs -#[derive(Debug)] +struct PendingFragment { + ext_seqnum: u64, + obu: UnsizedObu, + bytes: Vec, +} + struct State { - last_timestamp: Option, + last_timestamp: Option, /// if true, the last packet of a temporal unit has been received marked_packet: bool, /// if the next output buffer needs the DISCONT flag set @@ -36,7 +44,7 @@ struct State { /// if we saw a valid OBU since the last reset found_valid_obu: bool, /// holds data for a fragment - obu_fragment: Option<(UnsizedObu, Vec)>, + obu_fragment: Option, } impl Default for State { @@ -51,9 +59,9 @@ impl Default for State { } } -#[derive(Debug, Default)] +#[derive(Default)] pub struct RTPAv1Depay { - state: Mutex, + state: AtomicRefCell, } static CAT: Lazy = Lazy::new(|| { @@ -78,7 +86,7 @@ impl RTPAv1Depay { impl ObjectSubclass for RTPAv1Depay { const NAME: &'static str = "GstRtpAv1Depay"; type Type = super::RTPAv1Depay; - type ParentType = gst_rtp::RTPBaseDepayload; + type ParentType = crate::basedepay::RtpBaseDepay2; } impl ObjectImpl for RTPAv1Depay {} @@ -107,7 +115,6 @@ impl ElementImpl for RTPAv1Depay { gst::PadPresence::Always, &gst::Caps::builder("application/x-rtp") .field("media", "video") - .field("payload", gst::IntRange::new(96, 127)) .field("clock-rate", CLOCK_RATE as i32) .field("encoding-name", "AV1") .build(), @@ -131,87 +138,66 @@ impl ElementImpl for RTPAv1Depay { PAD_TEMPLATES.as_ref() } - - fn change_state( - &self, - transition: gst::StateChange, - ) -> Result { - gst::debug!(CAT, imp: self, "changing state: {}", transition); - - if matches!(transition, gst::StateChange::ReadyToPaused) { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state); - } - - let ret = self.parent_change_state(transition); - - if matches!(transition, gst::StateChange::PausedToReady) { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state); - } - - ret - } } -impl RTPBaseDepayloadImpl for RTPAv1Depay { - fn set_caps(&self, _caps: &gst::Caps) -> Result<(), gst::LoggableError> { - let element = self.obj(); - let src_pad = element.src_pad(); - let src_caps = src_pad.pad_template_caps(); - src_pad.push_event(gst::event::Caps::builder(&src_caps).build()); +impl crate::basedepay::RtpBaseDepay2Impl for RTPAv1Depay { + const ALLOWED_META_TAGS: &'static [&'static str] = &["video"]; + + fn start(&self) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.borrow_mut(); + self.reset(&mut state); Ok(()) } - fn handle_event(&self, event: gst::Event) -> bool { - match event.view() { - gst::EventView::Eos(_) | gst::EventView::FlushStop(_) => { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state); - } - _ => (), - } + fn stop(&self) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.borrow_mut(); + self.reset(&mut state); - self.parent_handle_event(event) + Ok(()) } - fn process_rtp_packet( + fn set_sink_caps(&self, _caps: &gst::Caps) -> bool { + self.obj() + .set_src_caps(&self.obj().src_pad().pad_template_caps()); + + true + } + + fn flush(&self) { + let mut state = self.state.borrow_mut(); + self.reset(&mut state); + } + + fn handle_packet( &self, - rtp: &gst_rtp::RTPBuffer, - ) -> Option { - if let Err(err) = self.handle_rtp_packet(rtp) { + packet: &crate::basedepay::Packet, + ) -> Result { + let res = self.handle_rtp_packet(packet); + + if let Err(err) = res { gst::warning!(CAT, imp: self, "Failed to handle RTP packet: {err:?}"); - self.reset(&mut self.state.lock().unwrap()); + self.reset(&mut self.state.borrow_mut()); } - None + res } } impl RTPAv1Depay { fn handle_rtp_packet( &self, - rtp: &gst_rtp::RTPBuffer, - ) -> Result<(), gst::FlowError> { - gst::log!( + packet: &crate::basedepay::Packet, + ) -> Result { + gst::trace!( CAT, imp: self, - "processing RTP packet with payload type {} and size {}", - rtp.payload_type(), - rtp.buffer().size(), + "Processing RTP packet {packet:?}", ); - let payload = rtp.payload().map_err(err_flow!(self, payload_buf))?; + let mut state = self.state.borrow_mut(); - let mut state = self.state.lock().unwrap(); - - if rtp.buffer().flags().contains(gst::BufferFlags::DISCONT) { - gst::debug!(CAT, imp: self, "buffer discontinuity"); - self.reset(&mut state); - } - - let mut reader = Cursor::new(payload); + let mut reader = Cursor::new(packet.payload()); let mut ready_obus = Vec::new(); let aggr_header = { @@ -223,7 +209,7 @@ impl RTPAv1Depay { }; // handle new temporal units - if state.marked_packet || state.last_timestamp != Some(rtp.timestamp()) { + if state.marked_packet || state.last_timestamp != Some(packet.ext_timestamp()) { if state.last_timestamp.is_some() && state.obu_fragment.is_some() { gst::error!( CAT, @@ -242,8 +228,8 @@ impl RTPAv1Depay { // the next temporal unit starts with a temporal delimiter OBU ready_obus.extend_from_slice(&TEMPORAL_DELIMITER); } - state.marked_packet = rtp.is_marker(); - state.last_timestamp = Some(rtp.timestamp()); + state.marked_packet = packet.marker_bit(); + state.last_timestamp = Some(packet.ext_timestamp()); // parse and prepare the received OBUs let mut idx = 0; @@ -258,10 +244,20 @@ impl RTPAv1Depay { self.reset(&mut state); } - if let Some((obu, ref mut bytes)) = &mut state.obu_fragment { + // If we finish an OBU here, it will start with the ext seqnum of this packet + // but if it also extends a fragment then the start will be set to the start + // of the fragment instead. + let mut start_ext_seqnum = packet.ext_seqnum(); + + if let Some(PendingFragment { + ext_seqnum, + obu, + ref mut bytes, + }) = state.obu_fragment + { assert!(aggr_header.leading_fragment); let (element_size, is_last_obu) = self - .find_element_info(rtp, &mut reader, &aggr_header, idx) + .find_element_info(&mut reader, &aggr_header, idx) .map_err(err_flow!(self, find_element))?; let bytes_end = bytes.len(); @@ -283,6 +279,7 @@ impl RTPAv1Depay { &full_obu, &mut ready_obus, )?; + start_ext_seqnum = ext_seqnum; state.obu_fragment = None; } @@ -290,9 +287,9 @@ impl RTPAv1Depay { } // handle other OBUs, including trailing fragments - while reader.position() < rtp.payload_size() as u64 { + while (reader.position() as usize) < reader.get_ref().len() { let (element_size, is_last_obu) = - self.find_element_info(rtp, &mut reader, &aggr_header, idx)?; + self.find_element_info(&mut reader, &aggr_header, idx)?; if idx == 0 && aggr_header.leading_fragment { if state.found_valid_obu { @@ -330,13 +327,17 @@ impl RTPAv1Depay { // trailing OBU fragments are stored in the state if is_last_obu && aggr_header.trailing_fragment { - let bytes_left = rtp.payload_size() - (reader.position() as u32); - let mut bytes = vec![0; bytes_left as usize]; + let bytes_left = reader.get_ref().len() - (reader.position() as usize); + let mut bytes = vec![0; bytes_left]; reader .read_exact(bytes.as_mut_slice()) .map_err(err_flow!(self, buf_read))?; - state.obu_fragment = Some((obu, bytes)); + state.obu_fragment = Some(PendingFragment { + ext_seqnum: packet.ext_seqnum(), + obu, + bytes, + }); } // full OBUs elements are translated and appended to the ready OBUs else { @@ -396,10 +397,13 @@ impl RTPAv1Depay { drop(state); if let Some(buffer) = buffer { - self.obj().push(buffer)?; + self.obj().queue_buffer( + PacketToBufferRelation::Seqnums(start_ext_seqnum..=packet.ext_seqnum()), + buffer, + ) + } else { + Ok(gst::FlowSuccess::Ok) } - - Ok(()) } /// Find out the next OBU element's size, and if it is the last OBU in the packet. @@ -408,7 +412,6 @@ impl RTPAv1Depay { /// and will be at the first byte past the element's size field afterwards. fn find_element_info( &self, - rtp: &gst_rtp::RTPBuffer, reader: &mut Cursor<&[u8]>, aggr_header: &AggregationHeader, index: u32, @@ -418,7 +421,7 @@ impl RTPAv1Depay { let element_size = if let Some(count) = aggr_header.obu_count { is_last_obu = index + 1 == count as u32; if is_last_obu { - rtp.payload_size() - (reader.position() as u32) + (reader.get_ref().len() - reader.position() as usize) as u32 } else { let mut bitreader = BitReader::endian(reader, ENDIANNESS); let (size, _) = parse_leb128(&mut bitreader).map_err(err_flow!(self, leb_read))?; @@ -427,7 +430,11 @@ impl RTPAv1Depay { } else { let (size, _) = parse_leb128(&mut BitReader::endian(&mut *reader, ENDIANNESS)) .map_err(err_flow!(self, leb_read))?; - is_last_obu = match rtp.payload_size().cmp(&(reader.position() as u32 + size)) { + is_last_obu = match reader + .get_ref() + .len() + .cmp(&(reader.position() as usize + size as usize)) + { Ordering::Greater => false, Ordering::Equal => true, Ordering::Less => { @@ -545,7 +552,9 @@ mod tests { ) ]; - let element = ::Type::new(); + // Element exists just for logging purposes + let element = glib::Object::new::(); + for (idx, (obu, rtp_bytes, out_bytes)) in test_data.into_iter().enumerate() { println!("running test {idx}..."); let mut reader = Cursor::new(rtp_bytes.as_slice()); @@ -563,40 +572,35 @@ mod tests { fn test_find_element_info() { gst::init().unwrap(); - let test_data: [(Vec<(u32, bool)>, u32, Vec, AggregationHeader); 4] = [ + let test_data: [(Vec<(u32, bool)>, Vec, AggregationHeader); 4] = [ ( vec![(1, false)], // expected results - 100, // RTP payload size - vec![0b0000_0001, 0b0001_0000], + vec![0b0000_0001, 0b0001_0000, 0], AggregationHeader { obu_count: None, ..AggregationHeader::default() }, ), ( vec![(5, true)], - 5, vec![0b0111_1000, 0, 0, 0, 0], AggregationHeader { obu_count: Some(1), ..AggregationHeader::default() }, ), ( vec![(7, true)], - 8, vec![0b0000_0111, 0b0011_0110, 0b0010_1000, 0b0000_1010, 1, 2, 3, 4], AggregationHeader { obu_count: None, ..AggregationHeader::default() }, ), ( vec![(6, false), (4, true)], - 11, vec![0b0000_0110, 0b0111_1000, 1, 2, 3, 4, 5, 0b0011_0000, 1, 2, 3], AggregationHeader { obu_count: Some(2), ..AggregationHeader::default() }, ) ]; - let element = ::Type::new(); + // Element exists just for logging purposes + let element = glib::Object::new::(); + for (idx, ( info, - payload_size, rtp_bytes, aggr_header, )) in test_data.into_iter().enumerate() { println!("running test {idx}..."); - let buffer = gst::Buffer::new_rtp_with_sizes(payload_size, 0, 0).unwrap(); - let rtp = gst_rtp::RTPBuffer::from_buffer_readable(&buffer).unwrap(); let mut reader = Cursor::new(rtp_bytes.as_slice()); let mut element_size = 0; @@ -607,7 +611,7 @@ mod tests { println!("testing element {} with reader position {}...", obu_idx, reader.position()); - let actual = element.imp().find_element_info(&rtp, &mut reader, &aggr_header, obu_idx as u32); + let actual = element.imp().find_element_info(&mut reader, &aggr_header, obu_idx as u32); assert_eq!(actual, Ok(expected)); element_size = actual.unwrap().0; } diff --git a/net/rtp/src/av1/depay/mod.rs b/net/rtp/src/av1/depay/mod.rs index be18ecd0..1131f644 100644 --- a/net/rtp/src/av1/depay/mod.rs +++ b/net/rtp/src/av1/depay/mod.rs @@ -6,22 +6,18 @@ // . // // SPDX-License-Identifier: MPL-2.0 -#![allow(clippy::new_without_default)] use gst::glib; use gst::prelude::*; pub mod imp; +#[cfg(test)] +mod tests; + glib::wrapper! { pub struct RTPAv1Depay(ObjectSubclass) - @extends gst_rtp::RTPBaseDepayload, gst::Element, gst::Object; -} - -impl RTPAv1Depay { - pub fn new() -> Self { - glib::Object::new() - } + @extends crate::basedepay::RtpBaseDepay2, gst::Element, gst::Object; } pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { diff --git a/net/rtp/src/av1/depay/tests.rs b/net/rtp/src/av1/depay/tests.rs new file mode 100644 index 00000000..b366effd --- /dev/null +++ b/net/rtp/src/av1/depay/tests.rs @@ -0,0 +1,117 @@ +// +// Copyright (C) 2022 Vivienne Watermeier +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use gst::{event::Eos, Caps}; +use gst_check::Harness; + +fn init() { + use std::sync::Once; + static INIT: Once = Once::new(); + + INIT.call_once(|| { + gst::init().unwrap(); + crate::plugin_register_static().expect("rtpav1 test"); + }); +} + +#[test] +fn test_depayloader() { + #[rustfmt::skip] + let test_packets: [(Vec, gst::ClockTime, bool, u32); 4] = [ + ( // simple packet, complete TU + vec![ // RTP payload + 0b0001_1000, + 0b0011_0000, 1, 2, 3, 4, 5, 6, + ], + gst::ClockTime::from_seconds(0), + true, // marker bit + 100_000, // timestamp + ), ( // 2 OBUs, last is fragmented + vec![ + 0b0110_0000, + 0b0000_0110, 0b0111_1000, 1, 2, 3, 4, 5, + 0b0011_0000, 1, 2, 3, + ], + gst::ClockTime::from_seconds(1), + false, + 190_000, + ), ( // continuation of the last OBU + vec![ + 0b1100_0000, + 0b0000_0100, 4, 5, 6, 7, + ], + gst::ClockTime::from_seconds(1), + false, + 190_000, + ), ( // finishing the OBU fragment + vec![ + 0b1001_0000, + 8, 9, 10, + ], + gst::ClockTime::from_seconds(1), + true, + 190_000, + ) + ]; + + #[rustfmt::skip] + let expected: [(gst::ClockTime, Vec); 3] = [ + ( + gst::ClockTime::from_seconds(0), + vec![0b0001_0010, 0, 0b0011_0010, 0b0000_0110, 1, 2, 3, 4, 5, 6], + ), + ( + gst::ClockTime::from_seconds(1), + vec![0b0001_0010, 0, 0b0111_1010, 0b0000_0101, 1, 2, 3, 4, 5], + ), + ( + gst::ClockTime::from_seconds(1), + vec![0b0011_0010, 0b0000_1010, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ), + ]; + + init(); + + let mut h = Harness::new("rtpav1depay"); + h.play(); + + let caps = Caps::builder("application/x-rtp") + .field("media", "video") + .field("payload", 96) + .field("clock-rate", 90000) + .field("encoding-name", "AV1") + .build(); + h.set_src_caps(caps); + + for (idx, (bytes, pts, marker, timestamp)) in test_packets.iter().enumerate() { + let builder = rtp_types::RtpPacketBuilder::new() + .marker_bit(*marker) + .timestamp(*timestamp) + .payload_type(96) + .sequence_number(idx as u16) + .payload(bytes.as_slice()); + let buf = builder.write_vec().unwrap(); + let mut buf = gst::Buffer::from_mut_slice(buf); + { + buf.get_mut().unwrap().set_pts(*pts); + } + + h.push(buf).unwrap(); + } + h.push_event(Eos::new()); + + for (idx, (pts, ex)) in expected.iter().enumerate() { + println!("checking buffer {idx}..."); + + let buffer = h.pull().unwrap(); + assert_eq!(buffer.pts(), Some(*pts)); + let actual = buffer.into_mapped_buffer_readable().unwrap(); + assert_eq!(actual.as_slice(), ex.as_slice()); + } +} diff --git a/net/rtp/src/av1/pay/imp.rs b/net/rtp/src/av1/pay/imp.rs index 4182c14b..ba73f8c1 100644 --- a/net/rtp/src/av1/pay/imp.rs +++ b/net/rtp/src/av1/pay/imp.rs @@ -7,20 +7,19 @@ // // SPDX-License-Identifier: MPL-2.0 +use atomic_refcell::AtomicRefCell; use gst::{glib, subclass::prelude::*}; -use gst_rtp::{prelude::*, subclass::prelude::*}; use std::{ - cmp, collections::VecDeque, io::{Cursor, Read, Seek, SeekFrom, Write}, - sync::Mutex, }; use bitstream_io::{BitReader, BitWriter}; use once_cell::sync::Lazy; -use crate::av1::common::{ - err_flow, leb128_size, write_leb128, ObuType, SizedObu, CLOCK_RATE, ENDIANNESS, +use crate::{ + av1::common::{err_flow, leb128_size, write_leb128, ObuType, SizedObu, CLOCK_RATE, ENDIANNESS}, + basepay::{PacketToBufferRelation, RtpBasePay2Ext}, }; static CAT: Lazy = Lazy::new(|| { @@ -31,8 +30,6 @@ static CAT: Lazy = Lazy::new(|| { ) }); -// TODO: properly handle `max_ptime` and `min_ptime` - /// Information about the OBUs intended to be grouped into one packet #[derive(Copy, Clone, Debug, PartialEq, Eq)] struct PacketOBUData { @@ -60,8 +57,7 @@ struct ObuData { info: SizedObu, bytes: Vec, offset: usize, - dts: Option, - pts: Option, + id: u64, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -78,18 +74,13 @@ struct State { /// (Corresponds to `N` field in the aggregation header) first_packet_in_seq: bool, - /// The last observed DTS if upstream does not provide DTS for each OBU - last_dts: Option, - /// The last observed PTS if upstream does not provide PTS for each OBU - last_pts: Option, - /// If the input is TU or frame aligned. framed: bool, } #[derive(Debug, Default)] pub struct RTPAv1Pay { - state: Mutex, + state: AtomicRefCell, } impl Default for State { @@ -98,8 +89,6 @@ impl Default for State { obus: VecDeque::new(), open_obu_fragment: false, first_packet_in_seq: true, - last_dts: None, - last_pts: None, framed: false, } } @@ -124,11 +113,10 @@ impl RTPAv1Pay { fn handle_new_obus( &self, state: &mut State, + id: u64, data: &[u8], marker: bool, - dts: Option, - pts: Option, - ) -> Result { + ) -> Result { let mut reader = Cursor::new(data); while reader.position() < data.len() as u64 { @@ -163,8 +151,7 @@ impl RTPAv1Pay { info: obu, bytes: Vec::new(), offset: 0, - dts, - pts, + id, }); } @@ -195,23 +182,17 @@ impl RTPAv1Pay { info: obu, bytes, offset: 0, - dts, - pts, + id, }); } } } - let mut list = gst::BufferList::new(); - { - let list = list.get_mut().unwrap(); - while let Some(packet_data) = self.consider_new_packet(state, false, marker) { - let buffer = self.generate_new_packet(state, packet_data)?; - list.add(buffer); - } + while let Some(packet_data) = self.consider_new_packet(state, false, marker) { + self.generate_new_packet(state, packet_data)?; } - Ok(list) + Ok(gst::FlowSuccess::Ok) } /// Look at the size the currently stored OBUs would require, @@ -237,7 +218,7 @@ impl RTPAv1Pay { marker, ); - let payload_limit = gst_rtp::calc_payload_len(self.obj().mtu(), 0, 0); + let payload_limit = self.obj().max_payload_size(); // Create information about the packet that can be created now while iterating over the // OBUs and return this if a full packet can indeed be created now. @@ -361,7 +342,7 @@ impl RTPAv1Pay { &self, state: &mut State, packet: PacketOBUData, - ) -> Result { + ) -> Result { gst::log!( CAT, imp: self, @@ -370,186 +351,134 @@ impl RTPAv1Pay { ); // prepare the outgoing buffer - let mut outbuf = - gst::Buffer::new_rtp_with_sizes(packet.payload_size, 0, 0).map_err(|err| { - gst::element_imp_error!( - self, - gst::ResourceError::Write, - ["Failed to allocate output buffer: {}", err] - ); - - gst::FlowError::Error - })?; + let mut payload = Vec::with_capacity(packet.payload_size as usize); + let mut writer = Cursor::new(&mut payload); { - // this block enforces that outbuf_mut is dropped before pushing outbuf - let first_obu = state.obus.front().unwrap(); - if let Some(dts) = first_obu.dts { - state.last_dts = Some( - state - .last_dts - .map_or(dts, |last_dts| cmp::max(last_dts, dts)), - ); - } - if let Some(pts) = first_obu.pts { - state.last_pts = Some( - state - .last_pts - .map_or(pts, |last_pts| cmp::max(last_pts, pts)), - ); - } + // construct aggregation header + let w = if packet.omit_last_size_field && packet.obu_count < 4 { + packet.obu_count + } else { + 0 + }; - let outbuf_mut = outbuf - .get_mut() - .expect("Failed to get mutable reference to outbuf"); - outbuf_mut.set_dts(state.last_dts); - outbuf_mut.set_pts(state.last_pts); - - let mut rtp = gst_rtp::RTPBuffer::from_buffer_writable(outbuf_mut) - .expect("Failed to create RTPBuffer"); - rtp.set_marker(packet.ends_temporal_unit); - - let payload = rtp - .payload_mut() - .expect("Failed to get mutable reference to RTP payload"); - let mut writer = Cursor::new(payload); - - { - // construct aggregation header - let w = if packet.omit_last_size_field && packet.obu_count < 4 { - packet.obu_count - } else { - 0 - }; - - let aggr_header: [u8; 1] = [ + let aggr_header: [u8; 1] = [ (state.open_obu_fragment as u8) << 7 | // Z ((packet.last_obu_fragment_size.is_some()) as u8) << 6 | // Y (w as u8) << 4 | // W (state.first_packet_in_seq as u8) << 3 // N ; 1]; - writer - .write(&aggr_header) - .map_err(err_flow!(self, aggr_header_write))?; + writer + .write(&aggr_header) + .map_err(err_flow!(self, aggr_header_write))?; - state.first_packet_in_seq = false; + state.first_packet_in_seq = false; + } + + let mut start_id = None; + let end_id; + + // append OBUs to the buffer + for _ in 1..packet.obu_count { + let obu = loop { + let obu = state.obus.pop_front().unwrap(); + + // Drop temporal delimiter from here + if obu.info.obu_type != ObuType::TemporalDelimiter { + break obu; + } + }; + + if start_id.is_none() { + start_id = Some(obu.id); } - // append OBUs to the buffer - for _ in 1..packet.obu_count { - let obu = loop { - let obu = state.obus.pop_front().unwrap(); + write_leb128( + &mut BitWriter::endian(&mut writer, ENDIANNESS), + obu.info.size + obu.info.header_len, + ) + .map_err(err_flow!(self, leb_write))?; + writer + .write(&obu.bytes[obu.offset..]) + .map_err(err_flow!(self, obu_write))?; + } + state.open_obu_fragment = false; - if let Some(dts) = obu.dts { - state.last_dts = Some( - state - .last_dts - .map_or(dts, |last_dts| cmp::max(last_dts, dts)), - ); - } - if let Some(pts) = obu.pts { - state.last_pts = Some( - state - .last_pts - .map_or(pts, |last_pts| cmp::max(last_pts, pts)), - ); - } + { + let last_obu = loop { + let obu = state.obus.front_mut().unwrap(); - // Drop temporal delimiter from here - if obu.info.obu_type != ObuType::TemporalDelimiter { - break obu; - } - }; + // Drop temporal delimiter from here + if obu.info.obu_type != ObuType::TemporalDelimiter { + break obu; + } + let _ = state.obus.pop_front().unwrap(); + }; - write_leb128( - &mut BitWriter::endian(&mut writer, ENDIANNESS), - obu.info.size + obu.info.header_len, - ) - .map_err(err_flow!(self, leb_write))?; + if start_id.is_none() { + start_id = Some(last_obu.id); + } + end_id = last_obu.id; + + // do the last OBU separately + // in this instance `obu_size` includes the header length + let obu_size = if let Some(size) = packet.last_obu_fragment_size { + state.open_obu_fragment = true; + size + } else { + last_obu.bytes.len() as u32 - last_obu.offset as u32 + }; + + if !packet.omit_last_size_field { + write_leb128(&mut BitWriter::endian(&mut writer, ENDIANNESS), obu_size) + .map_err(err_flow!(self, leb_write))?; + } + + // if this OBU is not a fragment, handle it as usual + if packet.last_obu_fragment_size.is_none() { writer - .write(&obu.bytes[obu.offset..]) + .write(&last_obu.bytes[last_obu.offset..]) .map_err(err_flow!(self, obu_write))?; + let _ = state.obus.pop_front().unwrap(); } - state.open_obu_fragment = false; + // otherwise write only a slice, and update the element + // to only contain the unwritten bytes + else { + writer + .write(&last_obu.bytes[last_obu.offset..last_obu.offset + obu_size as usize]) + .map_err(err_flow!(self, obu_write))?; - { - let last_obu = loop { - let obu = state.obus.front_mut().unwrap(); - - if let Some(dts) = obu.dts { - state.last_dts = Some( - state - .last_dts - .map_or(dts, |last_dts| cmp::max(last_dts, dts)), - ); - } - if let Some(pts) = obu.pts { - state.last_pts = Some( - state - .last_pts - .map_or(pts, |last_pts| cmp::max(last_pts, pts)), - ); - } - - // Drop temporal delimiter from here - if obu.info.obu_type != ObuType::TemporalDelimiter { - break obu; - } - let _ = state.obus.pop_front().unwrap(); + let new_size = last_obu.bytes.len() as u32 - last_obu.offset as u32 - obu_size; + last_obu.info = SizedObu { + size: new_size, + header_len: 0, + leb_size: leb128_size(new_size) as u32, + is_fragment: true, + ..last_obu.info }; - - // do the last OBU separately - // in this instance `obu_size` includes the header length - let obu_size = if let Some(size) = packet.last_obu_fragment_size { - state.open_obu_fragment = true; - size - } else { - last_obu.bytes.len() as u32 - last_obu.offset as u32 - }; - - if !packet.omit_last_size_field { - write_leb128(&mut BitWriter::endian(&mut writer, ENDIANNESS), obu_size) - .map_err(err_flow!(self, leb_write))?; - } - - // if this OBU is not a fragment, handle it as usual - if packet.last_obu_fragment_size.is_none() { - writer - .write(&last_obu.bytes[last_obu.offset..]) - .map_err(err_flow!(self, obu_write))?; - let _ = state.obus.pop_front().unwrap(); - } - // otherwise write only a slice, and update the element - // to only contain the unwritten bytes - else { - writer - .write( - &last_obu.bytes[last_obu.offset..last_obu.offset + obu_size as usize], - ) - .map_err(err_flow!(self, obu_write))?; - - let new_size = last_obu.bytes.len() as u32 - last_obu.offset as u32 - obu_size; - last_obu.info = SizedObu { - size: new_size, - header_len: 0, - leb_size: leb128_size(new_size) as u32, - is_fragment: true, - ..last_obu.info - }; - last_obu.offset += obu_size as usize; - } + last_obu.offset += obu_size as usize; } } + // OBUs were consumed above so start_id will be set now + let start_id = start_id.unwrap(); + gst::log!( CAT, imp: self, "generated RTP packet of size {}", - outbuf.size() + payload.len() ); - Ok(outbuf) + self.obj().queue_packet( + PacketToBufferRelation::Ids(start_id..=end_id), + rtp_types::RtpPacketBuilder::new() + .marker_bit(packet.ends_temporal_unit) + .payload(&payload), + )?; + + Ok(gst::FlowSuccess::Ok) } } @@ -557,7 +486,7 @@ impl RTPAv1Pay { impl ObjectSubclass for RTPAv1Pay { const NAME: &'static str = "GstRtpAv1Pay"; type Type = super::RTPAv1Pay; - type ParentType = gst_rtp::RTPBasePayload; + type ParentType = crate::basepay::RtpBasePay2; } impl ObjectImpl for RTPAv1Pay {} @@ -610,64 +539,58 @@ impl ElementImpl for RTPAv1Pay { PAD_TEMPLATES.as_ref() } - - fn change_state( - &self, - transition: gst::StateChange, - ) -> Result { - gst::debug!(CAT, imp: self, "changing state: {}", transition); - - if matches!(transition, gst::StateChange::ReadyToPaused) { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state, true); - } - - let ret = self.parent_change_state(transition); - - if matches!(transition, gst::StateChange::PausedToReady) { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state, true); - } - - ret - } } -impl RTPBasePayloadImpl for RTPAv1Pay { - fn set_caps(&self, caps: &gst::Caps) -> Result<(), gst::LoggableError> { - gst::debug!(CAT, imp: self, "received caps {caps:?}"); +impl crate::basepay::RtpBasePay2Impl for RTPAv1Pay { + const ALLOWED_META_TAGS: &'static [&'static str] = &["video"]; - { - let mut state = self.state.lock().unwrap(); - let s = caps.structure(0).unwrap(); - match s.get::<&str>("alignment").unwrap() { - "tu" | "frame" => { - state.framed = true; - } - _ => { - state.framed = false; - } - } - } - - self.obj().set_options("video", true, "AV1", CLOCK_RATE); + fn start(&self) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.borrow_mut(); + self.reset(&mut state, true); Ok(()) } - fn handle_buffer(&self, buffer: gst::Buffer) -> Result { - gst::trace!(CAT, imp: self, "received buffer of size {}", buffer.size()); + fn stop(&self) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.borrow_mut(); + self.reset(&mut state, true); - let mut state = self.state.lock().unwrap(); + Ok(()) + } - if buffer.flags().contains(gst::BufferFlags::DISCONT) { - gst::debug!(CAT, imp: self, "buffer discontinuity"); - self.reset(&mut state, false); + fn set_sink_caps(&self, caps: &gst::Caps) -> bool { + gst::debug!(CAT, imp: self, "received caps {caps:?}"); + + self.obj().set_src_caps( + &gst::Caps::builder("application/x-rtp") + .field("media", "video") + .field("clock-rate", CLOCK_RATE as i32) + .field("encoding-name", "AV1") + .build(), + ); + + let mut state = self.state.borrow_mut(); + let s = caps.structure(0).unwrap(); + match s.get::<&str>("alignment").unwrap() { + "tu" | "frame" => { + state.framed = true; + } + _ => { + state.framed = false; + } } - let dts = buffer.dts(); - let pts = buffer.pts(); + true + } + fn handle_buffer( + &self, + buffer: &gst::Buffer, + id: u64, + ) -> Result { + gst::trace!(CAT, imp: self, "received buffer of size {}", buffer.size()); + + let mut state = self.state.borrow_mut(); let map = buffer.map_readable().map_err(|_| { gst::element_imp_error!( self, @@ -680,49 +603,31 @@ impl RTPBasePayloadImpl for RTPAv1Pay { // Does the buffer finished a full TU? let marker = buffer.flags().contains(gst::BufferFlags::MARKER) || state.framed; - let list = self.handle_new_obus(&mut state, map.as_slice(), marker, dts, pts)?; + let res = self.handle_new_obus(&mut state, id, map.as_slice(), marker)?; drop(map); drop(state); - if !list.is_empty() { - self.obj().push_list(list) - } else { - Ok(gst::FlowSuccess::Ok) - } + Ok(res) } - fn sink_event(&self, event: gst::Event) -> bool { - gst::log!(CAT, imp: self, "sink event: {}", event.type_()); + fn drain(&self) -> Result { + // flush all remaining OBUs + let mut res = Ok(gst::FlowSuccess::Ok); - match event.view() { - gst::EventView::Eos(_) => { - // flush all remaining OBUs - let mut list = gst::BufferList::new(); - { - let mut state = self.state.lock().unwrap(); - let list = list.get_mut().unwrap(); - - while let Some(packet_data) = self.consider_new_packet(&mut state, true, true) { - match self.generate_new_packet(&mut state, packet_data) { - Ok(buffer) => list.add(buffer), - Err(_) => break, - } - } - - self.reset(&mut state, false); - } - if !list.is_empty() { - let _ = self.obj().push_list(list); - } + let mut state = self.state.borrow_mut(); + while let Some(packet_data) = self.consider_new_packet(&mut state, true, true) { + res = self.generate_new_packet(&mut state, packet_data); + if res.is_err() { + break; } - gst::EventView::FlushStop(_) => { - let mut state = self.state.lock().unwrap(); - self.reset(&mut state, false); - } - _ => (), } - self.parent_sink_event(event) + res + } + + fn flush(&self) { + let mut state = self.state.borrow_mut(); + self.reset(&mut state, false); } } @@ -927,12 +832,14 @@ mod tests { ), ]; - let element = ::Type::new(); + // Element exists just for logging purposes + let element = glib::Object::new::(); + let pay = element.imp(); for idx in 0..input_data.len() { println!("running test {idx}..."); - let mut state = pay.state.lock().unwrap(); + let mut state = pay.state.borrow_mut(); *state = input_data[idx].1.clone(); assert_eq!( diff --git a/net/rtp/src/av1/pay/mod.rs b/net/rtp/src/av1/pay/mod.rs index 0717bafd..beb729b8 100644 --- a/net/rtp/src/av1/pay/mod.rs +++ b/net/rtp/src/av1/pay/mod.rs @@ -6,22 +6,17 @@ // . // // SPDX-License-Identifier: MPL-2.0 -#![allow(clippy::new_without_default)] use gst::glib; use gst::prelude::*; pub mod imp; +#[cfg(test)] +mod tests; glib::wrapper! { pub struct RTPAv1Pay(ObjectSubclass) - @extends gst_rtp::RTPBasePayload, gst::Element, gst::Object; -} - -impl RTPAv1Pay { - pub fn new() -> Self { - glib::Object::new() - } + @extends crate::basepay::RtpBasePay2, gst::Element, gst::Object; } pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { diff --git a/net/rtp/tests/rtpav1.rs b/net/rtp/src/av1/pay/tests.rs similarity index 54% rename from net/rtp/tests/rtpav1.rs rename to net/rtp/src/av1/pay/tests.rs index 3ae68b42..813f9d70 100644 --- a/net/rtp/tests/rtpav1.rs +++ b/net/rtp/src/av1/pay/tests.rs @@ -9,7 +9,6 @@ use gst::{event::Eos, prelude::*, Buffer, Caps, ClockTime}; use gst_check::Harness; -use gst_rtp::{rtp_buffer::RTPBufferExt, RTPBuffer}; fn init() { use std::sync::Once; @@ -17,101 +16,13 @@ fn init() { INIT.call_once(|| { gst::init().unwrap(); - gstrsrtp::plugin_register_static().expect("rtpav1 test"); + crate::plugin_register_static().expect("rtpav1 test"); }); } #[test] -#[rustfmt::skip] -fn test_depayloader() { - let test_packets: [(Vec, bool, u32); 4] = [ - ( // simple packet, complete TU - vec![ // RTP payload - 0b0001_1000, - 0b0011_0000, 1, 2, 3, 4, 5, 6, - ], - true, // marker bit - 100_000, // timestamp - ), ( // 2 OBUs, last is fragmented - vec![ - 0b0110_0000, - 0b0000_0110, 0b0111_1000, 1, 2, 3, 4, 5, - 0b0011_0000, 1, 2, 3, - ], - false, - 190_000, - ), ( // continuation of the last OBU - vec![ - 0b1100_0000, - 0b0000_0100, 4, 5, 6, 7, - ], - false, - 190_000, - ), ( // finishing the OBU fragment - vec![ - 0b1001_0000, - 8, 9, 10, - ], - true, - 190_000, - ) - ]; - - let expected: [Vec; 3] = [ - vec![ - 0b0001_0010, 0, - 0b0011_0010, 0b0000_0110, 1, 2, 3, 4, 5, 6, - ], - vec![ - 0b0001_0010, 0, - 0b0111_1010, 0b0000_0101, 1, 2, 3, 4, 5, - ], - vec![ - 0b0011_0010, 0b0000_1010, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - ], - ]; - - init(); - - let mut h = Harness::new("rtpav1depay"); - h.play(); - - let caps = Caps::builder("application/x-rtp") - .field("media", "video") - .field("payload", 96) - .field("clock-rate", 90000) - .field("encoding-name", "AV1") - .build(); - h.set_src_caps(caps); - - for (idx, (bytes, marker, timestamp)) in test_packets.iter().enumerate() { - let mut buf = Buffer::new_rtp_with_sizes(bytes.len() as u32, 0, 0).unwrap(); - { - let buf_mut = buf.get_mut().unwrap(); - let mut rtp_mut = RTPBuffer::from_buffer_writable(buf_mut).unwrap(); - rtp_mut.set_marker(*marker); - rtp_mut.set_timestamp(*timestamp); - rtp_mut.set_payload_type(96); - rtp_mut.set_seq(idx as u16); - rtp_mut.payload_mut().unwrap().copy_from_slice(bytes); - } - - h.push(buf).unwrap(); - } - h.push_event(Eos::new()); - - for (idx, ex) in expected.iter().enumerate() { - println!("checking buffer {idx}..."); - - let buffer = h.pull().unwrap(); - let actual = buffer.into_mapped_buffer_readable().unwrap(); - assert_eq!(actual.as_slice(), ex.as_slice()); - } -} - -#[test] -#[rustfmt::skip] fn test_payloader() { + #[rustfmt::skip] let test_buffers: [(u64, Vec); 3] = [ ( 0, @@ -136,6 +47,7 @@ fn test_payloader() { ) ]; + #[rustfmt::skip] let expected = [ ( false, // marker bit @@ -183,7 +95,7 @@ fn test_payloader() { let pay = h.element().unwrap(); pay.set_property( "mtu", - gst_rtp::calc_packet_len(25, 0, 0) + 25u32 + rtp_types::RtpPacket::MIN_RTP_PACKET_LEN as u32, ); } h.play(); @@ -203,7 +115,10 @@ fn test_payloader() { buffer.copy_from_slice(bytes); let mut buffer = buffer.into_buffer(); - buffer.get_mut().unwrap().set_pts(ClockTime::try_from(*pts).unwrap()); + buffer + .get_mut() + .unwrap() + .set_pts(ClockTime::from_nseconds(*pts)); h.push(buffer).unwrap(); } @@ -214,13 +129,14 @@ fn test_payloader() { println!("checking packet {idx}..."); let buffer = h.pull().unwrap(); - let packet = RTPBuffer::from_buffer_readable(&buffer).unwrap(); + let map = buffer.map_readable().unwrap(); + let packet = rtp_types::RtpPacket::parse(&map).unwrap(); if base_ts.is_none() { base_ts = Some(packet.timestamp()); } - assert_eq!(packet.payload().unwrap(), payload.as_slice()); - assert_eq!(packet.is_marker(), *marker); + assert_eq!(packet.payload(), payload.as_slice()); + assert_eq!(packet.marker_bit(), *marker); assert_eq!(packet.timestamp(), base_ts.unwrap() + ts_offset); } } diff --git a/net/rtp/src/lib.rs b/net/rtp/src/lib.rs index e04415c3..3fa2ee5a 100644 --- a/net/rtp/src/lib.rs +++ b/net/rtp/src/lib.rs @@ -16,7 +16,6 @@ */ use gst::glib; -mod av1; mod gcc; mod audio_discont; @@ -24,14 +23,13 @@ mod baseaudiopay; mod basedepay; mod basepay; +mod av1; mod pcmau; #[cfg(test)] mod tests; fn plugin_init(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { - av1::depay::register(plugin)?; - av1::pay::register(plugin)?; gcc::register(plugin)?; #[cfg(feature = "doc")] @@ -46,6 +44,9 @@ fn plugin_init(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { .mark_as_plugin_api(gst::PluginAPIFlags::empty()); } + av1::depay::register(plugin)?; + av1::pay::register(plugin)?; + pcmau::depay::register(plugin)?; pcmau::pay::register(plugin)?;