Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Fn trait #6659

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions corelib/src/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ use range::RangeOp;

mod function;
pub use function::FnOnce;
pub use function::Fn;
21 changes: 21 additions & 0 deletions corelib/src/ops/function.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,24 @@ pub trait FnOnce<T, Args> {
/// Performs the call operation.
fn call(self: T, args: Args) -> Self::Output;
}

/// An implementation of `FnOnce` when `Fn` is implemented.
/// Makes sure we can always pass an `Fn` to a function that expects an `FnOnce`.
impl FnOnceImpl<T, Args, +Destruct<T>, +Fn<T, Args>> of FnOnce<T, Args> {
type Output = Fn::<T, Args>::Output;
fn call(self: T, args: Args) -> Self::Output {
Fn::call(@self, args)
}
}

/// The version of the call operator that takes a by-snapshot receiver.
///
/// Instances of `Fn` can be called multiple times.
///
/// `Fn` is implemented automatically by closures that capture only copyable variables.
pub trait Fn<T, Args> {
/// The returned type after the call operator is used.
type Output;
/// Performs the call operation.
fn call(self: @T, args: Args) -> Self::Output;
}
12 changes: 12 additions & 0 deletions corelib/src/test/language_features/closure_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ fn option_map_test() {
assert_eq!(option_map(Option::Some(2), |x| Option::Some(x)), Option::Some(Option::Some(2)));
}

fn array_map<T, F, impl Fn: core::ops::Fn<F, (T,)>, +Drop<T>, +Drop<F>, +Drop<Fn::Output>>(
arr: [T; 2], f: F,
) -> [core::ops::Fn::<F, (T,)>::Output; 2] {
let [a, b] = arr;
[f(a), f(b)]
}

#[test]
fn array_map_test() {
assert_eq!(array_map([2, 3], |x| x + 3), [5, 6]);
}

28 changes: 15 additions & 13 deletions crates/cairo-lang-lowering/src/borrow_check/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ blk0 (root):
Statements:
(v0: core::felt252) <- 8
(v1: {[email protected]:3:13: 3:16}) <- struct_construct(v0{`x`})
(v2: core::felt252) <- 2
(v3: (core::felt252,)) <- struct_construct(v2{`2`})
(v4: core::felt252) <- Generated core::ops::function::FnOnce::<{[email protected]:3:13: 3:16}, (core::felt252,)>::call(v1{`c`}, v3{`c(2)`})
(v5: core::felt252) <- core::Felt252Add::add(v4{`y`}, v0{`x`})
(v2: {[email protected]:3:13: 3:16}, v3: @{[email protected]:3:13: 3:16}) <- snapshot(v1{`c`})
(v4: core::felt252) <- 2
(v5: (core::felt252,)) <- struct_construct(v4{`2`})
(v6: core::felt252) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::felt252,)>::call(v3{`c`}, v5{`c(2)`})
(v7: core::felt252) <- core::Felt252Add::add(v6{`y`}, v0{`x`})
End:
Return(v5)
Return(v7)

//! > ==========================================================================

