Skip to content

Commit

Permalink
introduce parallel salsa
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbarsky committed Aug 23, 2024
1 parent f608ff8 commit 0668aec
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { path = "components/salsa-macros" }
smallvec = "1"
lazy_static = "1"
rayon = "1.10.0"

[dev-dependencies]
annotate-snippets = "0.11.4"
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};

// ANCHOR: db_struct
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
pub struct CalcDatabaseImpl {
storage: salsa::Storage<Self>,

Expand Down
15 changes: 11 additions & 4 deletions examples/lazy-input/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#![allow(unreachable_patterns)]
// FIXME(rust-lang/rust#129031): regression in nightly
use std::{path::PathBuf, sync::Mutex, time::Duration};
use std::{
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};

use crossbeam::channel::{unbounded, Sender};
use dashmap::{mapref::entry::Entry, DashMap};
Expand Down Expand Up @@ -77,11 +81,12 @@ trait Db: salsa::Database {
}

#[salsa::db]
#[derive(Clone)]
struct LazyInputDatabase {
storage: Storage<Self>,
logs: Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
}

impl LazyInputDatabase {
Expand All @@ -90,7 +95,9 @@ impl LazyInputDatabase {
storage: Default::default(),
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
file_watcher: Arc::new(Mutex::new(
new_debouncer(Duration::from_secs(1), tx).unwrap(),
)),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl dyn Database {
///
/// # Panics
///
/// If the view has not been added to the database (see [`DatabaseView`][])
/// If the view has not been added to the database (see [`crate::views::Views`]).
#[track_caller]
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.zalsa().views().try_view_as(self).unwrap()
Expand Down
2 changes: 1 addition & 1 deletion src/database_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
#[salsa::db]
/// Default database implementation that you can use if you don't
/// require any custom user data.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct DatabaseImpl {
storage: Storage<Self>,
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod input;
mod interned;
mod key;
mod nonce;
mod par_map;
mod revision;
mod runtime;
mod salsa_struct;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub use self::storage::Storage;
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
pub use par_map::par_map;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
Expand Down
29 changes: 29 additions & 0 deletions src/par_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::sync::Arc;

use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::Database;

pub fn par_map<Db, D, E, C>(
db: &Db,
inputs: impl IntoParallelIterator<Item = D>,
op: fn(&Db, D) -> E,
) -> C
where
Db: ?Sized + Database,
D: Send,
E: Send + Sync,
C: FromParallelIterator<E>,
{
let fork_db_a: Box<dyn Database> = db.fork_db();
let fork_db_a: &Db = fork_db_a.as_view::<Db>();
// not sure what to use here, but here are the properties I need:
// - "implements" `Clone` via `db.fork_db()`
// - is `Send`, but absolutely not `Sync`
let wrapper = todo!();

inputs
.into_par_iter()
.map_with(&wrapper, |db, element| op(db, element))
.collect()
}
6 changes: 5 additions & 1 deletion src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
///
/// The `storage` and `storage_mut` fields must both return a reference to the same
/// storage field which must be owned by `self`.
pub unsafe trait HasStorage: Database + Sized {
pub unsafe trait HasStorage: Database + Clone + Sized {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&mut self) -> &mut Storage<Self>;
}
Expand Down Expand Up @@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
fn zalsa_local(&self) -> &ZalsaLocal {
&self.storage().zalsa_local
}

fn fork_db(&self) -> Box<dyn Database> {
Box::new(self.clone())
}
}

impl<Db: Database> RefUnwindSafe for Storage<Db> {}
Expand Down
4 changes: 4 additions & 0 deletions src/zalsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub unsafe trait ZalsaDatabase: Any {
/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &ZalsaLocal;

/// Clone the database.
#[doc(hidden)]
fn fork_db(&self) -> Box<dyn Database>;
}

pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {
Expand Down
14 changes: 8 additions & 6 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

#![allow(dead_code)]

use std::sync::{Arc, Mutex};

use salsa::{Database, Storage};

/// Logging userdata: provides [`LogDatabase`][] trait.
///
/// If you wish to use it along with other userdata,
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
#[derive(Default)]
#[derive(Clone, Default)]
pub struct Logger {
logs: std::sync::Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
}

/// Trait implemented by databases that lets them log events.
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}

/// Database that provides logging but does not log salsa event.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct LoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -67,7 +69,7 @@ impl Database for LoggerDatabase {

/// Database that provides logging and logs salsa events.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct EventLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct DiscardLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct ExecuteValidateLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand Down
1 change: 1 addition & 0 deletions tests/parallel/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recover;
mod parallel_map;
mod signal;
17 changes: 17 additions & 0 deletions tests/parallel/parallel_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use crate::setup::Knobs;

#[salsa::input]
struct ParallelInput {
field: i32,
}

#[test]
fn execute() {
let db = Knobs::default();

let inputs = (1..=10)
.map(|field| ParallelInput::new(&db, field))
.collect::<Vec<ParallelInput>>();

let _foo: Vec<i32> = salsa::par_map(&db, inputs, |db, field| field.field(db));
}

0 comments on commit 0668aec

Please sign in to comment.