Skip to content

Commit

Permalink
Add support for Logger, refactor threads using new StatefulThread type
Browse files Browse the repository at this point in the history
  • Loading branch information
willcrichton committed Aug 26, 2024
1 parent 7e0c591 commit 3bd45be
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 51 deletions.
97 changes: 46 additions & 51 deletions crates/server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use std::{
path::PathBuf,
pin::pin,
sync::{Arc, LazyLock},
time::{Duration, Instant},
};

use miniserve::{http::StatusCode, Content, Request, Response};
use serde::{Deserialize, Serialize};
use tokio::{
fs, join,
sync::{mpsc, oneshot},
task::JoinSet,
};
use stateful::StatefulThread;
use tokio::{fs, join, task::JoinSet};

mod stateful;

async fn index(_req: Request) -> Response {
let content = include_str!("../index.html").to_string();
Expand Down Expand Up @@ -42,55 +39,49 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
docs
}

type Payload = (Arc<Vec<String>>, oneshot::Sender<Option<Vec<String>>>);

fn chatbot_thread() -> (mpsc::Sender<Payload>, mpsc::Sender<()>) {
let (req_tx, mut req_rx) = mpsc::channel::<Payload>(1024);
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
tokio::spawn(async move {
let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]);
while let Some((messages, responder)) = req_rx.recv().await {
let doc_paths = chatbot.retrieval_documents(&messages);
let docs = load_docs(doc_paths).await;
let mut chat_fut = pin!(chatbot.query_chat(&messages, &docs));
let mut cancel_fut = pin!(cancel_rx.recv());
let start = Instant::now();
loop {
let log_fut = tokio::time::sleep(Duration::from_secs(1));
tokio::select! {
response = &mut chat_fut => {
responder.send(Some(response)).unwrap();
break;
}
_ = &mut cancel_fut => {
responder.send(None).unwrap();
break;
}
_ = log_fut => {
println!("Waiting for {} seconds", start.elapsed().as_secs());
}
}
}
}
});
(req_tx, cancel_tx)
struct LogFunction {
logger: chatbot::Logger,
}

static CHATBOT_THREAD: LazyLock<(mpsc::Sender<Payload>, mpsc::Sender<()>)> =
LazyLock::new(chatbot_thread);
impl stateful::StatefulFunction for LogFunction {
type Input = Arc<Vec<String>>;
type Output = ();

async fn query_chat(messages: &Arc<Vec<String>>) -> Option<Vec<String>> {
let (tx, rx) = oneshot::channel();
CHATBOT_THREAD
.0
.send((Arc::clone(messages), tx))
.await
.unwrap();
rx.await.unwrap()
async fn call(&mut self, messages: Self::Input) -> Self::Output {
self.logger.append(messages.last().unwrap());
self.logger.save().await.unwrap();
}
}

static LOG_THREAD: LazyLock<StatefulThread<LogFunction>> = LazyLock::new(|| {
StatefulThread::new(LogFunction {
logger: chatbot::Logger::default(),
})
});

struct ChatbotFunction {
chatbot: chatbot::Chatbot,
}

impl stateful::StatefulFunction for ChatbotFunction {
type Input = Arc<Vec<String>>;
type Output = Vec<String>;

async fn call(&mut self, messages: Self::Input) -> Self::Output {
let doc_paths = self.chatbot.retrieval_documents(&messages);
let docs = load_docs(doc_paths).await;
self.chatbot.query_chat(&messages, &docs).await
}
}

static CHATBOT_THREAD: LazyLock<StatefulThread<ChatbotFunction>> = LazyLock::new(|| {
StatefulThread::new(ChatbotFunction {
chatbot: chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]),
})
});

async fn cancel(_req: Request) -> Response {
CHATBOT_THREAD.1.send(()).await.unwrap();
CHATBOT_THREAD.cancel().await;
Ok(Content::Html("success".into()))
}

Expand All @@ -103,7 +94,11 @@ async fn chat(req: Request) -> Response {
};

let messages = Arc::new(data.messages);
let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages));
let (i, responses_opt, _) = join!(
chatbot::gen_random_number(),
CHATBOT_THREAD.call(Arc::clone(&messages)),
LOG_THREAD.call(Arc::clone(&messages))
);

let response = match responses_opt {
Some(mut responses) => {
Expand Down
73 changes: 73 additions & 0 deletions crates/server/src/stateful.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::{
fmt::Debug,
future::Future,
pin::pin,
time::{Duration, Instant},
};

use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};

pub trait StatefulFunction: Send + 'static {
type Input: Send;
type Output: Send + Debug;
fn call(&mut self, input: Self::Input) -> impl Future<Output = Self::Output> + Send;
}

type Payload<F> = (
<F as StatefulFunction>::Input,
oneshot::Sender<Option<<F as StatefulFunction>::Output>>,
);

pub struct StatefulThread<F: StatefulFunction> {
_handle: JoinHandle<()>,
input_tx: mpsc::Sender<Payload<F>>,
cancel_tx: mpsc::Sender<()>,
}

impl<F: StatefulFunction> StatefulThread<F> {
pub fn new(mut func: F) -> Self {
let (input_tx, mut input_rx) = mpsc::channel::<Payload<F>>(1024);
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
let _handle = tokio::spawn(async move {
while let Some((input, responder)) = input_rx.recv().await {
let mut output_fut = pin!(func.call(input));
let mut cancel_fut = pin!(cancel_rx.recv());
let start = Instant::now();
loop {
let log_fut = tokio::time::sleep(Duration::from_secs(1));
tokio::select! {
response = &mut output_fut => {
responder.send(Some(response)).unwrap();
break;
}
_ = &mut cancel_fut => {
responder.send(None).unwrap();
break;
}
_ = log_fut => {
println!("Waiting for {} seconds", start.elapsed().as_secs());
}
}
}
}
});
StatefulThread {
_handle,
input_tx,
cancel_tx,
}
}

pub async fn call(&self, input: F::Input) -> Option<F::Output> {
let (tx, rx) = oneshot::channel();
self.input_tx.send((input, tx)).await.unwrap();
rx.await.unwrap()
}

pub async fn cancel(&self) {
self.cancel_tx.send(()).await.unwrap();
}
}

0 comments on commit 3bd45be

Please sign in to comment.