From 5ae730c9b1cb9e6efa47ba4dd23133665668b496 Mon Sep 17 00:00:00 2001 From: Riccardo Zaglia Date: Fri, 14 Jul 2023 11:57:27 +0800 Subject: [PATCH] Remove some mutexes inside StreamSocket --- alvr/client_core/src/connection.rs | 3 +-- alvr/server/src/connection.rs | 3 +-- alvr/sockets/src/stream_socket/mod.rs | 31 ++++++++++++--------------- alvr/sockets/src/stream_socket/tcp.rs | 4 ++-- alvr/sockets/src/stream_socket/udp.rs | 4 ++-- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/alvr/client_core/src/connection.rs b/alvr/client_core/src/connection.rs index c7255102ef..ff516a09d7 100644 --- a/alvr/client_core/src/connection.rs +++ b/alvr/client_core/src/connection.rs @@ -264,14 +264,13 @@ fn connection_pipeline( return Ok(()); } - let stream_socket = stream_socket_builder.accept_from_server( + let mut stream_socket = stream_socket_builder.accept_from_server( &runtime, Duration::from_secs(2), server_ip, settings.connection.stream_port, settings.connection.packet_size as _, )?; - let stream_socket = Arc::new(stream_socket); info!("Connected to server"); diff --git a/alvr/server/src/connection.rs b/alvr/server/src/connection.rs index 44bb097590..b24b1cf552 100644 --- a/alvr/server/src/connection.rs +++ b/alvr/server/src/connection.rs @@ -536,7 +536,7 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { *BITRATE_MANAGER.lock() = BitrateManager::new(settings.video.bitrate.history_size, fps); - let stream_socket = StreamSocketBuilder::connect_to_client( + let mut stream_socket = StreamSocketBuilder::connect_to_client( &runtime, Duration::from_secs(1), client_ip, @@ -547,7 +547,6 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { settings.connection.packet_size as _, ) .map_err(to_con_e!())?; - let stream_socket = Arc::new(stream_socket); let mut video_sender = stream_socket.request_stream(VIDEO); let game_audio_sender = stream_socket.request_stream(AUDIO); diff --git a/alvr/sockets/src/stream_socket/mod.rs b/alvr/sockets/src/stream_socket/mod.rs index f862b47220..b24d2de840 100644 --- a/alvr/sockets/src/stream_socket/mod.rs +++ b/alvr/sockets/src/stream_socket/mod.rs @@ -7,7 +7,7 @@ mod tcp; mod udp; -use alvr_common::{parking_lot::Mutex, prelude::*}; +use alvr_common::prelude::*; use alvr_session::{SocketBufferSize, SocketProtocol}; use bytes::{Buf, BufMut, BytesMut}; use futures::SinkExt; @@ -17,10 +17,7 @@ use std::{ marker::PhantomData, net::IpAddr, ops::{Deref, DerefMut}, - sync::{ - mpsc::{self, RecvTimeoutError}, - Arc, - }, + sync::mpsc::{self, RecvTimeoutError}, time::Duration, }; use tcp::{TcpStreamReceiveSocket, TcpStreamSendSocket}; @@ -368,8 +365,8 @@ impl StreamSocketBuilder { Ok(StreamSocket { max_packet_size, send_socket, - receive_socket: Arc::new(Mutex::new(Some(receive_socket))), - packet_queues: Arc::new(Mutex::new(HashMap::new())), + receive_socket, + packet_queues: HashMap::new(), }) } @@ -415,8 +412,8 @@ impl StreamSocketBuilder { Ok(StreamSocket { max_packet_size, send_socket, - receive_socket: Arc::new(Mutex::new(Some(receive_socket))), - packet_queues: Arc::new(Mutex::new(HashMap::new())), + receive_socket, + packet_queues: HashMap::new(), }) } } @@ -424,8 +421,8 @@ impl StreamSocketBuilder { pub struct StreamSocket { max_packet_size: usize, send_socket: StreamSendSocket, - receive_socket: Arc>>, - packet_queues: Arc>>>, + receive_socket: StreamReceiveSocket, + packet_queues: HashMap>, } impl StreamSocket { @@ -440,10 +437,10 @@ impl StreamSocket { } } - pub fn subscribe_to_stream(&self, stream_id: u16) -> StreamReceiver { + pub fn subscribe_to_stream(&mut self, stream_id: u16) -> StreamReceiver { let (sender, receiver) = mpsc::channel(); - self.packet_queues.lock().insert(stream_id, sender); + self.packet_queues.insert(stream_id, sender); StreamReceiver { receiver, @@ -454,13 +451,13 @@ impl StreamSocket { } } - pub fn recv(&self, runtime: &Runtime, timeout: Duration) -> ConResult { - match self.receive_socket.lock().as_mut().unwrap() { + pub fn recv(&mut self, runtime: &Runtime, timeout: Duration) -> ConResult { + match &mut self.receive_socket { StreamReceiveSocket::Udp(socket) => { - udp::recv(runtime, timeout, socket, &self.packet_queues) + udp::recv(runtime, timeout, socket, &mut self.packet_queues) } StreamReceiveSocket::Tcp(socket) => { - tcp::recv(runtime, timeout, socket, &self.packet_queues) + tcp::recv(runtime, timeout, socket, &mut self.packet_queues) } } } diff --git a/alvr/sockets/src/stream_socket/tcp.rs b/alvr/sockets/src/stream_socket/tcp.rs index a3f840b0d0..b5aa36d3da 100644 --- a/alvr/sockets/src/stream_socket/tcp.rs +++ b/alvr/sockets/src/stream_socket/tcp.rs @@ -93,7 +93,7 @@ pub fn recv( runtime: &Runtime, timeout: Duration, socket: &mut TcpStreamReceiveSocket, - packet_enqueuers: &Mutex>>, + packet_enqueuers: &mut HashMap>, ) -> ConResult { if let Some(maybe_packet) = runtime.block_on(async { tokio::select! { @@ -104,7 +104,7 @@ pub fn recv( let mut packet = maybe_packet?; let stream_id = packet.get_u16(); - if let Some(enqueuer) = packet_enqueuers.lock().get_mut(&stream_id) { + if let Some(enqueuer) = packet_enqueuers.get_mut(&stream_id) { enqueuer.send(packet).map_err(to_con_e!())?; } diff --git a/alvr/sockets/src/stream_socket/udp.rs b/alvr/sockets/src/stream_socket/udp.rs index bb39c9e3a0..43f3d42bfa 100644 --- a/alvr/sockets/src/stream_socket/udp.rs +++ b/alvr/sockets/src/stream_socket/udp.rs @@ -72,7 +72,7 @@ pub fn recv( runtime: &Runtime, timeout: Duration, socket: &mut UdpStreamReceiveSocket, - packet_enqueuers: &Mutex>>, + packet_enqueuers: &mut HashMap>, ) -> ConResult { if let Some(maybe_packet) = runtime.block_on(async { tokio::select! { @@ -88,7 +88,7 @@ pub fn recv( } let stream_id = packet_bytes.get_u16(); - if let Some(enqueuer) = packet_enqueuers.lock().get_mut(&stream_id) { + if let Some(enqueuer) = packet_enqueuers.get_mut(&stream_id) { enqueuer.send(packet_bytes).map_err(to_con_e!())?; }