Skip to content

Commit

Permalink
allow more loops to be converted to iter
Browse files Browse the repository at this point in the history
  • Loading branch information
lucarlig committed Mar 19, 2024
1 parent f711a7c commit 9552a6a
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 108 deletions.
49 changes: 14 additions & 35 deletions lints/for_each/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
#![feature(let_chains)]

extern crate rustc_errors;
extern crate rustc_hash;
extern crate rustc_hir;
extern crate rustc_hir_typeck;
extern crate rustc_infer;
extern crate rustc_middle;

use clippy_utils::{higher::ForLoop, ty::is_copy};
mod variable_check;

use clippy_utils::higher::ForLoop;
use rustc_errors::Applicability;
use rustc_hir::{
def::Res,
intravisit::{walk_expr, Visitor},
Expr, ExprKind, HirId, Node,
Expr, ExprKind,
};
use rustc_lint::{LateContext, LateLintPass, LintContext};
use utils::span_to_snippet_macro;
use variable_check::check_variables;

dylint_linting::declare_late_lint! {
/// ### What it does
/// parallelize iterators using rayon
Expand Down Expand Up @@ -48,16 +55,9 @@ impl<'tcx> LateLintPass<'tcx> for ForEach {
let src_map = cx.sess().source_map();

// Make sure we ignore cases that require a try_foreach
let mut validator = Validator {
is_valid: true,
is_arg: true,
arg_variables: vec![],
cx,
};
validator.visit_expr(arg);
validator.is_arg = false;
let mut validator = Validator { is_valid: true };
validator.visit_expr(body);
if !validator.is_valid {
if !validator.is_valid || !check_variables(cx, body) {
return;
}
// Check whether the iter is explicit
Expand Down Expand Up @@ -122,38 +122,17 @@ impl Visitor<'_> for IterExplorer {
}
}

struct Validator<'a, 'tcx> {
struct Validator {
is_valid: bool,
cx: &'a LateContext<'tcx>,
is_arg: bool,
arg_variables: Vec<HirId>,
}

impl<'a, 'tcx> Visitor<'_> for Validator<'a, 'tcx> {
impl Visitor<'_> for Validator {
fn visit_expr(&mut self, ex: &Expr) {
match &ex.kind {
ExprKind::Loop(_, _, _, _)
| ExprKind::Closure(_)
| ExprKind::Ret(_)
| ExprKind::Break(_, _) => self.is_valid = false,
ExprKind::Path(ref path) => {
if let Res::Local(hir_id) = self.cx.typeck_results().qpath_res(path, ex.hir_id) {
if let Node::Local(local) = self.cx.tcx.parent_hir_node(hir_id) {
if self.is_arg {
self.arg_variables.push(local.hir_id);
return;
}

if let Some(expr) = local.init
&& !self.arg_variables.contains(&local.hir_id)
{
let ty = self.cx.typeck_results().expr_ty(expr);
self.is_valid &= is_copy(self.cx, ty)
}
}
}
}

_ => walk_expr(self, ex),
}
}
Expand Down
123 changes: 123 additions & 0 deletions lints/for_each/src/variable_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use clippy_utils::visitors::for_each_expr_with_closures;
use rustc_hash::FxHashSet;
use rustc_hir as hir;
use rustc_hir_typeck::expr_use_visitor as euv;
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_lint::LateContext;
use rustc_middle::{
hir::map::associated_body,
mir::FakeReadCause,
ty::{self, Ty, UpvarId, UpvarPath},
};
use std::{collections::HashSet, ops::ControlFlow};

pub struct MutablyUsedVariablesCtxt<'tcx> {
all_vars: FxHashSet<Ty<'tcx>>,
prev_bind: Option<hir::HirId>,
/// In async functions, the inner AST is composed of multiple layers until we reach the code
/// defined by the user. Because of that, some variables are marked as mutably borrowed even
/// though they're not. This field lists the `HirId` that should not be considered as mutable
/// use of a variable.
prev_move_to_closure: hir::HirIdSet,
}

// TODO: remove repetation is this two function almost identical
pub fn check_variables<'tcx>(cx: &LateContext<'tcx>, ex: &'tcx hir::Expr) -> bool {
let MutablyUsedVariablesCtxt { all_vars, .. } = {
let body_owner = ex.hir_id.owner.def_id;

let mut ctx = MutablyUsedVariablesCtxt {
all_vars: FxHashSet::default(),
prev_bind: None,
prev_move_to_closure: hir::HirIdSet::default(),
};
let infcx = cx.tcx.infer_ctxt().build();
euv::ExprUseVisitor::new(
&mut ctx,
&infcx,
body_owner,
cx.param_env,
cx.typeck_results(),
)
.walk_expr(ex);

let mut checked_closures = FxHashSet::default();

// We retrieve all the closures declared in the function because they will not be found
// by `euv::Delegate`.
let mut closures: FxHashSet<hir::def_id::LocalDefId> = FxHashSet::default();
for_each_expr_with_closures(cx, ex, |expr| {
if let hir::ExprKind::Closure(closure) = expr.kind {
closures.insert(closure.def_id);
}
ControlFlow::<()>::Continue(())
});
check_closures(&mut ctx, cx, &infcx, &mut checked_closures, closures);

ctx
};

all_vars.is_empty()
}

