Skip to content

Commit

Permalink
Simplify interface (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulgb authored May 2, 2024
1 parent be73430 commit f7b747d
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 195 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions examples/binary-echo/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use stateroom_wasm::prelude::*;

#[stateroom_wasm]
#[derive(Default)]
struct BinaryEcho;

impl SimpleStateroomService for BinaryEcho {
fn new(_: &str, _: &impl StateroomContext) -> Self {
BinaryEcho
}

impl StateroomService for BinaryEcho {
fn message(&mut self, _: ClientId, message: &str, ctx: &impl StateroomContext) {
ctx.send_binary(
MessageRecipient::Broadcast,
Expand Down
12 changes: 6 additions & 6 deletions examples/clock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use stateroom_wasm::prelude::*;

#[stateroom_wasm]
struct ClockServer(String, u32);
#[derive(Default)]
struct ClockServer(u32);

impl SimpleStateroomService for ClockServer {
fn new(room_id: &str, ctx: &impl StateroomContext) -> Self {
impl StateroomService for ClockServer {
fn init(&mut self, ctx: &impl StateroomContext) {
ctx.set_timer(4000);
ClockServer(room_id.to_string(), 0)
}

fn timer(&mut self, ctx: &impl StateroomContext) {
ctx.send_message(MessageRecipient::Broadcast, &format!("Here in room {} from timer @ {}", self.0, self.1));
self.1 += 1;
ctx.send_message(MessageRecipient::Broadcast, &format!("Timer @ {}", self.0));
self.0 += 1;
ctx.set_timer(4000);
}
}
7 changes: 2 additions & 5 deletions examples/counter-service/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use stateroom_wasm::prelude::*;

#[stateroom_wasm]
#[derive(Default)]
struct SharedCounterServer(i32);

impl SimpleStateroomService for SharedCounterServer {
fn new(_: &str, _: &impl StateroomContext) -> Self {
SharedCounterServer(0)
}

impl StateroomService for SharedCounterServer {
fn message(&mut self, _: ClientId, message: &str, ctx: &impl StateroomContext) {
match message {
"increment" => self.0 += 1,
Expand Down
13 changes: 5 additions & 8 deletions examples/cpu-hog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@ use stateroom_wasm::prelude::*;
const SECONDS: u64 = 1_000_000_000;

#[stateroom_wasm]
struct CpuHog(String);
#[derive(Default)]
struct CpuHog;

fn get_time() -> u64 {
unsafe {
wasi::clock_time_get(wasi::CLOCKID_REALTIME, 0).unwrap()
}
}

impl SimpleStateroomService for CpuHog {
fn new(room_id: &str, _: &impl StateroomContext) -> Self {
CpuHog(room_id.to_string())
}

impl StateroomService for CpuHog {
fn connect(&mut self, _: ClientId, ctx: &impl StateroomContext) {
ctx.send_message(
MessageRecipient::Broadcast,
&format!("Connected to room {}", self.0),
&format!("Connected."),
);

let init_time = get_time();
Expand All @@ -33,7 +30,7 @@ impl SimpleStateroomService for CpuHog {

ctx.send_message(
MessageRecipient::Broadcast,
&format!("Finished in room {}", self.0),
&format!("Finished."),
);
}
}
7 changes: 2 additions & 5 deletions examples/echo-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use stateroom_wasm::prelude::*;

#[stateroom_wasm]
#[derive(Default)]
struct EchoServer;

impl SimpleStateroomService for EchoServer {
fn new(_: &str, _: &impl StateroomContext) -> Self {
EchoServer
}

impl StateroomService for EchoServer {
fn connect(&mut self, client_id: ClientId, ctx: &impl StateroomContext) {
ctx.send_message(client_id, &format!("User {:?} connected.", client_id));
}
Expand Down
7 changes: 2 additions & 5 deletions examples/randomness/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ use bytemuck::cast;
use stateroom_wasm::prelude::*;

#[stateroom_wasm]
#[derive(Default)]
struct RandomServer;

impl SimpleStateroomService for RandomServer {
fn new(_: &str, _: &impl StateroomContext) -> Self {
RandomServer
}

impl StateroomService for RandomServer {
fn connect(&mut self, client_id: ClientId, ctx: &impl StateroomContext) {
let mut buf: [u8; 4] = [0, 0, 0, 0];
unsafe {
Expand Down
1 change: 1 addition & 0 deletions stateroom-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ dashmap = "5.5.3"
futures-util = "0.3.30"
stateroom = {path="../stateroom", version="0.2.8"}
tokio = { version = "1.37.0", features = ["rt-multi-thread"] }
tracing = "0.1.40"
30 changes: 6 additions & 24 deletions stateroom-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use axum::{
routing::get,
Router,
};
use server::{ServerState, ServiceActorContext};
use stateroom::{StateroomService, StateroomServiceFactory};
use server::ServerState;
use stateroom::StateroomServiceFactory;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
Expand Down Expand Up @@ -103,17 +103,8 @@ impl Server {
/// endpoints are available:
/// - `/` (GET): return HTTP 200 if the server is running (useful as a baseline status check)
/// - `/ws` (GET): initiate a WebSocket connection to the stateroom service.
pub async fn serve_async<J: StateroomService + Send + Sync + Unpin + 'static>(
self,
service_factory: impl StateroomServiceFactory<ServiceActorContext, Service = J>
+ Send
+ Sync
+ 'static,
) -> std::io::Result<()>
where
J: StateroomService + Send + Sync + Unpin + 'static,
{
let server_state = Arc::new(ServerState::new(service_factory));
pub async fn serve_async(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
let server_state = Arc::new(ServerState::new(factory));

let app = Router::new()
.route("/ws", get(serve_websocket))
Expand All @@ -133,21 +124,12 @@ impl Server {
/// endpoints are available:
/// - `/` (GET): return HTTP 200 if the server is running (useful as a baseline status check)
/// - `/ws` (GET): initiate a WebSocket connection to the stateroom service.
pub fn serve<J>(
self,
service_factory: impl StateroomServiceFactory<ServiceActorContext, Service = J>
+ Send
+ Sync
+ 'static,
) -> std::io::Result<()>
where
J: StateroomService + Send + Sync + Unpin + 'static,
{
pub fn serve(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async { self.serve_async(service_factory).await })
.block_on(async { self.serve_async(factory).await })
}
}

Expand Down
54 changes: 24 additions & 30 deletions stateroom-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ use tokio::{

/// A [StateroomContext] implementation for [StateroomService]s hosted in the
/// context of a [ServiceActor].
#[derive(Clone)]
pub struct ServiceActorContext {
pub struct ServerStateroomContext {
senders: Arc<DashMap<ClientId, Sender<Message>>>,
event_sender: Sender<Event>,
event_sender: Arc<Sender<Event>>,
}

impl ServiceActorContext {
impl ServerStateroomContext {
pub fn try_send(&self, recipient: MessageRecipient, message: Message) {
match recipient {
MessageRecipient::Broadcast => {
Expand All @@ -39,14 +38,14 @@ impl ServiceActorContext {
if let Some(sender) = self.senders.get(&client_id) {
sender.try_send(message).unwrap();
} else {
println!("No sender for client {:?}", client_id);
tracing::error!(?client_id, "No sender for client.");
}
}
}
}
}

impl StateroomContext for ServiceActorContext {
impl StateroomContext for ServerStateroomContext {
fn send_message(&self, recipient: impl Into<MessageRecipient>, message: &str) {
self.try_send(recipient.into(), Message::Text(message.to_string()));
}
Expand All @@ -60,7 +59,7 @@ impl StateroomContext for ServiceActorContext {
let sender = self.event_sender.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(ms_delay as u64)).await;
sender.send(Event::TimerEvent).await.unwrap();
sender.send(Event::Timer).await.unwrap();
});
}
}
Expand All @@ -78,44 +77,39 @@ pub enum Event {
Message { client: ClientId, message: Message },
Join { client: ClientId },
Leave { client: ClientId },
TimerEvent,
Timer,
}

impl ServerState {
pub fn new<T: StateroomService + Send + Sync + 'static>(
service_factory: impl StateroomServiceFactory<ServiceActorContext, Service = T> + Send + 'static,
) -> Self {
pub fn new(factory: impl StateroomServiceFactory) -> Self {
let (tx, mut rx) = tokio::sync::mpsc::channel::<Event>(100);

let senders = Arc::new(DashMap::new());

let senders_ = senders.clone();
let tx_ = tx.clone();
let handle = tokio::spawn(async move {
let mut service = service_factory
.build(
"",
ServiceActorContext {
senders: senders_.clone(),
event_sender: tx_,
},
)
.unwrap();

let context = Arc::new(ServerStateroomContext {
senders: senders_.clone(),
event_sender: Arc::new(tx_),
});

let mut service = factory.build("", context.clone()).unwrap();
service.init(context.as_ref());

loop {
let msg = rx.recv().await;
println!("{:?}", msg);
match msg {
Some(Event::Message { client, message }) => match message {
Message::Text(msg) => service.message(client, &msg),
Message::Binary(msg) => service.binary(client, &msg),
Message::Text(msg) => service.message(client, &msg, context.as_ref()),
Message::Binary(msg) => service.binary(client, &msg, context.as_ref()),
Message::Close(_) => {}
msg => println!("Ignoring unhandled message: {:?}", msg),
msg => tracing::warn!("Ignoring unhandled message: {:?}", msg),
},
Some(Event::Join { client }) => service.connect(client),
Some(Event::Leave { client }) => service.disconnect(client),
Some(Event::TimerEvent) => {
service.timer();
Some(Event::Join { client }) => service.connect(client, context.as_ref()),
Some(Event::Leave { client }) => service.disconnect(client, context.as_ref()),
Some(Event::Timer) => {
service.timer(context.as_ref());
}
None => break,
}
Expand All @@ -133,7 +127,7 @@ impl ServerState {
pub fn remove(&self, client: &ClientId) {
self.inbound_sender
.try_send(Event::Leave {
client: client.clone(),
client: *client,
})
.unwrap();
self.senders.remove(client);
Expand Down
Loading

0 comments on commit f7b747d

Please sign in to comment.