Skip to content

Commit

Permalink
feature(state-operator): Add try_load for State (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlejandroCabeza authored Jan 23, 2025
1 parent b18eebd commit 1b83775
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/ping_pong/saved_states/ping_state.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"pong_count":12}
{"pong_count":5}
9 changes: 9 additions & 0 deletions examples/ping_pong/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ pub struct StateSaveOperator {
#[async_trait::async_trait]
impl StateOperator for StateSaveOperator {
type StateInput = PingState;
type LoadError = std::io::Error;

fn try_load(
settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError> {
let state_string = std::fs::read_to_string(&settings.state_save_path)?;
serde_json::from_str(&state_string)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))
}

fn from_settings(settings: <Self::StateInput as ServiceState>::Settings) -> Self {
Self {
Expand Down
17 changes: 2 additions & 15 deletions examples/ping_pong/src/states.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// STD
use std::io;
// Crates
use overwatch_rs::services::state::ServiceState;
use serde::{Deserialize, Serialize};
Expand All @@ -14,22 +12,11 @@ pub struct PingState {
pub pong_count: u32,
}

impl PingState {
fn load_saved_state(save_path: &str) -> io::Result<Self> {
let json_state = std::fs::read(save_path)?;
let state = serde_json::from_slice(json_state.as_slice())
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
Ok(state)
}
}

impl ServiceState for PingState {
type Settings = PingSettings;
type Error = PingStateError;

fn from_settings(settings: &Self::Settings) -> Result<Self, Self::Error> {
let state = Self::load_saved_state(settings.state_save_path.as_str())
.unwrap_or_else(|_error| Self::default());
Ok(state)
fn from_settings(_settings: &Self::Settings) -> Result<Self, Self::Error> {
Ok(Self::default())
}
}
11 changes: 10 additions & 1 deletion overwatch-rs/src/services/handle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// crates
use tokio::runtime::Handle;
use tracing::info;
// internal
use crate::overwatch::handle::OverwatchHandle;
use crate::services::life_cycle::LifecycleHandle;
Expand Down Expand Up @@ -50,7 +51,15 @@ impl<S: ServiceData> ServiceHandle<S> {
settings: S::Settings,
overwatch_handle: OverwatchHandle,
) -> Result<Self, <S::State as ServiceState>::Error> {
S::State::from_settings(&settings).map(|initial_state| Self {
let initial_state = if let Ok(Some(loaded_state)) = S::StateOperator::try_load(&settings) {
info!("Loaded state from Operator");
loaded_state
} else {
info!("Couldn't load state from Operator. Creating from settings.");
S::State::from_settings(&settings)?
};

Ok(Self {
outbound_relay: None,
overwatch_handle,
settings: SettingsUpdater::new(settings),
Expand Down
29 changes: 28 additions & 1 deletion overwatch-rs/src/services/state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::convert::Infallible;
use std::error::Error;
// std
use std::marker::PhantomData;
use std::pin::Pin;
Expand All @@ -20,16 +22,26 @@ pub trait ServiceState: Sized {
type Settings;
/// Errors that can occur during state initialization
type Error;
/// Initialize a stage upon the provided settings
/// Initialize a state using the provided settings.
/// This is called when [StateOperator::try_load] doesn't return a state.
fn from_settings(settings: &Self::Settings) -> Result<Self, Self::Error>;
}

/// A state operator is an entity that can handle a state in a point of time
/// to perform any operation based on it.
/// A typical use case is to handle recovery: Saving and loading state.
#[async_trait]
pub trait StateOperator {
/// The type of state that the operator can handle
type StateInput: ServiceState;
/// Errors that can occur during state loading
type LoadError: Error;
/// State initialization method
/// In contrast to [ServiceState::from_settings], this is used to try to initialize
/// a (saved) [ServiceState] from an external source (e.g. file, database, etc.)
fn try_load(
settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError>;
/// Operator initialization method. Can be implemented over some subset of settings
fn from_settings(settings: <Self::StateInput as ServiceState>::Settings) -> Self;
/// Asynchronously perform an operation for a given state
Expand All @@ -56,6 +68,13 @@ impl<T> Clone for NoOperator<T> {
#[async_trait]
impl<StateInput: ServiceState> StateOperator for NoOperator<StateInput> {
type StateInput = StateInput;
type LoadError = Infallible;

fn try_load(
_settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError> {
Ok(None)
}

fn from_settings(_settings: <Self::StateInput as ServiceState>::Settings) -> Self {
NoOperator(PhantomData)
Expand Down Expand Up @@ -208,6 +227,7 @@ where
mod test {
use crate::services::state::{ServiceState, StateHandle, StateOperator, StateUpdater};
use async_trait::async_trait;
use std::convert::Infallible;
use std::time::Duration;
use tokio::io;
use tokio::io::AsyncWriteExt;
Expand All @@ -229,6 +249,13 @@ mod test {
#[async_trait]
impl StateOperator for PanicOnGreaterThanTen {
type StateInput = UsizeCounter;
type LoadError = Infallible;

fn try_load(
_settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError> {
Ok(None)
}

fn from_settings(_settings: <Self::StateInput as ServiceState>::Settings) -> Self {
Self
Expand Down
8 changes: 8 additions & 0 deletions overwatch-rs/tests/state_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle};
use overwatch_rs::services::relay::RelayMessage;
use overwatch_rs::services::state::{ServiceState, StateOperator};
use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId};
use std::convert::Infallible;
use std::time::Duration;
use tokio::io::{self, AsyncWriteExt};
use tokio::time::sleep;
Expand Down Expand Up @@ -49,6 +50,13 @@ pub struct CounterStateOperator;
#[async_trait]
impl StateOperator for CounterStateOperator {
type StateInput = CounterState;
type LoadError = Infallible;

fn try_load(
_settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError> {
Ok(None)
}

fn from_settings(_settings: <Self::StateInput as ServiceState>::Settings) -> Self {
CounterStateOperator
Expand Down
113 changes: 113 additions & 0 deletions overwatch-rs/tests/try_load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::thread;
use std::time::Duration;
// Crates
use async_trait::async_trait;
use overwatch_derive::Services;
use overwatch_rs::overwatch::OverwatchRunner;
use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle};
use overwatch_rs::services::relay::NoMessage;
use overwatch_rs::services::state::{ServiceState, StateOperator};
use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId};
use overwatch_rs::DynError;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::SendError;

#[derive(Clone)]
struct TryLoadState;

impl ServiceState for TryLoadState {
type Settings = TryLoadSettings;
type Error = DynError;
fn from_settings(settings: &Self::Settings) -> Result<Self, DynError> {
settings
.origin_sender
.send(String::from("ServiceState::from_settings"))?;
Ok(Self {})
}
}

#[derive(Clone)]
struct TryLoadOperator;

#[async_trait]
impl StateOperator for TryLoadOperator {
type StateInput = TryLoadState;
type LoadError = SendError<String>;

fn try_load(
settings: &<Self::StateInput as ServiceState>::Settings,
) -> Result<Option<Self::StateInput>, Self::LoadError> {
settings
.origin_sender
.send(String::from("StateOperator::try_load"))?;
Ok(Some(Self::StateInput {}))
}

fn from_settings(_settings: <Self::StateInput as ServiceState>::Settings) -> Self {
Self {}
}

async fn run(&mut self, _state: Self::StateInput) {}
}

#[derive(Debug, Clone)]
struct TryLoadSettings {
origin_sender: broadcast::Sender<String>,
}

struct TryLoad {
service_state_handle: ServiceStateHandle<Self>,
}

impl ServiceData for TryLoad {
const SERVICE_ID: ServiceId = "try_load";
type Settings = TryLoadSettings;
type State = TryLoadState;
type StateOperator = TryLoadOperator;
type Message = NoMessage;
}

#[async_trait]
impl ServiceCore for TryLoad {
fn init(
service_state: ServiceStateHandle<Self>,
_initial_state: Self::State,
) -> Result<Self, DynError> {
Ok(Self {
service_state_handle: service_state,
})
}

async fn run(self) -> Result<(), DynError> {
let Self {
service_state_handle,
..
} = self;

service_state_handle.overwatch_handle.shutdown().await;
Ok(())
}
}

#[derive(Services)]
struct TryLoadApp {
try_load: ServiceHandle<TryLoad>,
}

#[test]
fn load_state_from_operator() {
// Create a sender that will be called wherever the state is loaded
let (origin_sender, mut origin_receiver) = broadcast::channel(1);
let settings = TryLoadAppServiceSettings {
try_load: TryLoadSettings { origin_sender },
};

// Run the app
let app = OverwatchRunner::<TryLoadApp>::run(settings, None).unwrap();
app.wait_finished();

// Check if the origin was called
thread::sleep(Duration::from_secs(1));
let origin = origin_receiver.try_recv().expect("Value was not sent");
assert_eq!(origin, "StateOperator::try_load");
}

0 comments on commit 1b83775

Please sign in to comment.