Skip to content

Commit

Permalink
feat: rtcp proxy add packet kind
Browse files Browse the repository at this point in the history
  • Loading branch information
hongcha98 committed Nov 2, 2023
1 parent 1ddaa63 commit 66398c0
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 45 deletions.
89 changes: 44 additions & 45 deletions src/forward/forward_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::{Arc, Weak};
use std::time::Duration;

use anyhow::Result;
use log::info;
use log::{debug, info};
use tokio::sync::mpsc::{channel, unbounded_channel, Receiver, Sender, UnboundedSender};
use tokio::sync::RwLock;
use webrtc::api::interceptor_registry::register_default_interceptors;
Expand All @@ -16,7 +16,6 @@ use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication;
use webrtc::rtcp::reception_report::ReceptionReport;
use webrtc::rtcp::sender_report::SenderReport;
use webrtc::rtp::packet::Packet;
Expand All @@ -36,6 +35,7 @@ use webrtc::track::track_remote::TrackRemote;
use crate::forward::track_match::{track_match_codec, track_sort};
use crate::media;

use super::rtcp::RtcpMessage;
use super::track_match;

type ForwardData = Arc<Packet>;
Expand Down Expand Up @@ -106,7 +106,7 @@ type SubscriptionGroup = Arc<RwLock<HashMap<PeerWrap, SenderForwardData>>>;

#[derive(Clone)]
struct TrackForward {
pli_send: Sender<()>,
rtcp_send: Sender<RtcpMessage>,
subscription_group: SubscriptionGroup,
}