pub fn check_closures<'tcx, S: ::std::hash::BuildHasher>(
ctx: &mut MutablyUsedVariablesCtxt<'tcx>,
cx: &LateContext<'tcx>,
infcx: &InferCtxt<'tcx>,
checked_closures: &mut HashSet<hir::def_id::LocalDefId, S>,
closures: HashSet<hir::def_id::LocalDefId, S>,
) {
let hir = cx.tcx.hir();
for closure in closures {
if !checked_closures.insert(closure) {
continue;
}
ctx.prev_bind = None;
ctx.prev_move_to_closure.clear();
if let Some(body) = cx
.tcx
.opt_hir_node_by_def_id(closure)
.and_then(associated_body)
.map(|(_, body_id)| hir.body(body_id))
{
euv::ExprUseVisitor::new(ctx, infcx, closure, cx.param_env, cx.typeck_results())
.consume_body(body);
}
}
}

impl<'tcx> euv::Delegate<'tcx> for MutablyUsedVariablesCtxt<'tcx> {
#[allow(clippy::if_same_then_else)]
fn consume(&mut self, cmt: &euv::PlaceWithHirId<'tcx>, _: hir::HirId) {
if let euv::Place {
base:
euv::PlaceBase::Local(_)
| euv::PlaceBase::Upvar(UpvarId {
var_path: UpvarPath { hir_id: _ },
..
}),
base_ty,
..
} = &cmt.place
{
self.all_vars.insert(*base_ty);
}
}

#[allow(clippy::if_same_then_else)]
fn borrow(&mut self, _: &euv::PlaceWithHirId<'tcx>, _: hir::HirId, _: ty::BorrowKind) {}

fn mutate(&mut self, _: &euv::PlaceWithHirId<'tcx>, _id: hir::HirId) {}
fn copy(&mut self, _: &euv::PlaceWithHirId<'tcx>, _: hir::HirId) {}
fn fake_read(
&mut self,
_: &rustc_hir_typeck::expr_use_visitor::PlaceWithHirId<'tcx>,
_: FakeReadCause,
_id: hir::HirId,
) {
}

fn bind(&mut self, _: &euv::PlaceWithHirId<'tcx>, id: hir::HirId) {
self.prev_bind = Some(id);
}
}
19 changes: 19 additions & 0 deletions lints/for_each/ui/main.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ impl MyBuilder {
}
}

struct LocalQueue {}

impl LocalQueue {
fn new() -> Self {
Self {}
}
}

// no
fn build_request_builder() {
let headers = vec![("Key1", "Value1"), ("Key2", "Value2")];
Expand Down Expand Up @@ -73,6 +81,7 @@ fn nested_loop() {
}
}

// for_each
fn get_upload_file_total_size() -> u64 {
let some_num = vec![0; 10];
let mut file_total_size = 0;
Expand All @@ -95,4 +104,14 @@ fn return_loop() {
}
}

// for_each
fn local_into_iter() {
let thread_num = 10;
let mut locals = vec![];

(0..thread_num).into_iter().for_each(|_| {
locals.push(LocalQueue::new());
});
}

// TODO: double capture
19 changes: 19 additions & 0 deletions lints/for_each/ui/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ impl MyBuilder {
}
}

