diff --git a/examples/ping_pong/saved_states/ping_state.json b/examples/ping_pong/saved_states/ping_state.json index 320e7e2..9989827 100644 --- a/examples/ping_pong/saved_states/ping_state.json +++ b/examples/ping_pong/saved_states/ping_state.json @@ -1 +1 @@ -{"pong_count":12} \ No newline at end of file +{"pong_count":5} \ No newline at end of file diff --git a/examples/ping_pong/src/operators.rs b/examples/ping_pong/src/operators.rs index 12d9125..7240969 100644 --- a/examples/ping_pong/src/operators.rs +++ b/examples/ping_pong/src/operators.rs @@ -13,6 +13,15 @@ pub struct StateSaveOperator { #[async_trait::async_trait] impl StateOperator for StateSaveOperator { type StateInput = PingState; + type LoadError = std::io::Error; + + fn try_load( + settings: &::Settings, + ) -> Result, Self::LoadError> { + let state_string = std::fs::read_to_string(&settings.state_save_path)?; + serde_json::from_str(&state_string) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error)) + } fn from_settings(settings: ::Settings) -> Self { Self { diff --git a/examples/ping_pong/src/states.rs b/examples/ping_pong/src/states.rs index 0f8ad37..1e00ed6 100644 --- a/examples/ping_pong/src/states.rs +++ b/examples/ping_pong/src/states.rs @@ -1,5 +1,3 @@ -// STD -use std::io; // Crates use overwatch_rs::services::state::ServiceState; use serde::{Deserialize, Serialize}; @@ -14,22 +12,11 @@ pub struct PingState { pub pong_count: u32, } -impl PingState { - fn load_saved_state(save_path: &str) -> io::Result { - let json_state = std::fs::read(save_path)?; - let state = serde_json::from_slice(json_state.as_slice()) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - Ok(state) - } -} - impl ServiceState for PingState { type Settings = PingSettings; type Error = PingStateError; - fn from_settings(settings: &Self::Settings) -> Result { - let state = Self::load_saved_state(settings.state_save_path.as_str()) - .unwrap_or_else(|_error| Self::default()); - Ok(state) + fn from_settings(_settings: &Self::Settings) -> Result { + Ok(Self::default()) } } diff --git a/overwatch-rs/src/services/handle.rs b/overwatch-rs/src/services/handle.rs index ce19f4b..5cb94ab 100644 --- a/overwatch-rs/src/services/handle.rs +++ b/overwatch-rs/src/services/handle.rs @@ -1,5 +1,6 @@ // crates use tokio::runtime::Handle; +use tracing::info; // internal use crate::overwatch::handle::OverwatchHandle; use crate::services::life_cycle::LifecycleHandle; @@ -50,7 +51,15 @@ impl ServiceHandle { settings: S::Settings, overwatch_handle: OverwatchHandle, ) -> Result::Error> { - S::State::from_settings(&settings).map(|initial_state| Self { + let initial_state = if let Ok(Some(loaded_state)) = S::StateOperator::try_load(&settings) { + info!("Loaded state from Operator"); + loaded_state + } else { + info!("Couldn't load state from Operator. Creating from settings."); + S::State::from_settings(&settings)? + }; + + Ok(Self { outbound_relay: None, overwatch_handle, settings: SettingsUpdater::new(settings), diff --git a/overwatch-rs/src/services/state.rs b/overwatch-rs/src/services/state.rs index ae11291..1148720 100644 --- a/overwatch-rs/src/services/state.rs +++ b/overwatch-rs/src/services/state.rs @@ -1,3 +1,5 @@ +use std::convert::Infallible; +use std::error::Error; // std use std::marker::PhantomData; use std::pin::Pin; @@ -20,16 +22,26 @@ pub trait ServiceState: Sized { type Settings; /// Errors that can occur during state initialization type Error; - /// Initialize a stage upon the provided settings + /// Initialize a state using the provided settings. + /// This is called when [StateOperator::try_load] doesn't return a state. fn from_settings(settings: &Self::Settings) -> Result; } /// A state operator is an entity that can handle a state in a point of time /// to perform any operation based on it. +/// A typical use case is to handle recovery: Saving and loading state. #[async_trait] pub trait StateOperator { /// The type of state that the operator can handle type StateInput: ServiceState; + /// Errors that can occur during state loading + type LoadError: Error; + /// State initialization method + /// In contrast to [ServiceState::from_settings], this is used to try to initialize + /// a (saved) [ServiceState] from an external source (e.g. file, database, etc.) + fn try_load( + settings: &::Settings, + ) -> Result, Self::LoadError>; /// Operator initialization method. Can be implemented over some subset of settings fn from_settings(settings: ::Settings) -> Self; /// Asynchronously perform an operation for a given state @@ -56,6 +68,13 @@ impl Clone for NoOperator { #[async_trait] impl StateOperator for NoOperator { type StateInput = StateInput; + type LoadError = Infallible; + + fn try_load( + _settings: &::Settings, + ) -> Result, Self::LoadError> { + Ok(None) + } fn from_settings(_settings: ::Settings) -> Self { NoOperator(PhantomData) @@ -208,6 +227,7 @@ where mod test { use crate::services::state::{ServiceState, StateHandle, StateOperator, StateUpdater}; use async_trait::async_trait; + use std::convert::Infallible; use std::time::Duration; use tokio::io; use tokio::io::AsyncWriteExt; @@ -229,6 +249,13 @@ mod test { #[async_trait] impl StateOperator for PanicOnGreaterThanTen { type StateInput = UsizeCounter; + type LoadError = Infallible; + + fn try_load( + _settings: &::Settings, + ) -> Result, Self::LoadError> { + Ok(None) + } fn from_settings(_settings: ::Settings) -> Self { Self diff --git a/overwatch-rs/tests/state_handling.rs b/overwatch-rs/tests/state_handling.rs index fd58431..ac65186 100644 --- a/overwatch-rs/tests/state_handling.rs +++ b/overwatch-rs/tests/state_handling.rs @@ -5,6 +5,7 @@ use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle}; use overwatch_rs::services::relay::RelayMessage; use overwatch_rs::services::state::{ServiceState, StateOperator}; use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId}; +use std::convert::Infallible; use std::time::Duration; use tokio::io::{self, AsyncWriteExt}; use tokio::time::sleep; @@ -49,6 +50,13 @@ pub struct CounterStateOperator; #[async_trait] impl StateOperator for CounterStateOperator { type StateInput = CounterState; + type LoadError = Infallible; + + fn try_load( + _settings: &::Settings, + ) -> Result, Self::LoadError> { + Ok(None) + } fn from_settings(_settings: ::Settings) -> Self { CounterStateOperator diff --git a/overwatch-rs/tests/try_load.rs b/overwatch-rs/tests/try_load.rs new file mode 100644 index 0000000..0fbc16a --- /dev/null +++ b/overwatch-rs/tests/try_load.rs @@ -0,0 +1,113 @@ +use std::thread; +use std::time::Duration; +// Crates +use async_trait::async_trait; +use overwatch_derive::Services; +use overwatch_rs::overwatch::OverwatchRunner; +use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle}; +use overwatch_rs::services::relay::NoMessage; +use overwatch_rs::services::state::{ServiceState, StateOperator}; +use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId}; +use overwatch_rs::DynError; +use tokio::sync::broadcast; +use tokio::sync::broadcast::error::SendError; + +#[derive(Clone)] +struct TryLoadState; + +impl ServiceState for TryLoadState { + type Settings = TryLoadSettings; + type Error = DynError; + fn from_settings(settings: &Self::Settings) -> Result { + settings + .origin_sender + .send(String::from("ServiceState::from_settings"))?; + Ok(Self {}) + } +} + +#[derive(Clone)] +struct TryLoadOperator; + +#[async_trait] +impl StateOperator for TryLoadOperator { + type StateInput = TryLoadState; + type LoadError = SendError; + + fn try_load( + settings: &::Settings, + ) -> Result, Self::LoadError> { + settings + .origin_sender + .send(String::from("StateOperator::try_load"))?; + Ok(Some(Self::StateInput {})) + } + + fn from_settings(_settings: ::Settings) -> Self { + Self {} + } + + async fn run(&mut self, _state: Self::StateInput) {} +} + +#[derive(Debug, Clone)] +struct TryLoadSettings { + origin_sender: broadcast::Sender, +} + +struct TryLoad { + service_state_handle: ServiceStateHandle, +} + +impl ServiceData for TryLoad { + const SERVICE_ID: ServiceId = "try_load"; + type Settings = TryLoadSettings; + type State = TryLoadState; + type StateOperator = TryLoadOperator; + type Message = NoMessage; +} + +#[async_trait] +impl ServiceCore for TryLoad { + fn init( + service_state: ServiceStateHandle, + _initial_state: Self::State, + ) -> Result { + Ok(Self { + service_state_handle: service_state, + }) + } + + async fn run(self) -> Result<(), DynError> { + let Self { + service_state_handle, + .. + } = self; + + service_state_handle.overwatch_handle.shutdown().await; + Ok(()) + } +} + +#[derive(Services)] +struct TryLoadApp { + try_load: ServiceHandle, +} + +#[test] +fn load_state_from_operator() { + // Create a sender that will be called wherever the state is loaded + let (origin_sender, mut origin_receiver) = broadcast::channel(1); + let settings = TryLoadAppServiceSettings { + try_load: TryLoadSettings { origin_sender }, + }; + + // Run the app + let app = OverwatchRunner::::run(settings, None).unwrap(); + app.wait_finished(); + + // Check if the origin was called + thread::sleep(Duration::from_secs(1)); + let origin = origin_receiver.try_recv().expect("Value was not sent"); + assert_eq!(origin, "StateOperator::try_load"); +}