Skip to content

Commit

Permalink
Consolidate caller match in one function call (#4446)
Browse files Browse the repository at this point in the history
Co-authored-by: Jon Ross-Perkins <[email protected]>
  • Loading branch information
geoffromer and jonmeow authored Oct 29, 2024
1 parent 89eed42 commit e20e8bf
Show file tree
Hide file tree
Showing 61 changed files with 460 additions and 496 deletions.
7 changes: 4 additions & 3 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
}

// Convert the arguments to match the parameters.
auto converted_args_id = ConvertCallArgs(
context, loc_id, callee_function.self_id, arg_ids, return_slot_arg_id,
CalleeParamsInfo(callable), *callee_specific_id);
auto converted_args_id =
ConvertCallArgs(context, loc_id, callee_function.self_id, arg_ids,
return_slot_arg_id, CalleeParamsInfo(callable),
callable.return_slot_pattern_id, *callee_specific_id);
auto call_inst_id =
context.AddInst<SemIR::Call>(loc_id, {.type_id = return_info.type_id,
.callee_id = callee_id,
Expand Down
119 changes: 24 additions & 95 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,46 +1136,12 @@ auto ConvertForExplicitAs(Context& context, Parse::NodeId as_node,
{.kind = ConversionTarget::ExplicitAs, .type_id = type_id});
}

CARBON_DIAGNOSTIC(InCallToFunction, Note, "calling function declared here");

// Convert the object argument in a method call to match the `self` parameter.
static auto ConvertSelf(Context& context, SemIR::LocId call_loc_id,
SemIRLoc callee_loc,
SemIR::SpecificId callee_specific_id,
SemIR::InstId self_param_id, SemIR::InstId self_id)
-> SemIR::InstId {
if (!self_id.is_valid()) {
CARBON_DIAGNOSTIC(MissingObjectInMethodCall, Error,
"missing object argument in method call");
context.emitter()
.Build(call_loc_id, MissingObjectInMethodCall)
.Note(callee_loc, InCallToFunction)
.Emit();
return SemIR::InstId::BuiltinError;
}

bool addr_pattern = context.insts().Is<SemIR::AddrPattern>(self_param_id);
DiagnosticAnnotationScope annotate_diagnostics(
&context.emitter(), [&](auto& builder) {
CARBON_DIAGNOSTIC(InCallToFunctionSelf, Note,
"initializing `{0:addr self|self}` parameter of "
"method declared here",
BoolAsSelect);
builder.Note(self_param_id, InCallToFunctionSelf, addr_pattern);
});

return CallerPatternMatch(context, callee_specific_id, self_param_id,
self_id);
}

// TODO: consider moving this to pattern_match.h
auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs,
SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee,
SemIR::SpecificId callee_specific_id)
-> SemIR::InstBlockId {
auto ConvertCallArgs(
Context& context, SemIR::LocId call_loc_id, SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs, SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee, SemIR::InstId return_slot_pattern_id,
SemIR::SpecificId callee_specific_id) -> SemIR::InstBlockId {
auto implicit_param_patterns =
context.inst_blocks().GetOrEmpty(callee.implicit_param_patterns_id);
auto param_patterns =
Expand All @@ -1184,68 +1150,31 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
// The caller should have ensured this callee has the right arity.
CARBON_CHECK(arg_refs.size() == param_patterns.size());

// Start building a block to hold the converted arguments.
llvm::SmallVector<SemIR::InstId> args;
args.reserve(implicit_param_patterns.size() + param_patterns.size() +
return_slot_arg_id.is_valid());

// Check implicit parameters.
// Find self parameter pattern.
// TODO: Do this during initial traversal of implicit params.
auto self_param_id = SemIR::InstId::Invalid;
for (auto implicit_param_id : implicit_param_patterns) {
if (implicit_param_id == SemIR::InstId::BuiltinError) {
return SemIR::InstBlockId::Invalid;
if (SemIR::Function::GetNameFromPatternId(
context.sem_ir(), implicit_param_id) == SemIR::NameId::SelfValue) {
CARBON_CHECK(!self_param_id.is_valid());
self_param_id = implicit_param_id;
}
auto param_pattern_info = SemIR::Function::GetParamPatternInfoFromPatternId(
context.sem_ir(), implicit_param_id);
if (param_pattern_info.GetNameId(context.sem_ir()) ==
SemIR::NameId::SelfValue) {
auto converted_self_id =
ConvertSelf(context, call_loc_id, callee.callee_loc,
callee_specific_id, implicit_param_id, self_id);
if (converted_self_id == SemIR::InstId::BuiltinError) {
return SemIR::InstBlockId::Invalid;
}
args.push_back(converted_self_id);
} else {
CARBON_CHECK(!param_pattern_info.inst.runtime_index.is_valid(),
"Unexpected implicit parameter passed at runtime");
}
}

// Check type conversions per-element.
for (auto [i, arg_id, param_pattern_id] :
llvm::enumerate(arg_refs, param_patterns)) {
auto runtime_index = SemIR::Function::GetParamPatternInfoFromPatternId(
context.sem_ir(), param_pattern_id)
.inst.runtime_index;
if (!runtime_index.is_valid()) {
// Not a runtime parameter: we don't pass an argument.
continue;
}

DiagnosticAnnotationScope annotate_diagnostics(
&context.emitter(), [&](auto& builder) {
CARBON_DIAGNOSTIC(InCallToFunctionParam, Note,
"initializing function parameter");
builder.Note(param_pattern_id, InCallToFunctionParam);
});

auto converted_arg_id = CallerPatternMatch(context, callee_specific_id,
param_pattern_id, arg_id);
if (converted_arg_id == SemIR::InstId::BuiltinError) {
return SemIR::InstBlockId::Invalid;
}

CARBON_CHECK(static_cast<int32_t>(args.size()) == runtime_index.index,
"Parameters not numbered in order.");
args.push_back(converted_arg_id);
}

// Track the return storage, if present.
if (return_slot_arg_id.is_valid()) {
args.push_back(return_slot_arg_id);
if (self_param_id.is_valid() && !self_id.is_valid()) {
CARBON_DIAGNOSTIC(MissingObjectInMethodCall, Error,
"missing object argument in method call");
CARBON_DIAGNOSTIC(InCallToFunction, Note, "calling function declared here");
context.emitter()
.Build(call_loc_id, MissingObjectInMethodCall)
.Note(callee.callee_loc, InCallToFunction)
.Emit();
self_id = SemIR::InstId::BuiltinError;
}

return context.inst_blocks().AddOrEmpty(args);
return CallerPatternMatch(context, callee_specific_id, self_param_id,
callee.param_patterns_id, return_slot_pattern_id,
self_id, arg_refs, return_slot_arg_id);
}

auto ExprAsType(Context& context, SemIR::LocId loc_id, SemIR::InstId value_id)
Expand Down
21 changes: 10 additions & 11 deletions toolchain/check/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ auto ConvertForExplicitAs(Context& context, Parse::NodeId as_node,
SemIR::InstId value_id, SemIR::TypeId type_id)
-> SemIR::InstId;

// Information about the parameters of a callee. This information is extracted
// from the EntityWithParamsBase before calling ConvertCallArgs, because
// conversion can trigger importing of more entities, which can invalidate the
// reference to the callee.
// Information about the syntactic parameters of a callee (excluding the return
// slot, for example). This information is extracted from the
// EntityWithParamsBase before calling ConvertCallArgs, because conversion can
// trigger importing of more entities, which can invalidate the reference to the
// callee.
struct CalleeParamsInfo {
explicit CalleeParamsInfo(const SemIR::EntityWithParamsBase& callee)
: callee_loc(callee.latest_decl_id()),
Expand All @@ -114,13 +115,11 @@ struct CalleeParamsInfo {
// Implicitly converts a set of arguments to match the parameter types in a
// function call. Returns a block containing the converted implicit and explicit
// argument values for runtime parameters.
auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs,
SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee,
SemIR::SpecificId callee_specific_id)
-> SemIR::InstBlockId;
auto ConvertCallArgs(
Context& context, SemIR::LocId call_loc_id, SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs, SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee, SemIR::InstId return_slot_pattern_id,
SemIR::SpecificId callee_specific_id) -> SemIR::InstBlockId;

// A type that has been converted for use as a type expression.
struct TypeExpr {
Expand Down
3 changes: 2 additions & 1 deletion toolchain/check/global_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ auto GlobalInit::Finalize() -> void {
.extern_library_id = SemIR::LibraryNameId::Invalid,
.non_owning_decl_id = SemIR::InstId::Invalid,
.first_owning_decl_id = SemIR::InstId::Invalid},
{.return_slot_id = SemIR::InstId::Invalid,
{.return_slot_pattern_id = SemIR::InstId::Invalid,
.return_slot_id = SemIR::InstId::Invalid,
.body_block_ids = {SemIR::InstBlockId::GlobalInit}}}));
}

Expand Down
3 changes: 2 additions & 1 deletion toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ static auto BuildFunctionDecl(Context& context,
auto function_info =
SemIR::Function{{name_context.MakeEntityWithParamsBase(
name, decl_id, is_extern, introducer.extern_library)},
{.return_slot_id = name.return_slot_id,
{.return_slot_pattern_id = name.return_slot_pattern_id,
.return_slot_id = name.return_slot_id,
.virtual_modifier = virtual_modifier}};
if (is_definition) {
function_info.definition_id = decl_id;
Expand Down
33 changes: 32 additions & 1 deletion toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,34 @@ class ImportRefResolver {
return context_.inst_blocks().Add(new_patterns);
}

// Returns a version of import_return_slot_pattern_id localized to the current
// IR.
auto GetLocalReturnSlotPatternId(SemIR::InstId import_return_slot_pattern_id)
-> SemIR::InstId {
if (!import_return_slot_pattern_id.is_valid()) {
return SemIR::InstId::Invalid;
}

auto param_pattern = import_ir_.insts().GetAs<SemIR::OutParamPattern>(
import_return_slot_pattern_id);
auto return_slot_pattern =
import_ir_.insts().GetAs<SemIR::ReturnSlotPattern>(
param_pattern.subpattern_id);
auto type_id = context_.GetTypeIdForTypeConstant(
GetLocalConstantIdChecked(return_slot_pattern.type_id));

auto new_return_slot_pattern_id = context_.AddInstInNoBlock(
context_.MakeImportedLocAndInst<SemIR::ReturnSlotPattern>(
AddImportIRInst(param_pattern.subpattern_id),
{.type_id = type_id, .type_inst_id = SemIR::InstId::Invalid}));
return context_.AddInstInNoBlock(
context_.MakeImportedLocAndInst<SemIR::OutParamPattern>(
AddImportIRInst(import_return_slot_pattern_id),
{.type_id = type_id,
.subpattern_id = new_return_slot_pattern_id,
.runtime_index = param_pattern.runtime_index}));
}

// Translates a NameId from the import IR to a local NameId.
auto GetLocalNameId(SemIR::NameId import_name_id) -> SemIR::NameId {
if (auto ident_id = import_name_id.AsIdentifierId(); ident_id.is_valid()) {
Expand Down Expand Up @@ -1600,7 +1628,8 @@ class ImportRefResolver {
// Start with an incomplete function.
function_decl.function_id = context_.functions().Add(
{GetIncompleteLocalEntityBase(function_decl_id, import_function),
{.return_slot_id = SemIR::InstId::Invalid,
{.return_slot_pattern_id = SemIR::InstId::Invalid,
.return_slot_id = SemIR::InstId::Invalid,
.builtin_function_kind = import_function.builtin_function_kind}});

function_decl.type_id =
Expand Down Expand Up @@ -1671,6 +1700,8 @@ class ImportRefResolver {
GetLocalParamRefsId(import_function.param_refs_id);
new_function.param_patterns_id =
GetLocalParamPatternsId(import_function.param_patterns_id);
new_function.return_slot_pattern_id =
GetLocalReturnSlotPatternId(import_function.return_slot_pattern_id);
SetGenericData(import_function.generic_id, new_function.generic_id,
generic_data);

Expand Down
Loading

0 comments on commit e20e8bf

Please sign in to comment.