Skip to content

Commit

Permalink
♻️ Clean up moar
Browse files Browse the repository at this point in the history
  • Loading branch information
Philogy committed Mar 20, 2024
1 parent 14add0d commit b3a775b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 38 deletions.
1 change: 1 addition & 0 deletions src/parser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl<O, E: Error<Token>, P: Parser<Token, O, Error = E>> TokenParser<O, E> for P
pub trait OrDefaultParser<I: Clone, O: Default, E: Error<I>>:
Parser<I, O, Error = E> + Sized
{
#[allow(clippy::type_complexity)]
fn or_default(self) -> Map<OrNot<Self>, fn(Option<O>) -> O, Option<O>> {
self.or_not().map(Option::unwrap_or_default)
}
Expand Down
38 changes: 12 additions & 26 deletions src/scheduling/astar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,10 @@ pub struct Explored {
type ExploredMap = HashMap<u64, Explored, BuildHasherDefault<NoopHasher>>;
type ScheduleQueue = BinaryHeap<ScheduleNode>;

struct FastHasher(ahash::AHasher);

impl FastHasher {
fn new() -> Self {
Self(ahash::AHasher::default())
}

fn hash_one_off<T: Hash>(&mut self, value: &T) -> u64 {
let buf = &mut self.0 as *mut ahash::AHasher as *mut u64;
unsafe { *buf = 0 };

value.hash(&mut self.0);

self.0.finish()
}
fn hash_one_off<T: Hash>(value: &T) -> u64 {
let mut hashooor = ahash::AHasher::default();
value.hash(&mut hashooor);
hashooor.finish()
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -191,12 +180,10 @@ pub trait AStarScheduler: Sized + Sync + Send {
at_end: start.all_done(),
});

let mut hasher = FastHasher::new();

// 1. Pop top of priority queue (node closest to end according to actual cost + estimated
// remaining distance).
while let Some(mut node) = queue.pop() {
let came_from = hasher.hash_one_off(&node.state);
let came_from = hash_one_off(&node.state);
// 2a. If the shortest node is the end we know we found our solution, accumulate the
// steps and return.
if node.at_end {
Expand Down Expand Up @@ -235,13 +222,13 @@ pub trait AStarScheduler: Sized + Sync + Send {
}
let new_cost = node.cost + steps.iter().map(|step| step.cost()).sum::<u32>();
tracker.total_explored += 1;
let new_state_hash = hasher.hash_one_off(&new_state);
let new_state_hash = hash_one_off(&new_state);

let new_cost_better = match explored.get(&new_state_hash) {
match explored.get(&new_state_hash) {
Some(e) => new_cost < e.cost,
None => true,
};
if new_cost_better {
}
.then(|| {
let out = explored.insert(
new_state_hash,
Explored {
Expand All @@ -252,14 +239,13 @@ pub trait AStarScheduler: Sized + Sync + Send {
);
tracker.total_collisions += if out.is_some() { 1 } else { 0 };
let score = new_cost + self.estimate_remaining_cost(info, &new_state, new_cost);
return Some(ScheduleNode {
ScheduleNode {
state: new_state,
cost: new_cost,
score,
at_end,
});
}
None
}
})
}));
}

Expand Down
23 changes: 11 additions & 12 deletions src/transformer/ir_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ impl SemanticContext {

/// Returns the IDs of the nodes that the newly inserted node is now dependent on (non-operand
/// semantic dependency).
pub fn record_write(&mut self, dependency: &String, id: CompNodeId) {
pub fn record_write(&mut self, dependency: &str, id: CompNodeId) {
let mut pre_deps = self
.last_reads
.insert(dependency.clone(), Vec::new())
.insert(dependency.to_string(), Vec::new())
.unwrap_or_default();
pre_deps.extend(self.last_write.insert(dependency.clone(), id));
pre_deps.extend(self.last_write.insert(dependency.to_string(), id));

self.nodes_sources[id].0.post.extend(pre_deps);
}
}

fn unspan<T: Clone + Debug>(spanned: &Vec<Spanned<T>>) -> Vec<T> {
fn unspan<T: Clone + Debug>(spanned: &[Spanned<T>]) -> Vec<T> {
spanned.iter().map(Spanned::unwrap_ref).cloned().collect()
}

Expand Down Expand Up @@ -235,7 +235,7 @@ fn set_blocked_count(input_ids: &[CompNodeId], output_ids: &[CompNodeId], nodes:
let total = nodes.len();

let mut blocked_by = vec![0u32; total];
let mut stack_count = vec![0u32; total];
let mut stack_counts = vec![0u32; total];

for node in nodes.iter() {
for post_id in node.post.iter() {
Expand All @@ -244,23 +244,22 @@ fn set_blocked_count(input_ids: &[CompNodeId], output_ids: &[CompNodeId], nodes:
for dep_id in node.operands.iter() {
blocked_by[*dep_id] += 1;
// Blocked once as an argument.
stack_count[*dep_id] += 1;
stack_counts[*dep_id] += 1;
}
}

for output_id in output_ids.iter() {
stack_count[*output_id] += 1;
stack_counts[*output_id] += 1;
}

for id in 0..total {
let required_dedups = stack_count[id].max(1) - 1;
let required_dedups = stack_counts[id].max(1) - 1;
*nodes[id].blocked_by.as_mut().unwrap() += required_dedups + blocked_by[id];
}

for id in 0..total {
if nodes[id].blocked_by.unwrap() == 0 && input_ids.contains(&id) && output_ids.contains(&id)
{
nodes[id].blocked_by = None;
for (id, node) in nodes.iter_mut().enumerate() {
if node.blocked_by.unwrap() == 0 && input_ids.contains(&id) && output_ids.contains(&id) {
node.blocked_by = None;
}
}
}
Expand Down

0 comments on commit b3a775b

Please sign in to comment.