struct LocalQueue {}

impl LocalQueue {
fn new() -> Self {
Self {}
}
}

// no
fn build_request_builder() {
let headers = vec![("Key1", "Value1"), ("Key2", "Value2")];
Expand Down Expand Up @@ -73,6 +81,7 @@ fn nested_loop() {
}
}

// for_each
fn get_upload_file_total_size() -> u64 {
let some_num = vec![0; 10];
let mut file_total_size = 0;
Expand All @@ -95,4 +104,14 @@ fn return_loop() {
}
}

// for_each
fn local_into_iter() {
let thread_num = 10;
let mut locals = vec![];

for _ in 0..thread_num {
locals.push(LocalQueue::new());
}
}

// TODO: double capture
25 changes: 20 additions & 5 deletions lints/for_each/ui/main.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
warning: use a for_each to enable iterator refinement
--> $DIR/main.rs:35:5
--> $DIR/main.rs:43:5
|
LL | / for x in 1..=100 {
LL | | println!("{x}");
Expand All @@ -15,7 +15,7 @@ LL + });
|

warning: use a for_each to enable iterator refinement
--> $DIR/main.rs:44:5
--> $DIR/main.rs:52:5
|
LL | / for a in vec_a {
LL | | if a == 1 {
Expand All @@ -36,7 +36,7 @@ LL + });
|

warning: use a for_each to enable iterator refinement
--> $DIR/main.rs:70:9
--> $DIR/main.rs:78:9
|
LL | / for b in &vec_b {
LL | | dbg!(a, b);
Expand All @@ -51,7 +51,7 @@ LL + });
|

warning: use a for_each to enable iterator refinement
--> $DIR/main.rs:79:5
--> $DIR/main.rs:88:5
|
LL | / for _ in 0..some_num.len() {
LL | | let (_, upload_size) = (true, 99);
Expand All @@ -67,5 +67,20 @@ LL + file_total_size += upload_size;
LL + });
|

warning: 4 warnings emitted
warning: use a for_each to enable iterator refinement
--> $DIR/main.rs:112:5
|
LL | / for _ in 0..thread_num {
LL | | locals.push(LocalQueue::new());
LL | | }
| |_____^
|
help: try using `for_each` on the iterator
|
LL ~ (0..thread_num).into_iter().for_each(|_| {
LL + locals.push(LocalQueue::new());
LL + });
|

warning: 5 warnings emitted

2 changes: 1 addition & 1 deletion lints/map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ rustc_private = true
rlib = ["dylint_linting/constituent"]

[[example]]
name = "fold_main"
name = "map_main"
path = "ui/main.rs"

[lints.clippy]
Expand Down
2 changes: 1 addition & 1 deletion utils/src/constants.rs → lints/par_iter/src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub const TRAIT_PATHS: &[&[&str]] = &[
pub(crate) const TRAIT_PATHS: &[&[&str]] = &[
&["rayon", "iter", "IntoParallelIterator"],
&["rayon", "iter", "ParallelIterator"],
&["rayon", "iter", "IndexedParallelIterator"],
Expand Down
10 changes: 8 additions & 2 deletions lints/par_iter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

extern crate rustc_data_structures;
extern crate rustc_errors;
extern crate rustc_hash;
extern crate rustc_hir;

extern crate rustc_hir_typeck;
extern crate rustc_infer;
extern crate rustc_middle;
extern crate rustc_span;
extern crate rustc_trait_selection;

mod constants;
mod variable_check;

use clippy_utils::{get_parent_expr, get_trait_def_id};
use rustc_data_structures::fx::FxHashSet;
Expand All @@ -17,7 +23,7 @@ use rustc_hir::{self as hir};
use rustc_lint::{LateContext, LateLintPass, LintContext};
use rustc_middle::ty::{self, Ty};
use rustc_span::sym;
use utils::variable_check::{
use variable_check::{
check_implements_par_iter, check_trait_impl, check_variables, generate_suggestion,
is_type_valid,
};
Expand Down
Loading

0 comments on commit 9552a6a

Please sign in to comment.