Expand Down Expand Up @@ -314,15 +315,16 @@ Statements:
(v3: core::array::Array::<core::felt252>, v2: ()) <- core::array::ArrayImpl::<core::felt252>::append(v0{`__array_builder_macro_result__`}, v1{`99_felt252`})
(v4: core::array::Array::<core::felt252>, v5: @core::array::Array::<core::felt252>) <- snapshot(v3{`x`})
(v6: {[email protected]:3:13: 3:16}) <- struct_construct(v5{`|a| { (@x).len() * (a + 3) }`})
(v7: core::integer::u32) <- 2
(v8: (core::integer::u32,)) <- struct_construct(v7{`2`})
(v9: core::integer::u32) <- Generated core::ops::function::FnOnce::<{[email protected]:3:13: 3:16}, (core::integer::u32,)>::call(v6{`c`}, v8{`c(2)`})
(v10: core::array::Array::<core::felt252>, v11: @core::array::Array::<core::felt252>) <- snapshot(v4{`x`})
(v12: core::integer::u32) <- 0
(v13: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::<core::array::Array::<core::felt252>, core::integer::u32, @core::felt252, core::array::ArrayIndex::<core::felt252>>::index(v11{`x`}, v12{`0`})
(v14: core::felt252) <- desnap(v13{`0`})
(v7: {[email protected]:3:13: 3:16}, v8: @{[email protected]:3:13: 3:16}) <- snapshot(v6{`c`})
(v9: core::integer::u32) <- 2
(v10: (core::integer::u32,)) <- struct_construct(v9{`2`})
(v11: core::integer::u32) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::integer::u32,)>::call(v8{`c`}, v10{`c(2)`})
(v12: core::array::Array::<core::felt252>, v13: @core::array::Array::<core::felt252>) <- snapshot(v4{`x`})
(v14: core::integer::u32) <- 0
(v15: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::<core::array::Array::<core::felt252>, core::integer::u32, @core::felt252, core::array::ArrayIndex::<core::felt252>>::index(v13{`x`}, v14{`0`})
(v16: core::felt252) <- desnap(v15{`0`})
End:
Return(v14)
Return(v16)

//! > ==========================================================================

Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ impl FunctionLongId {
semantic::corelib::destruct_trait_fn(semantic_db),
semantic::corelib::panic_destruct_trait_fn(semantic_db),
semantic::corelib::fn_once_call_trait_fn(semantic_db),
semantic::corelib::fn_call_trait_fn(semantic_db),
]
.contains(&function)
);
Expand Down
140 changes: 85 additions & 55 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{iter, vec};
use std::vec;

