Skip to content

Commit

Permalink
Model return slot as parameter in lowering (#4457)
Browse files Browse the repository at this point in the history
Co-authored-by: Richard Smith <[email protected]>
  • Loading branch information
geoffromer and zygoloid authored Nov 1, 2024
1 parent 145c44b commit ac5cc33
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 36 deletions.
113 changes: 85 additions & 28 deletions toolchain/lower/file_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,14 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
implicit_param_patterns.size() + param_patterns.size();
param_types.reserve(max_llvm_params);
param_inst_ids.reserve(max_llvm_params);
auto return_param_id = SemIR::InstId::Invalid;
if (return_info.has_return_slot()) {
param_types.push_back(return_type->getPointerTo());
param_inst_ids.push_back(function.return_slot_id);
return_param_id = sem_ir()
.insts()
.GetAs<SemIR::ReturnSlot>(function.return_slot_id)
.storage_id;
param_inst_ids.push_back(return_param_id);
}
for (auto param_pattern_id : llvm::concat<const SemIR::InstId>(
implicit_param_patterns, param_patterns)) {
Expand Down Expand Up @@ -280,7 +285,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
for (auto [inst_id, arg] :
llvm::zip_equal(param_inst_ids, llvm_function->args())) {
auto name_id = SemIR::NameId::Invalid;
if (inst_id == function.return_slot_id) {
if (inst_id == return_param_id) {
name_id = SemIR::NameId::ReturnSlot;
arg.addAttr(
llvm::Attribute::getWithStructRetType(llvm_context(), return_type));
Expand Down Expand Up @@ -324,51 +329,103 @@ auto FileContext::BuildFunctionDefinition(SemIR::FunctionId function_id)
sem_ir().inst_blocks().GetOrEmpty(function.implicit_param_refs_id);
auto param_refs = sem_ir().inst_blocks().GetOrEmpty(function.param_refs_id);
int param_index = 0;
if (SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
.has_return_slot()) {
function_lowering.SetLocal(function.return_slot_id,
llvm_function->getArg(param_index));
++param_index;
}
// The SemIR calling-convention parameters of the function, in order of
// runtime index. This is a transitional step toward generating this list
// in the check phase, which is why we're using the runtime index order
// even though it's less convenient for this usage.
llvm::SmallVector<SemIR::InstId> calling_convention_param_ids;
// This is an upper bound on the size because `self` and the return slot
// are the only runtime parameters that don't appear in the explicit
// parameter list.
calling_convention_param_ids.reserve(param_refs.size() + 2);
bool has_return_slot =
SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
.has_return_slot();
for (auto param_ref_id :
llvm::concat<const SemIR::InstId>(implicit_param_refs, param_refs)) {
auto param_info =
SemIR::Function::GetParamFromParamRefId(sem_ir(), param_ref_id);
if (!param_info.inst.runtime_index.is_valid()) {
continue;
if (param_info.inst.runtime_index.is_valid()) {
calling_convention_param_ids.push_back(param_info.inst_id);
}
}
if (has_return_slot) {
auto return_slot =
sem_ir().insts().GetAs<SemIR::ReturnSlot>(function.return_slot_id);
calling_convention_param_ids.push_back(return_slot.storage_id);
}

// TODO: find a way to ensure this code and the function-call lowering use
// the same parameter ordering.

// Lowers the given parameter. Must be called in LLVM calling convention
// parameter order.
auto lower_param = [&](SemIR::InstId param_id) {
// Get the value of the parameter from the function argument.
auto param_type_id = param_info.inst.type_id;
llvm::Value* param_value = llvm::PoisonValue::get(GetType(param_type_id));
if (SemIR::ValueRepr::ForType(sem_ir(), param_type_id).kind !=
auto param_inst = sem_ir().insts().GetAs<SemIR::AnyParam>(param_id);
llvm::Value* param_value =
llvm::PoisonValue::get(GetType(param_inst.type_id));
if (SemIR::ValueRepr::ForType(sem_ir(), param_inst.type_id).kind !=
SemIR::ValueRepr::None) {
param_value = llvm_function->getArg(param_index);
++param_index;
}

// The value of the parameter is the value of the argument.
function_lowering.SetLocal(param_info.inst_id, param_value);

// Match the portion of the pattern corresponding to the parameter against
// the parameter value. For now this is always a single name binding,
// possibly wrapped in `addr`.
//
// TODO: Support general patterns here.
auto bind_name_id = param_ref_id;
auto bind_name = sem_ir().insts().Get(bind_name_id);
CARBON_CHECK(bind_name.Is<SemIR::BindName>());
function_lowering.SetLocal(bind_name_id, param_value);
function_lowering.SetLocal(param_id, param_value);
};

// The subset of calling_convention_param_id that is in sequential order.
llvm::ArrayRef<SemIR::InstId> sequential_param_ids =
calling_convention_param_ids;

// The LLVM calling convention has the return slot first rather than last.
if (has_return_slot) {
lower_param(calling_convention_param_ids.back());

sequential_param_ids = sequential_param_ids.drop_back();
}
for (auto param_id : sequential_param_ids) {
lower_param(param_id);
}

// Lower all blocks.
for (auto block_id : body_block_ids) {
auto decl_block_id = SemIR::InstBlockId::Invalid;
if (function_id == sem_ir().global_ctor_id()) {
decl_block_id = SemIR::InstBlockId::Empty;
} else {
decl_block_id = sem_ir()
.insts()
.GetAs<SemIR::FunctionDecl>(function.latest_decl_id())
.decl_block_id;
}

// Lowers the contents of block_id into the corresponding LLVM block,
// creating it if it doesn't already exist.
auto lower_block = [&](SemIR::InstBlockId block_id) {
CARBON_VLOG("Lowering {0}\n", block_id);
auto* llvm_block = function_lowering.GetBlock(block_id);
// Keep the LLVM blocks in lexical order.
llvm_block->moveBefore(llvm_function->end());
function_lowering.builder().SetInsertPoint(llvm_block);
function_lowering.LowerBlock(block_id);
function_lowering.LowerBlockContents(block_id);
};

lower_block(decl_block_id);

// If the decl block is empty, reuse it as the first body block. We don't do
// this when the decl block is non-empty so that any branches back to the
// first body block don't also re-execute the decl.
llvm::BasicBlock* block = function_lowering.builder().GetInsertBlock();
if (block->empty() &&
function_lowering.TryToReuseBlock(body_block_ids.front(), block)) {
// Reuse this block as the first block of the function body.
} else {
function_lowering.builder().CreateBr(
function_lowering.GetBlock(body_block_ids.front()));
}

// Lower all blocks.
for (auto block_id : body_block_ids) {
lower_block(block_id);
}

// LLVM requires that the entry block has no predecessors.
Expand Down
2 changes: 1 addition & 1 deletion toolchain/lower/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ auto FunctionContext::TryToReuseBlock(SemIR::InstBlockId block_id,
return true;
}

auto FunctionContext::LowerBlock(SemIR::InstBlockId block_id) -> void {
auto FunctionContext::LowerBlockContents(SemIR::InstBlockId block_id) -> void {
for (auto inst_id : sem_ir().inst_blocks().Get(block_id)) {
LowerInst(inst_id);
}
Expand Down
2 changes: 1 addition & 1 deletion toolchain/lower/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FunctionContext {
-> bool;

// Builds LLVM IR for the sequence of instructions in `block_id`.
auto LowerBlock(SemIR::InstBlockId block_id) -> void;
auto LowerBlockContents(SemIR::InstBlockId block_id) -> void;

// Builds LLVM IR for the specified instruction.
auto LowerInst(SemIR::InstId inst_id) -> void;
Expand Down
15 changes: 9 additions & 6 deletions toolchain/lower/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,20 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::OutParam /*inst*/) -> void {
CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
// Parameters are lowered by `BuildFunctionDefinition`.
}

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::ValueParam /*inst*/) -> void {
CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
// Parameters are lowered by `BuildFunctionDefinition`.
}

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::ReturnSlot /*inst*/) -> void {
CARBON_FATAL("Return slots should be lowered by `BuildFunctionDefinition`");
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ReturnSlot inst) -> void {
if (SemIR::InitRepr::ForType(context.sem_ir(), inst.type_id).kind ==
SemIR::InitRepr::InPlace) {
context.SetLocal(inst_id, context.GetValue(inst.storage_id));
}
}

auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
Expand Down Expand Up @@ -226,7 +229,7 @@ auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,

auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::SpliceBlock inst) -> void {
context.LowerBlock(inst.block_id);
context.LowerBlockContents(inst.block_id);
context.SetLocal(inst_id, context.GetValue(inst.result_id));
}

Expand Down

0 comments on commit ac5cc33

Please sign in to comment.