Skip to content

Commit

Permalink
Progress on sync sockets (24)
Browse files Browse the repository at this point in the history
Rewrite packet reconstruction code to suit the timeout patten
  • Loading branch information
zmerp committed Jul 13, 2023
1 parent a2f09de commit a4cb1a3
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 125 deletions.
18 changes: 6 additions & 12 deletions alvr/audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
thread,
time::Duration,
};
use tokio::{runtime::Runtime, time};
use tokio::runtime::Runtime;

static VIRTUAL_MICROPHONE_PAIRS: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
[
Expand Down Expand Up @@ -302,7 +302,7 @@ pub fn record_audio_blocking(

#[cfg(windows)]
if mute && device.is_output {
crate::windows::set_mute_windows_device(&device, true).ok();
crate::windows::set_mute_windows_device(device, true).ok();
}

let mut res = stream.play().map_err(err!());
Expand Down Expand Up @@ -371,16 +371,10 @@ pub fn receive_samples_loop(
let mut recovery_sample_buffer = vec![];
loop {
if let Some(runtime) = &*runtime.read() {
let res = runtime.block_on(async {
tokio::select! {
res = receiver.recv_buffer(&mut receiver_buffer) => Some(res),
_ = time::sleep(Duration::from_millis(500)) => None,
}
});
match res {
Some(Ok(())) => (),
Some(err_res) => return err_res.map_err(err!()),
None => continue,
match receiver.recv_buffer(runtime, Duration::from_millis(500), &mut receiver_buffer) {
Ok(true) => (),
Ok(false) | Err(ConnectionError::Timeout) => continue,
Err(ConnectionError::Other(e)) => return fmt_e!("{e}"),
}
} else {
return Ok(());
Expand Down
36 changes: 13 additions & 23 deletions alvr/client_core/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use std::{
thread,
time::{Duration, Instant},
};
use tokio::{runtime::Runtime, time};
use tokio::runtime::Runtime;

#[cfg(target_os = "android")]
use crate::audio;
Expand Down Expand Up @@ -307,17 +307,14 @@ fn connection_pipeline(
let mut stream_corrupted = false;
loop {
if let Some(runtime) = &*CONNECTION_RUNTIME.read() {
let res = runtime.block_on(async {
tokio::select! {
res = video_receiver.recv_buffer(&mut receiver_buffer) => Some(res),
_ = time::sleep(Duration::from_millis(500)) => None,
}
});

match res {
Some(Ok(())) => (),
Some(Err(_)) => return,
None => continue,
match video_receiver.recv_buffer(
runtime,
Duration::from_millis(500),
&mut receiver_buffer,
) {
Ok(true) => (),
Ok(false) | Err(ConnectionError::Timeout) => continue,
Err(ConnectionError::Other(_)) => return,
}
} else {
return;
Expand Down Expand Up @@ -399,17 +396,10 @@ fn connection_pipeline(

let haptics_receive_thread = thread::spawn(move || loop {
let haptics = if let Some(runtime) = &*CONNECTION_RUNTIME.read() {
let res = runtime.block_on(async {
tokio::select! {
res = haptics_receiver.recv_header_only() => Some(res),
_ = time::sleep(Duration::from_millis(500)) => None,
}
});

match res {
Some(Ok(packet)) => packet,
Some(Err(_)) => return,
None => continue,
match haptics_receiver.recv_header_only(runtime, Duration::from_millis(500)) {
Ok(packet) => packet,
Err(ConnectionError::Timeout) => continue,
Err(ConnectionError::Other(_)) => return,
}
} else {
return;
Expand Down
30 changes: 9 additions & 21 deletions alvr/server/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use std::{
thread,
time::Duration,
};
use tokio::{runtime::Runtime, time};
use tokio::runtime::Runtime;

const RETRY_CONNECT_MIN_INTERVAL: Duration = Duration::from_secs(1);

Expand Down Expand Up @@ -697,16 +697,10 @@ fn try_connect(mut client_ips: HashMap<IpAddr, String>) -> ConResult {

loop {
let tracking = if let Some(runtime) = &*CONNECTION_RUNTIME.read() {
let maybe_tracking = runtime.block_on(async {
tokio::select! {
res = tracking_receiver.recv_header_only() => Some(res),
_ = time::sleep(Duration::from_millis(500)) => None,
}
});
match maybe_tracking {
Some(Ok(tracking)) => tracking,
Some(Err(_)) => return,
None => continue,
match tracking_receiver.recv_header_only(runtime, Duration::from_millis(500)) {
Ok(tracking) => tracking,
Err(ConnectionError::Timeout) => continue,
Err(ConnectionError::Other(_)) => return,
}
} else {
return;
Expand Down Expand Up @@ -815,16 +809,10 @@ fn try_connect(mut client_ips: HashMap<IpAddr, String>) -> ConResult {

let statistics_thread = thread::spawn(move || loop {
let client_stats = if let Some(runtime) = &*CONNECTION_RUNTIME.read() {
let maybe_client_stats = runtime.block_on(async {
tokio::select! {
res = statics_receiver.recv_header_only() => Some(res),
_ = time::sleep(Duration::from_millis(500)) => None,
}
});
match maybe_client_stats {
Some(Ok(stats)) => stats,
Some(Err(_)) => return,
None => continue,
match statics_receiver.recv_header_only(runtime, Duration::from_millis(500)) {
Ok(stats) => stats,
Err(ConnectionError::Timeout) => continue,
Err(ConnectionError::Other(_)) => return,
}
} else {
return;
Expand Down
139 changes: 70 additions & 69 deletions alvr/sockets/src/stream_socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use bytes::{Buf, BufMut, BytesMut};
use futures::SinkExt;
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::HashMap,
collections::{BTreeMap, HashMap},
marker::PhantomData,
mem,
net::IpAddr,
ops::{Deref, DerefMut},
sync::Arc,
Expand Down Expand Up @@ -229,86 +228,88 @@ impl<T: DeserializeOwned> ReceiverBuffer<T> {

pub struct StreamReceiver<T> {
receiver: mpsc::UnboundedReceiver<BytesMut>,
next_packet_shards: HashMap<usize, BytesMut>,
next_packet_shards_count: Option<usize>,
next_packet_index: u32,
last_reconstructed_packet_index: u32,
packet_shards: BTreeMap<u32, HashMap<usize, BytesMut>>,
empty_shard_maps: Vec<HashMap<usize, BytesMut>>,
_phantom: PhantomData<T>,
}

/// Get next packet reconstructing from shards. It can store at max shards from two packets; if the
/// reordering entropy is too high, packets will never be successfully reconstructed.
/// Get next packet reconstructing from shards.
/// Returns true if a packet has been recontructed and copied into the buffer.
impl<T: DeserializeOwned> StreamReceiver<T> {
pub async fn recv_buffer(&mut self, buffer: &mut ReceiverBuffer<T>) -> StrResult {
buffer.had_packet_loss = false;

loop {
let current_packet_index = self.next_packet_index;
self.next_packet_index += 1;

let mut current_packet_shards =
HashMap::with_capacity(self.next_packet_shards.capacity());
mem::swap(&mut current_packet_shards, &mut self.next_packet_shards);

let mut current_packet_shards_count = self.next_packet_shards_count.take();

loop {
if let Some(shards_count) = current_packet_shards_count {
if current_packet_shards.len() >= shards_count {
buffer.inner.clear();

for i in 0..shards_count {
if let Some(shard) = current_packet_shards.get(&i) {
buffer.inner.put_slice(shard);
} else {
error!("Cannot find shard with given index!");
buffer.had_packet_loss = true;

self.next_packet_shards.clear();

break;
}
}
pub fn recv_buffer(
&mut self,
runtime: &Runtime,
timeout: Duration,
buffer: &mut ReceiverBuffer<T>,
) -> ConResult<bool> {
// Get shard
let mut shard = runtime.block_on(async {
tokio::select! {
res = self.receiver.recv() => res.ok_or_else(enone!()).map_err(to_con_e!()),
_ = time::sleep(timeout) => alvr_common::timeout(),
}
})?;
let shard_packet_index = shard.get_u32();
let shards_count = shard.get_u32() as usize;
let shard_index = shard.get_u32() as usize;

// Discard shard if too old
if shard_packet_index <= self.last_reconstructed_packet_index {
debug!("Received old shard!");
return Ok(false);
}

return Ok(());
}
// Insert shards into map
let shard_map = self
.packet_shards
.entry(shard_packet_index)
.or_insert_with(|| self.empty_shard_maps.pop().unwrap_or_default());
shard_map.insert(shard_index, shard);

// If the shard map is (probably) complete:
if shard_map.len() == shards_count {
buffer.inner.clear();

// Copy shards into final buffer. Fail if there are missing shards. This is impossibly
// rare (if the shards_count value got corrupted) but should be handled.
for idx in 0..shards_count {
if let Some(shard) = shard_map.get(&idx) {
buffer.inner.put_slice(shard);
} else {
error!("Cannot find shard with given index!");
return Ok(false);
}
}

let mut shard = self.receiver.recv().await.ok_or_else(enone!())?;

let shard_packet_index = shard.get_u32();
let shards_count = shard.get_u32() as usize;
let shard_index = shard.get_u32() as usize;

if shard_packet_index == current_packet_index {
current_packet_shards.insert(shard_index, shard);
current_packet_shards_count = Some(shards_count);
} else if shard_packet_index >= self.next_packet_index {
if shard_packet_index > self.next_packet_index {
self.next_packet_shards.clear();
}
// Check if current packet index is one up the last successful reconstucted packet.
buffer.had_packet_loss = shard_packet_index != self.last_reconstructed_packet_index + 1;
self.last_reconstructed_packet_index = shard_packet_index;

self.next_packet_shards.insert(shard_index, shard);
self.next_packet_shards_count = Some(shards_count);
self.next_packet_index = shard_packet_index;
// Pop old shards and recycle containers
while let Some((packet_index, mut shards)) = self.packet_shards.pop_first() {
shards.clear();
self.empty_shard_maps.push(shards);

if shard_packet_index > self.next_packet_index
|| self.next_packet_shards.len() == shards_count
{
debug!("Skipping to next packet. Signaling packet loss.");
buffer.had_packet_loss = true;
break;
}
if packet_index == shard_packet_index {
break;
}
// else: ignore old shard
}

Ok(true)
} else {
Ok(false)
}
}

pub async fn recv_header_only(&mut self) -> StrResult<T> {
pub fn recv_header_only(&mut self, runtime: &Runtime, timeout: Duration) -> ConResult<T> {
let mut buffer = ReceiverBuffer::new();
self.recv_buffer(&mut buffer).await?;

Ok(buffer.get()?.0)
loop {
if self.recv_buffer(runtime, timeout, &mut buffer)? {
return Ok(buffer.get().map_err(to_con_e!())?.0);
}
}
}
}

Expand Down Expand Up @@ -458,9 +459,9 @@ impl StreamSocket {

StreamReceiver {
receiver,
next_packet_shards: HashMap::new(),
next_packet_shards_count: None,
next_packet_index: 0,
last_reconstructed_packet_index: 0,
packet_shards: BTreeMap::new(),
empty_shard_maps: vec![],
_phantom: PhantomData,
}
}
Expand Down

0 comments on commit a4cb1a3

Please sign in to comment.