From 0668aec5bc18170d08c4f844877621921ace0718 Mon Sep 17 00:00:00 2001 From: David Barsky Date: Fri, 23 Aug 2024 14:30:59 -0400 Subject: [PATCH] introduce parallel salsa --- Cargo.toml | 1 + examples/calc/db.rs | 2 +- examples/lazy-input/main.rs | 15 +++++++++++---- src/database.rs | 2 +- src/database_impl.rs | 2 +- src/lib.rs | 2 ++ src/par_map.rs | 29 +++++++++++++++++++++++++++++ src/storage.rs | 6 +++++- src/zalsa.rs | 4 ++++ tests/common/mod.rs | 14 ++++++++------ tests/parallel/main.rs | 1 + tests/parallel/parallel_map.rs | 17 +++++++++++++++++ 12 files changed, 81 insertions(+), 14 deletions(-) create mode 100644 src/par_map.rs create mode 100644 tests/parallel/parallel_map.rs diff --git a/Cargo.toml b/Cargo.toml index d3815b18..96a2f9c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { path = "components/salsa-macros" } smallvec = "1" lazy_static = "1" +rayon = "1.10.0" [dev-dependencies] annotate-snippets = "0.11.4" diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 2873ed5b..924205c2 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex}; // ANCHOR: db_struct #[salsa::db] -#[derive(Default)] +#[derive(Default, Clone)] pub struct CalcDatabaseImpl { storage: salsa::Storage, diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index 792b7f34..ff998fa3 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -1,6 +1,10 @@ #![allow(unreachable_patterns)] // FIXME(rust-lang/rust#129031): regression in nightly -use std::{path::PathBuf, sync::Mutex, time::Duration}; +use std::{ + path::PathBuf, + sync::{Arc, Mutex}, + time::Duration, +}; use crossbeam::channel::{unbounded, Sender}; use dashmap::{mapref::entry::Entry, DashMap}; @@ -77,11 +81,12 @@ trait Db: salsa::Database { } #[salsa::db] +#[derive(Clone)] struct LazyInputDatabase { storage: Storage, - logs: Mutex>, + logs: Arc>>, files: DashMap, - file_watcher: Mutex>, + file_watcher: Arc>>, } impl LazyInputDatabase { @@ -90,7 +95,9 @@ impl LazyInputDatabase { storage: Default::default(), logs: Default::default(), files: DashMap::new(), - file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()), + file_watcher: Arc::new(Mutex::new( + new_debouncer(Duration::from_secs(1), tx).unwrap(), + )), } } } diff --git a/src/database.rs b/src/database.rs index 5a32bd9b..a978df0e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -90,7 +90,7 @@ impl dyn Database { /// /// # Panics /// - /// If the view has not been added to the database (see [`DatabaseView`][]) + /// If the view has not been added to the database (see [`crate::views::Views`]). #[track_caller] pub fn as_view(&self) -> &DbView { self.zalsa().views().try_view_as(self).unwrap() diff --git a/src/database_impl.rs b/src/database_impl.rs index 71da9fff..e31c6ed7 100644 --- a/src/database_impl.rs +++ b/src/database_impl.rs @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage}; #[salsa::db] /// Default database implementation that you can use if you don't /// require any custom user data. -#[derive(Default)] +#[derive(Default, Clone)] pub struct DatabaseImpl { storage: Storage, } diff --git a/src/lib.rs b/src/lib.rs index c23d9c2e..4cd1921f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ mod input; mod interned; mod key; mod nonce; +mod par_map; mod revision; mod runtime; mod salsa_struct; @@ -45,6 +46,7 @@ pub use self::storage::Storage; pub use self::update::Update; pub use self::zalsa::IngredientIndex; pub use crate::attach::with_attached_database; +pub use par_map::par_map; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; diff --git a/src/par_map.rs b/src/par_map.rs new file mode 100644 index 00000000..64646b0e --- /dev/null +++ b/src/par_map.rs @@ -0,0 +1,29 @@ +use std::sync::Arc; + +use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; + +use crate::Database; + +pub fn par_map( + db: &Db, + inputs: impl IntoParallelIterator, + op: fn(&Db, D) -> E, +) -> C +where + Db: ?Sized + Database, + D: Send, + E: Send + Sync, + C: FromParallelIterator, +{ + let fork_db_a: Box = db.fork_db(); + let fork_db_a: &Db = fork_db_a.as_view::(); + // not sure what to use here, but here are the properties I need: + // - "implements" `Clone` via `db.fork_db()` + // - is `Send`, but absolutely not `Sync` + let wrapper = todo!(); + + inputs + .into_par_iter() + .map_with(&wrapper, |db, element| op(db, element)) + .collect() +} diff --git a/src/storage.rs b/src/storage.rs index c9e1273b..40986291 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -15,7 +15,7 @@ use crate::{ /// /// The `storage` and `storage_mut` fields must both return a reference to the same /// storage field which must be owned by `self`. -pub unsafe trait HasStorage: Database + Sized { +pub unsafe trait HasStorage: Database + Clone + Sized { fn storage(&self) -> &Storage; fn storage_mut(&mut self) -> &mut Storage; } @@ -108,6 +108,10 @@ unsafe impl ZalsaDatabase for T { fn zalsa_local(&self) -> &ZalsaLocal { &self.storage().zalsa_local } + + fn fork_db(&self) -> Box { + Box::new(self.clone()) + } } impl RefUnwindSafe for Storage {} diff --git a/src/zalsa.rs b/src/zalsa.rs index ab52814d..42242f7a 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -47,6 +47,10 @@ pub unsafe trait ZalsaDatabase: Any { /// Access the thread-local state associated with this database #[doc(hidden)] fn zalsa_local(&self) -> &ZalsaLocal; + + /// Clone the database. + #[doc(hidden)] + fn fork_db(&self) -> Box; } pub fn views(db: &Db) -> &Views { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4c4e9fc7..19f818b6 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,15 +2,17 @@ #![allow(dead_code)] +use std::sync::{Arc, Mutex}; + use salsa::{Database, Storage}; /// Logging userdata: provides [`LogDatabase`][] trait. /// /// If you wish to use it along with other userdata, /// you can also embed it in another struct and implement [`HasLogger`][] for that struct. -#[derive(Default)] +#[derive(Clone, Default)] pub struct Logger { - logs: std::sync::Mutex>, + logs: Arc>>, } /// Trait implemented by databases that lets them log events. @@ -48,7 +50,7 @@ impl LogDatabase for Db {} /// Database that provides logging but does not log salsa event. #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct LoggerDatabase { storage: Storage, logger: Logger, @@ -67,7 +69,7 @@ impl Database for LoggerDatabase { /// Database that provides logging and logs salsa events. #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct EventLoggerDatabase { storage: Storage, logger: Logger, @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase { } #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct DiscardLoggerDatabase { storage: Storage, logger: Logger, @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase { } #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct ExecuteValidateLoggerDatabase { storage: Storage, logger: Logger, diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 578a83cb..e01e4654 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -5,4 +5,5 @@ mod parallel_cycle_all_recover; mod parallel_cycle_mid_recover; mod parallel_cycle_none_recover; mod parallel_cycle_one_recover; +mod parallel_map; mod signal; diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs new file mode 100644 index 00000000..d9f03907 --- /dev/null +++ b/tests/parallel/parallel_map.rs @@ -0,0 +1,17 @@ +use crate::setup::Knobs; + +#[salsa::input] +struct ParallelInput { + field: i32, +} + +#[test] +fn execute() { + let db = Knobs::default(); + + let inputs = (1..=10) + .map(|field| ParallelInput::new(&db, field)) + .collect::>(); + + let _foo: Vec = salsa::par_map(&db, inputs, |db, field| field.field(db)); +}