Skip to content

Commit

Permalink
introduce parallel salsa
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbarsky committed Aug 26, 2024
1 parent f608ff8 commit 8e4ca38
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { path = "components/salsa-macros" }
smallvec = "1"
lazy_static = "1"
rayon = "1.10.0"

[dev-dependencies]
annotate-snippets = "0.11.4"
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};

// ANCHOR: db_struct
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
pub struct CalcDatabaseImpl {
storage: salsa::Storage<Self>,

Expand Down
15 changes: 11 additions & 4 deletions examples/lazy-input/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#![allow(unreachable_patterns)]
// FIXME(rust-lang/rust#129031): regression in nightly
use std::{path::PathBuf, sync::Mutex, time::Duration};
use std::{
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};

use crossbeam::channel::{unbounded, Sender};
use dashmap::{mapref::entry::Entry, DashMap};
Expand Down Expand Up @@ -77,11 +81,12 @@ trait Db: salsa::Database {
}

#[salsa::db]
#[derive(Clone)]
struct LazyInputDatabase {
storage: Storage<Self>,
logs: Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
}

impl LazyInputDatabase {
Expand All @@ -90,7 +95,9 @@ impl LazyInputDatabase {
storage: Default::default(),
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
file_watcher: Arc::new(Mutex::new(
new_debouncer(Duration::from_secs(1), tx).unwrap(),
)),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl dyn Database {
///
/// # Panics
///
/// If the view has not been added to the database (see [`DatabaseView`][])
/// If the view has not been added to the database (see [`crate::views::Views`]).
#[track_caller]
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.zalsa().views().try_view_as(self).unwrap()
Expand Down
2 changes: 1 addition & 1 deletion src/database_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
#[salsa::db]
/// Default database implementation that you can use if you don't
/// require any custom user data.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct DatabaseImpl {
storage: Storage<Self>,
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod input;
mod interned;
mod key;
mod nonce;
mod par_map;
mod revision;
mod runtime;
mod salsa_struct;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub use self::storage::Storage;
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
pub use par_map::par_map;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
Expand Down
25 changes: 25 additions & 0 deletions src/par_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::Database;

pub fn par_map<Db, D, E, C>(
db: &Db,
inputs: impl IntoParallelIterator<Item = D>,
op: fn(&Db, D) -> E,
) -> C
where
Db: Database,
D: Send,
E: Send + Sync,
C: FromParallelIterator<E>,
{
dbg!(db.zalsa().views());
let fork_db_a: Box<dyn Database> = db.fork_db();
let db: &Db = fork_db_a.as_view::<Db>();

todo!()
// inputs
// .into_par_iter()
// .map_with(wrapper, |db, element| op(db, element))
// .collect()
}
6 changes: 5 additions & 1 deletion src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
///
/// The `storage` and `storage_mut` fields must both return a reference to the same
/// storage field which must be owned by `self`.
pub unsafe trait HasStorage: Database + Sized {
pub unsafe trait HasStorage: Database + Clone + Sized {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&mut self) -> &mut Storage<Self>;
}
Expand Down Expand Up @@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
fn zalsa_local(&self) -> &ZalsaLocal {
&self.storage().zalsa_local
}

fn fork_db(&self) -> Box<dyn Database> {
Box::new(self.clone())
}
}

impl<Db: Database> RefUnwindSafe for Storage<Db> {}
Expand Down
4 changes: 4 additions & 0 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl<T: Slot> RefUnwindSafe for Page<T> {}
#[derive(Copy, Clone, Debug)]
pub struct PageIndex(usize);

unsafe impl Send for PageIndex {}

unsafe impl Sync for PageIndex {}

#[derive(Copy, Clone, Debug)]
pub struct SlotIndex(usize);

Expand Down
6 changes: 5 additions & 1 deletion src/zalsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub unsafe trait ZalsaDatabase: Any {
/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &ZalsaLocal;

/// Clone the database.
#[doc(hidden)]
fn fork_db(&self) -> Box<dyn Database>;
}

pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {
Expand Down Expand Up @@ -152,7 +156,7 @@ impl Zalsa {
}
}

pub(crate) fn views(&self) -> &Views {
pub fn views(&self) -> &Views {
&self.views_of
}

Expand Down
14 changes: 8 additions & 6 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

#![allow(dead_code)]

use std::sync::{Arc, Mutex};

use salsa::{Database, Storage};

/// Logging userdata: provides [`LogDatabase`][] trait.
///
/// If you wish to use it along with other userdata,
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
#[derive(Default)]
#[derive(Clone, Default)]
pub struct Logger {
logs: std::sync::Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
}

/// Trait implemented by databases that lets them log events.
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}

/// Database that provides logging but does not log salsa event.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct LoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -67,7 +69,7 @@ impl Database for LoggerDatabase {

/// Database that provides logging and logs salsa events.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct EventLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct DiscardLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct ExecuteValidateLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand Down
45 changes: 45 additions & 0 deletions tests/parallel_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use common::{LogDatabase, LoggerDatabase};
use salsa::{plumbing::ZalsaDatabase, Database};

mod common;

#[salsa::input]
struct ParallelInput {
field: u32,
}

#[salsa::tracked]
struct TrackedData<'db> {
field: u32,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn LogDatabase, input: ParallelInput) -> u32 {
db.push_log(format!("tracked_fn({input:?})"));
let t = TrackedData::new(db, input.field(db) * 2);
tracked_fn_extra::specify(db, t, 2222);
tracked_fn_extra(db, t)
}

#[salsa::tracked(specify)]
fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: TrackedData<'db>) -> u32 {
db.push_log(format!("tracked_fn_extra({input:?})"));
0
}

#[test]
fn execute() {
let db = common::LoggerDatabase::default();

let inputs = (1..=10)
.map(|field| {
let input = ParallelInput::new(&db, field);
let i = tracked_fn(&db, input);
input
})
.collect::<Vec<ParallelInput>>();

dbg!(db.zalsa().views());

let _foo: Vec<u32> = salsa::par_map(&db, inputs, |db, field| field.field(db));
}

0 comments on commit 8e4ca38

Please sign in to comment.