Expand Down Expand Up @@ -302,14 +302,23 @@ impl PeerForwardInternal {
return;
}
let target_track = tracks.get(target_index).unwrap();
let original_track_forward = anchor_track_forward_map.get(&original_track).unwrap();
let mut subscription_group = original_track_forward.subscription_group.write().await;
if let Some(sender) = subscription_group.remove(&peer_wrap) {
if let Some(sender) = anchor_track_forward_map
.get(&original_track)
.unwrap()
.subscription_group
.write()
.await
.remove(&peer_wrap)
{
let target_track_forward = anchor_track_forward_map.get(target_track).unwrap();
let mut target_subscription_group =
target_track_forward.subscription_group.write().await;
target_subscription_group.insert(peer_wrap, sender);
let _ = target_track_forward.pli_send.try_send(());
target_track_forward
.subscription_group
.write()
.await
.insert(peer_wrap, sender);
let _ = target_track_forward
.rtcp_send
.try_send(RtcpMessage::PictureLossIndication);
}
}
}
Expand Down Expand Up @@ -481,20 +490,18 @@ impl PeerForwardInternal {
sender: Arc<RTCRtpSender>,
track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, TrackForward>>>,
) {
while let Ok((packets, _)) = sender.read_rtcp().await {
if let Some(pc) = pc.upgrade() {
for packet in packets {
if let Some(_pli) = packet.as_any().downcast_ref::<PictureLossIndication>() {
if let Some(track) = sender.track().await {
let kind = track.kind();
let track_forward_map = track_forward_map.read().await;
for (track_remote, track_forward) in track_forward_map.iter() {
if track_remote.0.kind() == kind {
let subscription_group =
track_forward.subscription_group.read().await;
if subscription_group.contains_key(&PeerWrap(pc.clone())) {
let _ = track_forward.pli_send.try_send(());
}
while let (Ok((packets, _)), Some(pc)) = (sender.read_rtcp().await, pc.upgrade()) {
for packet in packets {
if let Some(msg) = RtcpMessage::from_rtcp_packet(packet) {
if let Some(track) = sender.track().await {
let kind = track.kind();
let track_forward_map = track_forward_map.read().await;
for (track_remote, track_forward) in track_forward_map.iter() {
if track_remote.0.kind() == kind {
let subscription_group =
track_forward.subscription_group.read().await;
if subscription_group.contains_key(&PeerWrap(pc.clone())) {
let _ = track_forward.rtcp_send.try_send(msg);
}
}
}
Expand Down Expand Up @@ -557,14 +564,14 @@ impl PeerForwardInternal {
return Err(anyhow::anyhow!("anchor is not self"));
}
let (send, recv) = channel(1);
tokio::spawn(Self::anchor_track_pli(
tokio::spawn(Self::peer_send_rtcp(
Arc::downgrade(&peer),
track.ssrc(),
recv,
));
let mut anchor_track_forward_map = self.anchor_track_forward_map.write().await;
let handle = TrackForward {
pli_send: send,
rtcp_send: send,
subscription_group: Default::default(),
};
anchor_track_forward_map.insert(TrackRemoteWrap(track.clone()), handle.clone());
Expand Down Expand Up @@ -592,10 +599,9 @@ impl PeerForwardInternal {
let anchor_track_forward = subscription.read().await;
let packet = Arc::new(rtp_packet);
for (peer_wrap, sender) in anchor_track_forward.iter() {
if peer_wrap.0.connection_state() != RTCPeerConnectionState::Connected {
continue;
if peer_wrap.0.connection_state() == RTCPeerConnectionState::Connected {
let _ = sender.send(packet.clone());
}
let _ = sender.send(packet.clone());
}
}
info!(
Expand All @@ -606,25 +612,18 @@ impl PeerForwardInternal {
);
}

async fn anchor_track_pli(
async fn peer_send_rtcp(
peer: Weak<RTCPeerConnection>,
media_ssrc: u32,
mut recv: Receiver<()>,
mut recv: Receiver<RtcpMessage>,
) {
loop {
let _ = recv.recv().await;
if let Some(pc) = peer.upgrade() {
if pc
.write_rtcp(&[Box::new(PictureLossIndication {
sender_ssrc: 0,
media_ssrc,
})])
.await
.is_err()
{
break;
}
} else {
while let (Some(rtcp_message), Some(pc)) = (recv.recv().await, peer.upgrade()) {
debug!("ssrc : {} ,send rtcp : {:?}", media_ssrc, rtcp_message);
if pc
.write_rtcp(&[rtcp_message.to_rtcp_packet(media_ssrc)])
.await
.is_err()
{
break;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/forward/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::forward::forward_internal::{get_peer_key, PeerForwardInternal};
use crate::{media, metrics};

mod forward_internal;
mod rtcp;
mod track_match;

#[derive(Clone)]
Expand Down
44 changes: 44 additions & 0 deletions src/forward/rtcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use webrtc::rtcp::packet::Packet;
use webrtc::rtcp::payload_feedbacks::full_intra_request::FullIntraRequest;
use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication;
use webrtc::rtcp::payload_feedbacks::slice_loss_indication::SliceLossIndication;

#[derive(Debug, Clone, Copy)]
pub(crate) enum RtcpMessage {
FullIntraRequest,
PictureLossIndication,
SliceLossIndication,
}

impl RtcpMessage {
pub(crate) fn from_rtcp_packet(packet: Box<dyn Packet + Send + Sync>) -> Option<Self> {
let x = packet.as_any();
if let Some(_) = x.downcast_ref::<FullIntraRequest>() {
return Some(RtcpMessage::FullIntraRequest);
} else if let Some(_) = x.downcast_ref::<PictureLossIndication>() {
return Some(RtcpMessage::PictureLossIndication);
} else if let Some(_) = x.downcast_ref::<SliceLossIndication>() {
return Some(RtcpMessage::SliceLossIndication);
}
None
}

pub(crate) fn to_rtcp_packet(&self, ssrc: u32) -> Box<dyn Packet + Send + Sync> {
match self {
RtcpMessage::FullIntraRequest => Box::new(FullIntraRequest {
sender_ssrc: 0,
media_ssrc: ssrc,
fir: vec![],
}),
RtcpMessage::PictureLossIndication => Box::new(PictureLossIndication {
sender_ssrc: 0,
media_ssrc: ssrc,
}),
RtcpMessage::SliceLossIndication => Box::new(SliceLossIndication {
sender_ssrc: 0,
media_ssrc: ssrc,
sli_entries: vec![],
}),
}
}
}

0 comments on commit 66398c0

Please sign in to comment.