Skip to content

Commit

Permalink
Remove some mutexes inside StreamSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
zmerp committed Jul 14, 2023
1 parent e440137 commit 5ae730c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 25 deletions.
3 changes: 1 addition & 2 deletions alvr/client_core/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,13 @@ fn connection_pipeline(
return Ok(());
}

let stream_socket = stream_socket_builder.accept_from_server(
let mut stream_socket = stream_socket_builder.accept_from_server(
&runtime,
Duration::from_secs(2),
server_ip,
settings.connection.stream_port,
settings.connection.packet_size as _,
)?;
let stream_socket = Arc::new(stream_socket);

info!("Connected to server");

Expand Down
3 changes: 1 addition & 2 deletions alvr/server/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ fn try_connect(mut client_ips: HashMap<IpAddr, String>) -> ConResult {

*BITRATE_MANAGER.lock() = BitrateManager::new(settings.video.bitrate.history_size, fps);

let stream_socket = StreamSocketBuilder::connect_to_client(
let mut stream_socket = StreamSocketBuilder::connect_to_client(
&runtime,
Duration::from_secs(1),
client_ip,
Expand All @@ -547,7 +547,6 @@ fn try_connect(mut client_ips: HashMap<IpAddr, String>) -> ConResult {
settings.connection.packet_size as _,
)
.map_err(to_con_e!())?;
let stream_socket = Arc::new(stream_socket);

let mut video_sender = stream_socket.request_stream(VIDEO);
let game_audio_sender = stream_socket.request_stream(AUDIO);
Expand Down
31 changes: 14 additions & 17 deletions alvr/sockets/src/stream_socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
mod tcp;
mod udp;

use alvr_common::{parking_lot::Mutex, prelude::*};
use alvr_common::prelude::*;
use alvr_session::{SocketBufferSize, SocketProtocol};
use bytes::{Buf, BufMut, BytesMut};
use futures::SinkExt;
Expand All @@ -17,10 +17,7 @@ use std::{
marker::PhantomData,
net::IpAddr,
ops::{Deref, DerefMut},
sync::{
mpsc::{self, RecvTimeoutError},
Arc,
},
sync::mpsc::{self, RecvTimeoutError},
time::Duration,
};
use tcp::{TcpStreamReceiveSocket, TcpStreamSendSocket};
Expand Down Expand Up @@ -368,8 +365,8 @@ impl StreamSocketBuilder {
Ok(StreamSocket {
max_packet_size,
send_socket,
receive_socket: Arc::new(Mutex::new(Some(receive_socket))),
packet_queues: Arc::new(Mutex::new(HashMap::new())),
receive_socket,
packet_queues: HashMap::new(),
})
}

Expand Down Expand Up @@ -415,17 +412,17 @@ impl StreamSocketBuilder {
Ok(StreamSocket {
max_packet_size,
send_socket,
receive_socket: Arc::new(Mutex::new(Some(receive_socket))),
packet_queues: Arc::new(Mutex::new(HashMap::new())),
receive_socket,
packet_queues: HashMap::new(),
})
}
}

pub struct StreamSocket {
max_packet_size: usize,
send_socket: StreamSendSocket,
receive_socket: Arc<Mutex<Option<StreamReceiveSocket>>>,
packet_queues: Arc<Mutex<HashMap<u16, mpsc::Sender<BytesMut>>>>,
receive_socket: StreamReceiveSocket,
packet_queues: HashMap<u16, mpsc::Sender<BytesMut>>,
}

impl StreamSocket {
Expand All @@ -440,10 +437,10 @@ impl StreamSocket {
}
}

pub fn subscribe_to_stream<T>(&self, stream_id: u16) -> StreamReceiver<T> {
pub fn subscribe_to_stream<T>(&mut self, stream_id: u16) -> StreamReceiver<T> {
let (sender, receiver) = mpsc::channel();

self.packet_queues.lock().insert(stream_id, sender);
self.packet_queues.insert(stream_id, sender);

StreamReceiver {
receiver,
Expand All @@ -454,13 +451,13 @@ impl StreamSocket {
}
}

pub fn recv(&self, runtime: &Runtime, timeout: Duration) -> ConResult {
match self.receive_socket.lock().as_mut().unwrap() {
pub fn recv(&mut self, runtime: &Runtime, timeout: Duration) -> ConResult {
match &mut self.receive_socket {
StreamReceiveSocket::Udp(socket) => {
udp::recv(runtime, timeout, socket, &self.packet_queues)
udp::recv(runtime, timeout, socket, &mut self.packet_queues)
}
StreamReceiveSocket::Tcp(socket) => {
tcp::recv(runtime, timeout, socket, &self.packet_queues)
tcp::recv(runtime, timeout, socket, &mut self.packet_queues)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions alvr/sockets/src/stream_socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn recv(
runtime: &Runtime,
timeout: Duration,
socket: &mut TcpStreamReceiveSocket,
packet_enqueuers: &Mutex<HashMap<u16, mpsc::Sender<BytesMut>>>,
packet_enqueuers: &mut HashMap<u16, mpsc::Sender<BytesMut>>,
) -> ConResult {
if let Some(maybe_packet) = runtime.block_on(async {
tokio::select! {
Expand All @@ -104,7 +104,7 @@ pub fn recv(
let mut packet = maybe_packet?;

let stream_id = packet.get_u16();
if let Some(enqueuer) = packet_enqueuers.lock().get_mut(&stream_id) {
if let Some(enqueuer) = packet_enqueuers.get_mut(&stream_id) {
enqueuer.send(packet).map_err(to_con_e!())?;
}

Expand Down
4 changes: 2 additions & 2 deletions alvr/sockets/src/stream_socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub fn recv(
runtime: &Runtime,
timeout: Duration,
socket: &mut UdpStreamReceiveSocket,
packet_enqueuers: &Mutex<HashMap<u16, mpsc::Sender<BytesMut>>>,
packet_enqueuers: &mut HashMap<u16, mpsc::Sender<BytesMut>>,
) -> ConResult {
if let Some(maybe_packet) = runtime.block_on(async {
tokio::select! {
Expand All @@ -88,7 +88,7 @@ pub fn recv(
}

let stream_id = packet_bytes.get_u16();
if let Some(enqueuer) = packet_enqueuers.lock().get_mut(&stream_id) {
if let Some(enqueuer) = packet_enqueuers.get_mut(&stream_id) {
enqueuer.send(packet_bytes).map_err(to_con_e!())?;
}

Expand Down

0 comments on commit 5ae730c

Please sign in to comment.