use block_builder::BlockBuilder;
use cairo_lang_debug::DebugWithDb;
Expand All @@ -7,7 +7,7 @@ use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::corelib::{ErrorPropagationType, unwrap_error_propagation_type};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
use cairo_lang_semantic::items::imp::{GeneratedImplItems, GeneratedImplLongId, ImplLongId};
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::usage::MemberPath;
use cairo_lang_semantic::{
ConcreteFunction, ConcreteTraitLongId, ExprVar, LocalVariable, VarId, corelib,
Expand Down Expand Up @@ -1780,11 +1780,11 @@ fn add_closure_call_function(
encapsulated_ctx: &mut LoweringContext<'_, '_>,
expr: &semantic::ExprClosure,
closure_info: &ClosureInfo,
trait_id: cairo_lang_defs::ids::TraitId,
) -> Maybe<()> {
let semantic_db = encapsulated_ctx.db.upcast();
let semantic_db: &dyn SemanticGroup = encapsulated_ctx.db.upcast();
let closure_ty = extract_matches!(expr.ty.lookup_intern(semantic_db), TypeLongId::Closure);
let expr_location = encapsulated_ctx.get_location(expr.stable_ptr.untyped());
let trait_id = semantic::corelib::fn_once_trait(semantic_db);
let parameters_ty = TypeLongId::Tuple(closure_ty.param_tys.clone()).intern(semantic_db);
let concrete_trait = ConcreteTraitLongId {
trait_id,
Expand All @@ -1794,18 +1794,26 @@ fn add_closure_call_function(
],
}
.intern(semantic_db);
let trait_function = semantic::corelib::fn_once_call_trait_fn(semantic_db);

let ret_ty = semantic_db.trait_type_by_name(trait_id, "Output".into()).unwrap().unwrap();
let impl_id = ImplLongId::GeneratedImpl(
GeneratedImplLongId {
concrete_trait,
generic_params: vec![],
impl_items: GeneratedImplItems(iter::once((ret_ty, closure_ty.ret_ty)).collect()),
}
.intern(semantic_db),
)
.intern(semantic_db);
let Ok(impl_id) = semantic::types::get_impl_at_context(
semantic_db,
encapsulated_ctx.variables.lookup_context.clone(),
concrete_trait,
None,
) else {
// If the impl doesn't exist, there won't be a call to the call-function, so we don't need
// to generate it.
return Ok(());
};
if !matches!(impl_id.lookup_intern(semantic_db), ImplLongId::GeneratedImpl(_)) {
// If the impl is not generated, we don't need to generate a lowering for it.
return Ok(());
}

let trait_function: cairo_lang_defs::ids::TraitFunctionId = semantic_db
.trait_function_by_name(trait_id, "call".into())
.unwrap()
.expect("Call function must exist for an Fn trait.");

let generic_function =
GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function: trait_function });
let function = semantic::FunctionLongId {
Expand All @@ -1827,49 +1835,65 @@ fn add_closure_call_function(
let root_block_id = alloc_empty_block(&mut ctx);
let mut builder = BlockBuilder::root(&mut ctx, root_block_id);

let (closure_param_var_id, closure_var) = if trait_id
== semantic::corelib::fn_once_trait(semantic_db)
{
// If the closure is FnOnce, the closure is passed by value.
let closure_param_var = ctx.new_var(VarRequest { ty: expr.ty, location: expr_location });
let closure_var = VarUsage { var_id: closure_param_var, location: expr_location };
(closure_param_var, closure_var)
} else {
// If the closure is Fn the closure argument will be a snapshot, so we need to desnap it.
let closure_param_var = ctx.new_var(VarRequest {
ty: wrap_in_snapshots(semantic_db, expr.ty, 1),
location: expr_location,
});

let closure_var = generators::Desnap {
input: VarUsage { var_id: closure_param_var, location: expr_location },
location: expr_location,
}
.add(&mut ctx, &mut builder.statements);
(closure_param_var, closure_var)
};
let parameters: Vec<VariableId> = [
ctx.new_var(VarRequest { ty: expr.ty, location: expr_location }),
closure_param_var_id,
ctx.new_var(VarRequest { ty: parameters_ty, location: expr_location }),
]
.into();

let root_ok = {
let captured_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[0], location: expr_location },
var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter())
.map(|(_, ty)| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (i, (param, _)) in closure_info.members.iter().enumerate() {
builder.semantics.introduce(param.clone(), captured_vars[i]);
}
for (i, (param, _)) in closure_info.snapshots.iter().enumerate() {
builder
.snapped_semantics
.insert(param.clone(), captured_vars[i + closure_info.members.len()]);
}
let param_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[1], location: expr_location },
var_reqs: closure_ty
.param_tys
.iter()
.map(|ty| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) {
builder
.semantics
.introduce((&parameter_as_member_path(param.clone())).into(), param_var);
}
let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body);
let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr);
maybe_sealed_block.and_then(|block_sealed| {
wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?;
Ok(root_block_id)
})
};
let captured_vars = generators::StructDestructure {
input: closure_var,
var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter())
.map(|(_, ty)| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (i, (param, _)) in closure_info.members.iter().enumerate() {
builder.semantics.introduce(param.clone(), captured_vars[i]);
}
for (i, (param, _)) in closure_info.snapshots.iter().enumerate() {
builder
.snapped_semantics
.insert(param.clone(), captured_vars[i + closure_info.members.len()]);
}
let param_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[1], location: expr_location },
var_reqs: closure_ty
.param_tys
.iter()
.map(|ty| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) {
builder.semantics.introduce((&parameter_as_member_path(param.clone())).into(), param_var);
}
let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body);
let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr);
let root_ok = maybe_sealed_block.and_then(|block_sealed| {
wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?;
Ok(root_block_id)
});
let blocks = root_ok
.map(|_| ctx.blocks.build().expect("Root block must exist."))
.unwrap_or_else(FlatBlocks::new_errored);
Expand Down Expand Up @@ -1910,8 +1934,14 @@ fn lower_expr_closure(
ctx,
expr,
builder.semantics.closures.get(&capture_var_usage.var_id).unwrap(),
if ctx.variables[capture_var_usage.var_id].copyable.is_ok() {
semantic::corelib::fn_trait(ctx.db.upcast())
} else {
semantic::corelib::fn_once_trait(ctx.db.upcast())
},
)
.map_err(LoweringFlowError::Failed)?;

