Skip to content

Commit

Permalink
refactor: track pli
Browse files Browse the repository at this point in the history
  • Loading branch information
hongcha98 committed Oct 10, 2023
1 parent 6d4fa46 commit dac58b9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 28 deletions.
83 changes: 56 additions & 27 deletions src/forward/forward_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::time::Duration;

use anyhow::Result;
use log::info;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::select;
use tokio::sync::mpsc::{channel, unbounded_channel, Receiver, Sender, UnboundedSender};
use tokio::sync::RwLock;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
Expand Down Expand Up @@ -101,13 +102,19 @@ impl Hash for TrackRemoteWrap {
}
}

type ForwardHandle = Arc<RwLock<HashMap<PeerWrap, SenderForwardData>>>;
type SubscriptionGroup = Arc<RwLock<HashMap<PeerWrap, SenderForwardData>>>;

#[derive(Clone)]
struct TrackForward {
pli_send: Sender<()>,
subscription_group: SubscriptionGroup,
}

pub(crate) struct PeerForwardInternal {
pub(crate) id: String,
anchor: RwLock<Option<Arc<RTCPeerConnection>>>,
subscribe_group: RwLock<Vec<PeerWrap>>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, ForwardHandle>>>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, TrackForward>>>,
ice_server: Vec<RTCIceServer>,
}

