From a4cb1a34005337fce1684bc6a6aaf88197c02dd2 Mon Sep 17 00:00:00 2001 From: Riccardo Zaglia Date: Thu, 13 Jul 2023 14:18:56 +0800 Subject: [PATCH] Progress on sync sockets (24) Rewrite packet reconstruction code to suit the timeout patten --- alvr/audio/src/lib.rs | 18 ++-- alvr/client_core/src/connection.rs | 36 +++---- alvr/server/src/connection.rs | 30 ++---- alvr/sockets/src/stream_socket/mod.rs | 139 +++++++++++++------------- 4 files changed, 98 insertions(+), 125 deletions(-) diff --git a/alvr/audio/src/lib.rs b/alvr/audio/src/lib.rs index b9f51a5f21..4584e9cfe7 100644 --- a/alvr/audio/src/lib.rs +++ b/alvr/audio/src/lib.rs @@ -24,7 +24,7 @@ use std::{ thread, time::Duration, }; -use tokio::{runtime::Runtime, time}; +use tokio::runtime::Runtime; static VIRTUAL_MICROPHONE_PAIRS: Lazy> = Lazy::new(|| { [ @@ -302,7 +302,7 @@ pub fn record_audio_blocking( #[cfg(windows)] if mute && device.is_output { - crate::windows::set_mute_windows_device(&device, true).ok(); + crate::windows::set_mute_windows_device(device, true).ok(); } let mut res = stream.play().map_err(err!()); @@ -371,16 +371,10 @@ pub fn receive_samples_loop( let mut recovery_sample_buffer = vec![]; loop { if let Some(runtime) = &*runtime.read() { - let res = runtime.block_on(async { - tokio::select! { - res = receiver.recv_buffer(&mut receiver_buffer) => Some(res), - _ = time::sleep(Duration::from_millis(500)) => None, - } - }); - match res { - Some(Ok(())) => (), - Some(err_res) => return err_res.map_err(err!()), - None => continue, + match receiver.recv_buffer(runtime, Duration::from_millis(500), &mut receiver_buffer) { + Ok(true) => (), + Ok(false) | Err(ConnectionError::Timeout) => continue, + Err(ConnectionError::Other(e)) => return fmt_e!("{e}"), } } else { return Ok(()); diff --git a/alvr/client_core/src/connection.rs b/alvr/client_core/src/connection.rs index 299410b382..243c3b0d31 100644 --- a/alvr/client_core/src/connection.rs +++ b/alvr/client_core/src/connection.rs @@ -32,7 +32,7 @@ use std::{ thread, time::{Duration, Instant}, }; -use tokio::{runtime::Runtime, time}; +use tokio::runtime::Runtime; #[cfg(target_os = "android")] use crate::audio; @@ -307,17 +307,14 @@ fn connection_pipeline( let mut stream_corrupted = false; loop { if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let res = runtime.block_on(async { - tokio::select! { - res = video_receiver.recv_buffer(&mut receiver_buffer) => Some(res), - _ = time::sleep(Duration::from_millis(500)) => None, - } - }); - - match res { - Some(Ok(())) => (), - Some(Err(_)) => return, - None => continue, + match video_receiver.recv_buffer( + runtime, + Duration::from_millis(500), + &mut receiver_buffer, + ) { + Ok(true) => (), + Ok(false) | Err(ConnectionError::Timeout) => continue, + Err(ConnectionError::Other(_)) => return, } } else { return; @@ -399,17 +396,10 @@ fn connection_pipeline( let haptics_receive_thread = thread::spawn(move || loop { let haptics = if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let res = runtime.block_on(async { - tokio::select! { - res = haptics_receiver.recv_header_only() => Some(res), - _ = time::sleep(Duration::from_millis(500)) => None, - } - }); - - match res { - Some(Ok(packet)) => packet, - Some(Err(_)) => return, - None => continue, + match haptics_receiver.recv_header_only(runtime, Duration::from_millis(500)) { + Ok(packet) => packet, + Err(ConnectionError::Timeout) => continue, + Err(ConnectionError::Other(_)) => return, } } else { return; diff --git a/alvr/server/src/connection.rs b/alvr/server/src/connection.rs index 9dd888c7a8..dd22cfa36f 100644 --- a/alvr/server/src/connection.rs +++ b/alvr/server/src/connection.rs @@ -42,7 +42,7 @@ use std::{ thread, time::Duration, }; -use tokio::{runtime::Runtime, time}; +use tokio::runtime::Runtime; const RETRY_CONNECT_MIN_INTERVAL: Duration = Duration::from_secs(1); @@ -697,16 +697,10 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { loop { let tracking = if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let maybe_tracking = runtime.block_on(async { - tokio::select! { - res = tracking_receiver.recv_header_only() => Some(res), - _ = time::sleep(Duration::from_millis(500)) => None, - } - }); - match maybe_tracking { - Some(Ok(tracking)) => tracking, - Some(Err(_)) => return, - None => continue, + match tracking_receiver.recv_header_only(runtime, Duration::from_millis(500)) { + Ok(tracking) => tracking, + Err(ConnectionError::Timeout) => continue, + Err(ConnectionError::Other(_)) => return, } } else { return; @@ -815,16 +809,10 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { let statistics_thread = thread::spawn(move || loop { let client_stats = if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let maybe_client_stats = runtime.block_on(async { - tokio::select! { - res = statics_receiver.recv_header_only() => Some(res), - _ = time::sleep(Duration::from_millis(500)) => None, - } - }); - match maybe_client_stats { - Some(Ok(stats)) => stats, - Some(Err(_)) => return, - None => continue, + match statics_receiver.recv_header_only(runtime, Duration::from_millis(500)) { + Ok(stats) => stats, + Err(ConnectionError::Timeout) => continue, + Err(ConnectionError::Other(_)) => return, } } else { return; diff --git a/alvr/sockets/src/stream_socket/mod.rs b/alvr/sockets/src/stream_socket/mod.rs index 03ee314c0d..cb0c9415fb 100644 --- a/alvr/sockets/src/stream_socket/mod.rs +++ b/alvr/sockets/src/stream_socket/mod.rs @@ -13,9 +13,8 @@ use bytes::{Buf, BufMut, BytesMut}; use futures::SinkExt; use serde::{de::DeserializeOwned, Serialize}; use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, marker::PhantomData, - mem, net::IpAddr, ops::{Deref, DerefMut}, sync::Arc, @@ -229,86 +228,88 @@ impl ReceiverBuffer { pub struct StreamReceiver { receiver: mpsc::UnboundedReceiver, - next_packet_shards: HashMap, - next_packet_shards_count: Option, - next_packet_index: u32, + last_reconstructed_packet_index: u32, + packet_shards: BTreeMap>, + empty_shard_maps: Vec>, _phantom: PhantomData, } -/// Get next packet reconstructing from shards. It can store at max shards from two packets; if the -/// reordering entropy is too high, packets will never be successfully reconstructed. +/// Get next packet reconstructing from shards. +/// Returns true if a packet has been recontructed and copied into the buffer. impl StreamReceiver { - pub async fn recv_buffer(&mut self, buffer: &mut ReceiverBuffer) -> StrResult { - buffer.had_packet_loss = false; - - loop { - let current_packet_index = self.next_packet_index; - self.next_packet_index += 1; - - let mut current_packet_shards = - HashMap::with_capacity(self.next_packet_shards.capacity()); - mem::swap(&mut current_packet_shards, &mut self.next_packet_shards); - - let mut current_packet_shards_count = self.next_packet_shards_count.take(); - - loop { - if let Some(shards_count) = current_packet_shards_count { - if current_packet_shards.len() >= shards_count { - buffer.inner.clear(); - - for i in 0..shards_count { - if let Some(shard) = current_packet_shards.get(&i) { - buffer.inner.put_slice(shard); - } else { - error!("Cannot find shard with given index!"); - buffer.had_packet_loss = true; - - self.next_packet_shards.clear(); - - break; - } - } + pub fn recv_buffer( + &mut self, + runtime: &Runtime, + timeout: Duration, + buffer: &mut ReceiverBuffer, + ) -> ConResult { + // Get shard + let mut shard = runtime.block_on(async { + tokio::select! { + res = self.receiver.recv() => res.ok_or_else(enone!()).map_err(to_con_e!()), + _ = time::sleep(timeout) => alvr_common::timeout(), + } + })?; + let shard_packet_index = shard.get_u32(); + let shards_count = shard.get_u32() as usize; + let shard_index = shard.get_u32() as usize; + + // Discard shard if too old + if shard_packet_index <= self.last_reconstructed_packet_index { + debug!("Received old shard!"); + return Ok(false); + } - return Ok(()); - } + // Insert shards into map + let shard_map = self + .packet_shards + .entry(shard_packet_index) + .or_insert_with(|| self.empty_shard_maps.pop().unwrap_or_default()); + shard_map.insert(shard_index, shard); + + // If the shard map is (probably) complete: + if shard_map.len() == shards_count { + buffer.inner.clear(); + + // Copy shards into final buffer. Fail if there are missing shards. This is impossibly + // rare (if the shards_count value got corrupted) but should be handled. + for idx in 0..shards_count { + if let Some(shard) = shard_map.get(&idx) { + buffer.inner.put_slice(shard); + } else { + error!("Cannot find shard with given index!"); + return Ok(false); } + } - let mut shard = self.receiver.recv().await.ok_or_else(enone!())?; - - let shard_packet_index = shard.get_u32(); - let shards_count = shard.get_u32() as usize; - let shard_index = shard.get_u32() as usize; - - if shard_packet_index == current_packet_index { - current_packet_shards.insert(shard_index, shard); - current_packet_shards_count = Some(shards_count); - } else if shard_packet_index >= self.next_packet_index { - if shard_packet_index > self.next_packet_index { - self.next_packet_shards.clear(); - } + // Check if current packet index is one up the last successful reconstucted packet. + buffer.had_packet_loss = shard_packet_index != self.last_reconstructed_packet_index + 1; + self.last_reconstructed_packet_index = shard_packet_index; - self.next_packet_shards.insert(shard_index, shard); - self.next_packet_shards_count = Some(shards_count); - self.next_packet_index = shard_packet_index; + // Pop old shards and recycle containers + while let Some((packet_index, mut shards)) = self.packet_shards.pop_first() { + shards.clear(); + self.empty_shard_maps.push(shards); - if shard_packet_index > self.next_packet_index - || self.next_packet_shards.len() == shards_count - { - debug!("Skipping to next packet. Signaling packet loss."); - buffer.had_packet_loss = true; - break; - } + if packet_index == shard_packet_index { + break; } - // else: ignore old shard } + + Ok(true) + } else { + Ok(false) } } - pub async fn recv_header_only(&mut self) -> StrResult { + pub fn recv_header_only(&mut self, runtime: &Runtime, timeout: Duration) -> ConResult { let mut buffer = ReceiverBuffer::new(); - self.recv_buffer(&mut buffer).await?; - Ok(buffer.get()?.0) + loop { + if self.recv_buffer(runtime, timeout, &mut buffer)? { + return Ok(buffer.get().map_err(to_con_e!())?.0); + } + } } } @@ -458,9 +459,9 @@ impl StreamSocket { StreamReceiver { receiver, - next_packet_shards: HashMap::new(), - next_packet_shards_count: None, - next_packet_index: 0, + last_reconstructed_packet_index: 0, + packet_shards: BTreeMap::new(), + empty_shard_maps: vec![], _phantom: PhantomData, } }