Ok(closure_variable)
}

Expand Down
18 changes: 10 additions & 8 deletions crates/cairo-lang-lowering/src/lower/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,25 @@ End:
Return()


Generated core::ops::function::FnOnce::call lowering for source location:
Generated core::ops::function::Fn::call lowering for source location:
let c = || a;
^^

Parameters: v0: {[email protected]:6:14: 6:16}, v1: ()
Parameters: v0: @{[email protected]:6:14: 6:16}, v2: ()
blk0 (root):
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
() <- struct_destructure(v1)
(v1: {[email protected]:6:14: 6:16}) <- desnap(v0)
(v3: core::integer::u32) <- struct_destructure(v1)
() <- struct_destructure(v2)
End:
Return(v2)
Return(v3)


Final lowering:
Parameters: v0: {[email protected]:6:14: 6:16}, v1: ()
Parameters: v0: @{[email protected]:6:14: 6:16}, v1: ()
blk0 (root):
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
(v2: {[email protected]:6:14: 6:16}) <- desnap(v0)
(v3: core::integer::u32) <- struct_destructure(v2)
End:
Return(v2)
Return(v3)
12 changes: 12 additions & 0 deletions crates/cairo-lang-semantic/src/corelib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,22 @@ pub fn fn_once_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::Ops, "FnOnce".into())
}

pub fn fn_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::Ops, "Fn".into())
}

pub fn fn_traits(db: &dyn SemanticGroup) -> [TraitId; 2] {
[fn_trait(db), fn_once_trait(db)]
}

pub fn fn_once_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId {
get_core_trait_fn(db, CoreTraitContext::Ops, "FnOnce".into(), "call".into())
}

pub fn fn_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId {
get_core_trait_fn(db, CoreTraitContext::Ops, "Fn".into(), "call".into())
}

pub fn copy_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::TopLevel, "Copy".into())
}
Expand Down
24 changes: 21 additions & 3 deletions crates/cairo-lang-semantic/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,19 @@ impl DiagnosticEntry for SemanticDiagnostic {
)
}
}
SemanticDiagnosticKind::TypeEqualTraitReImplementation => {
"Type equals trait should not be re-implemented.".into()
SemanticDiagnosticKind::CallExpressionRequiresFunction { ty, inference_errors } => {
if inference_errors.is_empty() {
format!("Call expression requires a function, found `{}`.", ty.format(db))
} else {
format!(
"Call expression requires a function, found `{}`.\n{}",
ty.format(db),
inference_errors.format(db)
)
}
}
SemanticDiagnosticKind::CompilerTraitReImplementation { trait_id } => {
format!("Trait `{}` should not be re-implemented.", trait_id.full_path(db.upcast()))
}
SemanticDiagnosticKind::ClosureInGlobalScope => {
"Closures are not allowed in this context.".into()
Expand Down Expand Up @@ -1189,7 +1200,12 @@ pub enum SemanticDiagnosticKind {
trait_name: SmolStr,
inference_errors: TraitInferenceErrors,
},
CallExpressionRequiresFunction {
ty: semantic::TypeId,
inference_errors: TraitInferenceErrors,
},
MultipleImplementationOfIndexOperator(semantic::TypeId),

UnsupportedInlineArguments,
RedundantInlineAttribute,
InlineAttrForExternFunctionNotAllowed,
Expand Down Expand Up @@ -1231,7 +1247,9 @@ pub enum SemanticDiagnosticKind {
DerefCycle {
deref_chain: String,
},
TypeEqualTraitReImplementation,
CompilerTraitReImplementation {
trait_id: TraitId,
},
ClosureInGlobalScope,
MaybeMissingColonColon,
CallingShadowedFunction {
Expand Down
Loading