Expand Down Expand Up @@ -180,6 +187,10 @@ impl PeerForwardInternal {
let mut subscribe_peers = self.subscribe_group.write().await;
subscribe_peers.push(PeerWrap(peer.clone()));
drop(subscribe_peers);
let anchor_track_forward_map = self.anchor_track_forward_map.read().await;
for (_, track_forward) in anchor_track_forward_map.iter() {
let _ = track_forward.pli_send.try_send(());
}
if self.publish_is_svc().await {
tokio::spawn(Self::subscribe_track_flush(
Arc::downgrade(&peer),
Expand Down Expand Up @@ -207,11 +218,11 @@ impl PeerForwardInternal {

async fn subscribe_track_flush(
peer: Weak<RTCPeerConnection>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, ForwardHandle>>>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, TrackForward>>>,
) {
let mut pre_report: Option<RemoteInboundRTPStats> = None;
loop {
let timeout = tokio::time::sleep(Duration::from_secs(10));
let timeout = tokio::time::sleep(Duration::from_secs(20));
tokio::pin!(timeout);
let _ = timeout.as_mut().await;
if let Some(pc) = peer.upgrade() {
Expand Down Expand Up @@ -250,7 +261,7 @@ impl PeerForwardInternal {

async fn subscribe_track_reallocate(
pc: Arc<RTCPeerConnection>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, ForwardHandle>>>,
anchor_track_forward_map: Arc<RwLock<HashMap<TrackRemoteWrap, TrackForward>>>,
upgrade: bool,
) {
let peer_wrap = PeerWrap(pc);
Expand All @@ -262,8 +273,8 @@ impl PeerForwardInternal {
.collect();
let mut original_track = None;
for track in tracks.iter() {
if let Some(subscribes) = anchor_track_forward_map.get(track) {
let subscribes = subscribes.read().await;
if let Some(track_forward) = anchor_track_forward_map.get(track) {
let subscribes = track_forward.subscription_group.read().await;
if subscribes.contains_key(&peer_wrap) {
original_track = Some(track.clone());
break;
Expand Down Expand Up @@ -295,21 +306,23 @@ impl PeerForwardInternal {
return;
}
let target_track = tracks.get(target_index).unwrap();
let original_subscribes = anchor_track_forward_map.get(&original_track).unwrap();
let mut subscribes = original_subscribes.write().await;
if let Some(sender) = subscribes.remove(&peer_wrap) {
let target_subscribes = anchor_track_forward_map.get(target_track).unwrap();
let mut target_subscribes = target_subscribes.write().await;
target_subscribes.insert(peer_wrap, sender);
let original_track_forward = anchor_track_forward_map.get(&original_track).unwrap();
let mut subscription_group = original_track_forward.subscription_group.write().await;
if let Some(sender) = subscription_group.remove(&peer_wrap) {
let target_track_forward = anchor_track_forward_map.get(target_track).unwrap();
let mut target_subscription_group =
target_track_forward.subscription_group.write().await;
target_subscription_group.insert(peer_wrap, sender);
let _ = target_track_forward.pli_send.try_send(());
}
}
}

pub async fn remove_subscribe(&self, peer: Arc<RTCPeerConnection>) -> Result<()> {
let peer_wrap = PeerWrap(peer.clone());
for (_, track_forward_map) in self.anchor_track_forward_map.write().await.iter() {
let mut track_forward_map = track_forward_map.write().await;
track_forward_map.remove(&peer_wrap);
for (_, track_forward) in self.anchor_track_forward_map.write().await.iter() {
let mut subscription_group = track_forward.subscription_group.write().await;
subscription_group.remove(&peer_wrap);
}
let mut subscribe_peers = self.subscribe_group.write().await;
subscribe_peers.retain(|x| x != &peer_wrap);
Expand Down Expand Up @@ -395,12 +408,13 @@ impl PeerForwardInternal {
)
.await
{
let mut subscription_map = anchor_track_forward_map
let mut subscription_group = anchor_track_forward_map
.get(&TrackRemoteWrap(track))
.unwrap()
.subscription_group
.write()
.await;
subscription_map.insert(PeerWrap(peer.clone()), sender);
subscription_group.insert(PeerWrap(peer.clone()), sender);
}
}
}
Expand Down Expand Up @@ -514,22 +528,30 @@ impl PeerForwardInternal {
if anchor.as_ref().unwrap().get_stats_id() != peer.get_stats_id() {
return Err(anyhow::anyhow!("anchor is not self"));
}
tokio::spawn(Self::anchor_track_pli(Arc::downgrade(&peer), track.ssrc()));
let (send, recv) = channel(1);
tokio::spawn(Self::anchor_track_pli(
Arc::downgrade(&peer),
track.ssrc(),
recv,
));
let mut anchor_track_forward_map = self.anchor_track_forward_map.write().await;
let subscription: ForwardHandle = Default::default();
anchor_track_forward_map.insert(TrackRemoteWrap(track.clone()), subscription.clone());
let handle = TrackForward {
pli_send: send,
subscription_group: Default::default(),
};
anchor_track_forward_map.insert(TrackRemoteWrap(track.clone()), handle.clone());
tokio::spawn(Self::anchor_track_forward(
self.id.clone(),
track,
subscription,
handle.subscription_group,
));
Ok(())
}

async fn anchor_track_forward(
id: String,
track: Arc<TrackRemote>,
subscription: ForwardHandle,
subscription: SubscriptionGroup,
) {
let mut b = vec![0u8; 1500];
info!(
Expand All @@ -556,11 +578,18 @@ impl PeerForwardInternal {
);
}

async fn anchor_track_pli(peer: Weak<RTCPeerConnection>, media_ssrc: u32) {
async fn anchor_track_pli(
peer: Weak<RTCPeerConnection>,
media_ssrc: u32,
mut recv: Receiver<()>,
) {
loop {
let timeout = tokio::time::sleep(Duration::from_secs(1));
let timeout = tokio::time::sleep(Duration::from_secs(10));
tokio::pin!(timeout);
let _ = timeout.as_mut().await;
select! {
_= recv.recv() => {},
_ = timeout.as_mut() => {},
}
if let Some(pc) = peer.upgrade() {
if pc
.write_rtcp(&[Box::new(PictureLossIndication {
Expand Down
4 changes: 3 additions & 1 deletion src/forward/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ impl PeerForward {
RTCPeerConnectionState::Failed | RTCPeerConnectionState::Disconnected => {
let _ = pc.close().await;
}
RTCPeerConnectionState::Connected => {
let _ = internal.add_subscribe(pc).await;
}
RTCPeerConnectionState::Closed => {
let _ = internal.remove_subscribe(pc).await;
}
Expand All @@ -121,7 +124,6 @@ impl PeerForward {
}
Box::pin(async {})
}));
let _ = self.internal.add_subscribe(peer.clone()).await;
Ok((
peer_complete(offer, peer.clone()).await?,
get_peer_key(peer),
Expand Down

0 comments on commit dac58b9

Please sign in to comment.