Skip to content

Commit

Permalink
Merge pull request #8 from sine-fdn/connection-closed
Browse files Browse the repository at this point in the history
Connection closed
  • Loading branch information
raimundo-henriques authored Nov 20, 2023
2 parents 9c47c2a + e743673 commit fbcc064
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ jobs:
~/.rustup
key: ${{ env.cache-name }}-${{ hashFiles('**/Cargo.toml') }}
- run: cargo build --all-features
- run: cargo test --all-features -- --skip session
- run: cargo test --all-features -- --skip session --skip quit_and_rejoin_session
- run: cargo clippy -- -Dwarnings
104 changes: 84 additions & 20 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use futures::StreamExt;
use libp2p::{
gossipsub, noise,
swarm::{NetworkBehaviour, SwarmEvent},
upnp, yamux, Multiaddr,
upnp, yamux, Multiaddr, PeerId,
};
use log::{error, info};
use rsa::signature::SignatureEncoding;
Expand Down Expand Up @@ -90,13 +90,15 @@ struct MyBehaviour {
enum Event {
Upnp(upnp::Event),
StdIn(String),
Msg(Msg),
Msg(Msg, PeerId),
ConnectionClosed(PeerId),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
enum Msg {
Join(PublicKey, String),
Participants(HashMap<PublicKey, String>),
Quit(PeerId, String),
Participants(HashMap<PublicKey, (String, PeerId)>),
LobbyNowClosed,
Share {
from: PublicKey,
Expand All @@ -120,7 +122,17 @@ enum Phase {
SendingShares,
}

fn print_results(results: &BTreeMap<String, i64>, participants: &HashMap<PublicKey, String>) {
fn print_participants(participants: &HashMap<PublicKey, (String, PeerId)>) {
println!("\n-- Participants --");
for (pub_key, (name, _)) in participants {
println!("{pub_key} - {name}");
}
}

fn print_results(
results: &BTreeMap<String, i64>,
participants: &HashMap<PublicKey, (String, PeerId)>,
) {
println!("\nAverage results:");
for (key, result) in results.iter() {
let avg = (*result as f64 / participants.len() as f64) / 100.00;
Expand Down Expand Up @@ -200,7 +212,7 @@ async fn main() -> Result<(), Box<dyn Error>> {

let mut phase = Phase::WaitingForParticipants;
let mut stdin = io::BufReader::new(io::stdin()).lines();
let mut participants = HashMap::<PublicKey, String>::new();
let mut participants = HashMap::<PublicKey, (String, PeerId)>::new();
let mut sent_shares = HashMap::<PublicKey, HashMap<&String, i64>>::new();
let mut received_shares = HashMap::<PublicKey, Vec<u8>>::new();
let mut sums = HashMap::<PublicKey, HashMap<String, i64>>::new();
Expand Down Expand Up @@ -363,20 +375,13 @@ async fn main() -> Result<(), Box<dyn Error>> {
received_shares.insert(from, share);
}
}
Event::Msg(msg)
Event::Msg(msg, propagation_source)
},
SwarmEvent::IncomingConnectionError { .. } => {
eprintln!("Error while establishing incoming connection");
continue;
},
SwarmEvent::ConnectionClosed { .. } => {
if result.is_none() {
eprintln!("Connection has been closed by one of the participants");
std::process::exit(1);
} else {
std::process::exit(0);
}
},
SwarmEvent::ConnectionClosed { peer_id, .. } => Event::ConnectionClosed(peer_id),
ev => {
info!("{ev:?}");
continue;
Expand Down Expand Up @@ -433,7 +438,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
println!("{pub_key} - {name}");
}
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
participants.insert(pub_key.clone(), name.clone());
participants.insert(pub_key.clone(), (name.clone(), *swarm.local_peer_id()));
}
(_, Event::Upnp(upnp::Event::GatewayNotFound)) => {
error!("Gateway does not support UPnP");
Expand All @@ -444,20 +449,62 @@ async fn main() -> Result<(), Box<dyn Error>> {
break;
}
(_, Event::Upnp(ev)) => info!("{ev:?}"),
(Phase::WaitingForParticipants, Event::Msg(msg)) => match msg {
(Phase::WaitingForParticipants, Event::ConnectionClosed(peer_id)) => {
if result.is_none() {
let Some(disconnected) =
participants.iter().find(|(_, (_, id))| *id == peer_id)
else {
println!("Connection error, please try again.");
std::process::exit(1);
};

let disconnected = disconnected.1 .0.clone();

println!("\nParticipant {disconnected} disconnected");

if swarm.connected_peers().count() == 0 && is_leader {
participants.retain(|_, (_, id)| *id != peer_id);
} else if is_leader {
let msg = Msg::Quit(peer_id, disconnected).serialize()?;
swarm
.behaviour_mut()
.gossipsub
.publish(topic.clone(), msg)?;

participants.retain(|_, (_, id)| *id != peer_id);

print_participants(&participants);

let msg = Msg::Participants(participants.clone()).serialize()?;
if let Err(e) = swarm.behaviour_mut().gossipsub.publish(topic.clone(), msg)
{
error!("Could not publish to gossipsub: {e:?}");
}
}
continue;
} else {
std::process::exit(0);
}
}
(Phase::WaitingForParticipants, Event::Msg(msg, peer_id)) => match msg {
Msg::Join(public_key, name) => {
if is_leader {
println!("{public_key} - {name}");
participants.insert(public_key, name);
participants.insert(public_key, (name, peer_id));
let msg = Msg::Participants(participants.clone()).serialize()?;
if let Err(e) = swarm.behaviour_mut().gossipsub.publish(topic.clone(), msg)
{
error!("Could not publish to gossipsub: {e:?}");
}
}
}
Msg::Quit(_, name) => {
println!("\nParticipant {name} disconnected");

print_participants(&participants);
}
Msg::Participants(all_participants) => {
for (public_key, name) in all_participants.iter() {
for (public_key, (name, _)) in all_participants.iter() {
if !participants.contains_key(public_key) {
println!("{public_key} - {name}");
}
Expand Down Expand Up @@ -485,14 +532,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
std::process::exit(1);
}
},
(Phase::SendingShares, Event::Msg(msg)) => match msg {
(Phase::SendingShares, Event::Msg(msg, _peer_id)) => match msg {
Msg::Join(_, _) | Msg::Participants(_) | Msg::LobbyNowClosed => {
println!(
"Already waiting for shares, but some participant still tried to join!"
);
continue;
}
Msg::Share { .. } => {}
Msg::Quit(..) | Msg::Share { .. } => {},
Msg::Sum(public_key, sum) => {
if is_leader {
sums.insert(public_key, sum);
Expand All @@ -503,6 +550,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
std::process::exit(0);
}
},
(Phase::SendingShares, Event::ConnectionClosed(peer_id)) => {
if is_leader {
let Some((_, (disconnected, _))) =
participants.iter().find(|(_, (_, id))| *id == peer_id)
else {
println!("Connection error, please try again.");
std::process::exit(1);
};

println!(
"Aborting benchmark: participant {disconnected} left the while waiting for shares"
);
} else {
println!("Aborting benchmark: a participant left while waiting for shares");
}
std::process::exit(1);
}
(Phase::ConfirmingParticipants, _) => {}
}
}
Expand Down
131 changes: 131 additions & 0 deletions tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,137 @@ fn invalid_address() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn quit_and_rejoin_session() -> Result<(), Box<dyn std::error::Error>> {
let mut new_session = new_command("foo", None, "tests/test_files/valid_json.json")?;

let mut leader = new_session
.stdout(Stdio::piped())
.stdin(Stdio::piped())
.spawn()?;
let stdout = leader.stdout.take().unwrap();
let reader = BufReader::new(stdout);
let stdin = leader.stdin.take().unwrap();
let mut writer = BufWriter::new(stdin);
let mut lines = reader.lines();

let address = loop {
if let Some(Ok(l)) = lines.next() {
if l.contains("--address=/ip4/") {
break l
.split(" ")
.find(|s| s.contains("--address=/ip4/"))
.unwrap()
.replace("--address=", "");
}
}
};

let bar_address = address.clone();
let bar_handle = thread::spawn(move || {
let mut participant = new_command(
"bar",
Some(&bar_address),
"tests/test_files/valid_json.json",
)
.unwrap()
.stdout(Stdio::piped())
.spawn()
.unwrap();

let stdout = participant.stdout.take().unwrap();
let reader = BufReader::new(stdout);
let mut lines = reader.lines();

while let Some(Ok(l)) = lines.next() {
println!("bar > {l}");
if l.contains("- foo") {
participant.kill().unwrap();
break;
}
}
});

while let Some(Ok(l)) = lines.next() {
println!("foo > {l}");
if l.contains("bar disconnected") {
break;
}
}

bar_handle.join().unwrap();

let mut threads = vec![];
for name in ["baz", "qux"] {
sleep(Duration::from_millis(200));
let address = address.clone();
threads.push(thread::spawn(move || {
let mut participant =
new_command(name, Some(&address), "tests/test_files/valid_json.json")
.unwrap()
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.unwrap();

let stdout = participant.stdout.take().unwrap();
let reader = BufReader::new(stdout);
let stdin = participant.stdin.take().unwrap();
let mut writer = BufWriter::new(stdin);
let mut lines = reader.lines();

while let Some(Ok(l)) = lines.next() {
println!("{name} > {l}");

if l.contains("Do you want to join the benchmark?") {
sleep(Duration::from_millis(200));
writeln!(writer, "y").unwrap();
writer.flush().unwrap();
}

if l.contains("results") {
participant.kill().unwrap();
return;
}
}
}));
}

let mut participant_count = 1;
let mut benchmark_complete = false;
while let Some(Ok(l)) = lines.next() {
println!("foo > {}", l);
if l.contains("- baz") || l.contains("- qux") {
participant_count += 1;
}
if participant_count == 3 {
sleep(Duration::from_millis(200));
writeln!(writer, "").unwrap();
writer.flush().unwrap();
}
if l.contains("results") {
benchmark_complete = true;
break;
}
}

sleep(Duration::from_millis(200));
leader.kill()?;

for t in threads {
t.join().unwrap();
}

if benchmark_complete {
Ok(())
} else {
Err(Box::new(Error::new(
ErrorKind::Other,
"Could not complete benchmark",
)))
}
}

#[test]
fn session() -> Result<(), Box<dyn std::error::Error>> {
let mut new_session = new_command("foo", None, "tests/test_files/valid_json.json")?;
Expand Down

0 comments on commit fbcc064

Please sign in to comment.