diff --git a/src/forward/forward_internal.rs b/src/forward/forward_internal.rs index 4d84711c..72089e09 100644 --- a/src/forward/forward_internal.rs +++ b/src/forward/forward_internal.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, Weak}; use std::time::Duration; use anyhow::Result; -use log::info; +use log::{debug, info}; use tokio::sync::mpsc::{channel, unbounded_channel, Receiver, Sender, UnboundedSender}; use tokio::sync::RwLock; use webrtc::api::interceptor_registry::register_default_interceptors; @@ -16,7 +16,6 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::RTCPeerConnection; -use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtcp::reception_report::ReceptionReport; use webrtc::rtcp::sender_report::SenderReport; use webrtc::rtp::packet::Packet; @@ -36,6 +35,7 @@ use webrtc::track::track_remote::TrackRemote; use crate::forward::track_match::{track_match_codec, track_sort}; use crate::media; +use super::rtcp::RtcpMessage; use super::track_match; type ForwardData = Arc; @@ -106,7 +106,7 @@ type SubscriptionGroup = Arc>>; #[derive(Clone)] struct TrackForward { - pli_send: Sender<()>, + rtcp_send: Sender, subscription_group: SubscriptionGroup, } @@ -302,14 +302,23 @@ impl PeerForwardInternal { return; } let target_track = tracks.get(target_index).unwrap(); - let original_track_forward = anchor_track_forward_map.get(&original_track).unwrap(); - let mut subscription_group = original_track_forward.subscription_group.write().await; - if let Some(sender) = subscription_group.remove(&peer_wrap) { + if let Some(sender) = anchor_track_forward_map + .get(&original_track) + .unwrap() + .subscription_group + .write() + .await + .remove(&peer_wrap) + { let target_track_forward = anchor_track_forward_map.get(target_track).unwrap(); - let mut target_subscription_group = - target_track_forward.subscription_group.write().await; - target_subscription_group.insert(peer_wrap, sender); - let _ = target_track_forward.pli_send.try_send(()); + target_track_forward + .subscription_group + .write() + .await + .insert(peer_wrap, sender); + let _ = target_track_forward + .rtcp_send + .try_send(RtcpMessage::PictureLossIndication); } } } @@ -481,20 +490,18 @@ impl PeerForwardInternal { sender: Arc, track_forward_map: Arc>>, ) { - while let Ok((packets, _)) = sender.read_rtcp().await { - if let Some(pc) = pc.upgrade() { - for packet in packets { - if let Some(_pli) = packet.as_any().downcast_ref::() { - if let Some(track) = sender.track().await { - let kind = track.kind(); - let track_forward_map = track_forward_map.read().await; - for (track_remote, track_forward) in track_forward_map.iter() { - if track_remote.0.kind() == kind { - let subscription_group = - track_forward.subscription_group.read().await; - if subscription_group.contains_key(&PeerWrap(pc.clone())) { - let _ = track_forward.pli_send.try_send(()); - } + while let (Ok((packets, _)), Some(pc)) = (sender.read_rtcp().await, pc.upgrade()) { + for packet in packets { + if let Some(msg) = RtcpMessage::from_rtcp_packet(packet) { + if let Some(track) = sender.track().await { + let kind = track.kind(); + let track_forward_map = track_forward_map.read().await; + for (track_remote, track_forward) in track_forward_map.iter() { + if track_remote.0.kind() == kind { + let subscription_group = + track_forward.subscription_group.read().await; + if subscription_group.contains_key(&PeerWrap(pc.clone())) { + let _ = track_forward.rtcp_send.try_send(msg); } } } @@ -557,14 +564,14 @@ impl PeerForwardInternal { return Err(anyhow::anyhow!("anchor is not self")); } let (send, recv) = channel(1); - tokio::spawn(Self::anchor_track_pli( + tokio::spawn(Self::peer_send_rtcp( Arc::downgrade(&peer), track.ssrc(), recv, )); let mut anchor_track_forward_map = self.anchor_track_forward_map.write().await; let handle = TrackForward { - pli_send: send, + rtcp_send: send, subscription_group: Default::default(), }; anchor_track_forward_map.insert(TrackRemoteWrap(track.clone()), handle.clone()); @@ -592,10 +599,9 @@ impl PeerForwardInternal { let anchor_track_forward = subscription.read().await; let packet = Arc::new(rtp_packet); for (peer_wrap, sender) in anchor_track_forward.iter() { - if peer_wrap.0.connection_state() != RTCPeerConnectionState::Connected { - continue; + if peer_wrap.0.connection_state() == RTCPeerConnectionState::Connected { + let _ = sender.send(packet.clone()); } - let _ = sender.send(packet.clone()); } } info!( @@ -606,25 +612,18 @@ impl PeerForwardInternal { ); } - async fn anchor_track_pli( + async fn peer_send_rtcp( peer: Weak, media_ssrc: u32, - mut recv: Receiver<()>, + mut recv: Receiver, ) { - loop { - let _ = recv.recv().await; - if let Some(pc) = peer.upgrade() { - if pc - .write_rtcp(&[Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc, - })]) - .await - .is_err() - { - break; - } - } else { + while let (Some(rtcp_message), Some(pc)) = (recv.recv().await, peer.upgrade()) { + debug!("ssrc : {} ,send rtcp : {:?}", media_ssrc, rtcp_message); + if pc + .write_rtcp(&[rtcp_message.to_rtcp_packet(media_ssrc)]) + .await + .is_err() + { break; } } diff --git a/src/forward/mod.rs b/src/forward/mod.rs index 590601a6..ffd7f1dd 100644 --- a/src/forward/mod.rs +++ b/src/forward/mod.rs @@ -16,6 +16,7 @@ use crate::forward::forward_internal::{get_peer_key, PeerForwardInternal}; use crate::{media, metrics}; mod forward_internal; +mod rtcp; mod track_match; #[derive(Clone)] diff --git a/src/forward/rtcp.rs b/src/forward/rtcp.rs new file mode 100644 index 00000000..8239f8c9 --- /dev/null +++ b/src/forward/rtcp.rs @@ -0,0 +1,44 @@ +use webrtc::rtcp::packet::Packet; +use webrtc::rtcp::payload_feedbacks::full_intra_request::FullIntraRequest; +use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; +use webrtc::rtcp::payload_feedbacks::slice_loss_indication::SliceLossIndication; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum RtcpMessage { + FullIntraRequest, + PictureLossIndication, + SliceLossIndication, +} + +impl RtcpMessage { + pub(crate) fn from_rtcp_packet(packet: Box) -> Option { + let x = packet.as_any(); + if let Some(_) = x.downcast_ref::() { + return Some(RtcpMessage::FullIntraRequest); + } else if let Some(_) = x.downcast_ref::() { + return Some(RtcpMessage::PictureLossIndication); + } else if let Some(_) = x.downcast_ref::() { + return Some(RtcpMessage::SliceLossIndication); + } + None + } + + pub(crate) fn to_rtcp_packet(&self, ssrc: u32) -> Box { + match self { + RtcpMessage::FullIntraRequest => Box::new(FullIntraRequest { + sender_ssrc: 0, + media_ssrc: ssrc, + fir: vec![], + }), + RtcpMessage::PictureLossIndication => Box::new(PictureLossIndication { + sender_ssrc: 0, + media_ssrc: ssrc, + }), + RtcpMessage::SliceLossIndication => Box::new(SliceLossIndication { + sender_ssrc: 0, + media_ssrc: ssrc, + sli_entries: vec![], + }), + } + } +}