Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model return slot as parameter in lowering #4457

Merged
merged 7 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions toolchain/lower/file_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,14 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
implicit_param_patterns.size() + param_patterns.size();
param_types.reserve(max_llvm_params);
param_inst_ids.reserve(max_llvm_params);
auto return_param_id = SemIR::InstId::Invalid;
if (return_info.has_return_slot()) {
param_types.push_back(return_type->getPointerTo());
param_inst_ids.push_back(function.return_slot_id);
return_param_id = sem_ir()
.insts()
.GetAs<SemIR::ReturnSlot>(function.return_slot_id)
.storage_id;
param_inst_ids.push_back(return_param_id);
}
for (auto param_pattern_id : llvm::concat<const SemIR::InstId>(
implicit_param_patterns, param_patterns)) {
Expand Down Expand Up @@ -280,7 +285,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
for (auto [inst_id, arg] :
llvm::zip_equal(param_inst_ids, llvm_function->args())) {
auto name_id = SemIR::NameId::Invalid;
if (inst_id == function.return_slot_id) {
if (inst_id == return_param_id) {
name_id = SemIR::NameId::ReturnSlot;
arg.addAttr(
llvm::Attribute::getWithStructRetType(llvm_context(), return_type));
Expand Down Expand Up @@ -324,51 +329,103 @@ auto FileContext::BuildFunctionDefinition(SemIR::FunctionId function_id)
sem_ir().inst_blocks().GetOrEmpty(function.implicit_param_refs_id);
auto param_refs = sem_ir().inst_blocks().GetOrEmpty(function.param_refs_id);
int param_index = 0;
if (SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
.has_return_slot()) {
function_lowering.SetLocal(function.return_slot_id,
llvm_function->getArg(param_index));
++param_index;
}
// The SemIR calling-convention parameters of the function, in order of
// runtime index. This is a transitional step toward generating this list
// in the check phase, which is why we're using the runtime index order
// even though it's less convenient for this usage.
zygoloid marked this conversation as resolved.
Show resolved Hide resolved
llvm::SmallVector<SemIR::InstId> calling_convention_param_ids;
zygoloid marked this conversation as resolved.
Show resolved Hide resolved
// This is an upper bound on the size because `self` and the return slot
// are the only runtime parameters that don't appear in the explicit
// parameter list.
calling_convention_param_ids.reserve(param_refs.size() + 2);
bool has_return_slot =
SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id)
.has_return_slot();
for (auto param_ref_id :
llvm::concat<const SemIR::InstId>(implicit_param_refs, param_refs)) {
auto param_info =
SemIR::Function::GetParamFromParamRefId(sem_ir(), param_ref_id);
if (!param_info.inst.runtime_index.is_valid()) {
continue;
if (param_info.inst.runtime_index.is_valid()) {
calling_convention_param_ids.push_back(param_info.inst_id);
}
}
if (has_return_slot) {
auto return_slot =
sem_ir().insts().GetAs<SemIR::ReturnSlot>(function.return_slot_id);
calling_convention_param_ids.push_back(return_slot.storage_id);
}

// TODO: find a way to ensure this code and the function-call lowering use
// the same parameter ordering.

// Lowers the given parameter. Must be called in LLVM calling convention
// parameter order.
auto lower_param = [&](SemIR::InstId param_id) {
// Get the value of the parameter from the function argument.
auto param_type_id = param_info.inst.type_id;
llvm::Value* param_value = llvm::PoisonValue::get(GetType(param_type_id));
if (SemIR::ValueRepr::ForType(sem_ir(), param_type_id).kind !=
auto param_inst = sem_ir().insts().GetAs<SemIR::AnyParam>(param_id);
llvm::Value* param_value =
llvm::PoisonValue::get(GetType(param_inst.type_id));
if (SemIR::ValueRepr::ForType(sem_ir(), param_inst.type_id).kind !=
SemIR::ValueRepr::None) {
param_value = llvm_function->getArg(param_index);
++param_index;
}

// The value of the parameter is the value of the argument.
function_lowering.SetLocal(param_info.inst_id, param_value);

// Match the portion of the pattern corresponding to the parameter against
// the parameter value. For now this is always a single name binding,
// possibly wrapped in `addr`.
//
// TODO: Support general patterns here.
auto bind_name_id = param_ref_id;
auto bind_name = sem_ir().insts().Get(bind_name_id);
CARBON_CHECK(bind_name.Is<SemIR::BindName>());
function_lowering.SetLocal(bind_name_id, param_value);
function_lowering.SetLocal(param_id, param_value);
};

// The subset of calling_convention_param_id that is in sequential order.
llvm::ArrayRef<SemIR::InstId> sequential_param_ids =
calling_convention_param_ids;

// The LLVM calling convention has the return slot first rather than last.
if (has_return_slot) {
lower_param(calling_convention_param_ids.back());

sequential_param_ids = sequential_param_ids.drop_back();
}
for (auto param_id : sequential_param_ids) {
lower_param(param_id);
}

// Lower all blocks.
for (auto block_id : body_block_ids) {
auto decl_block_id = SemIR::InstBlockId::Invalid;
if (function_id == sem_ir().global_ctor_id()) {
decl_block_id = SemIR::InstBlockId::Empty;
} else {
decl_block_id = sem_ir()
.insts()
.GetAs<SemIR::FunctionDecl>(function.latest_decl_id())
.decl_block_id;
}

// Lowers the contents of block_id into the corresponding LLVM block,
// creating it if it doesn't already exist.
auto lower_block = [&](SemIR::InstBlockId block_id) {
CARBON_VLOG("Lowering {0}\n", block_id);
auto* llvm_block = function_lowering.GetBlock(block_id);
// Keep the LLVM blocks in lexical order.
llvm_block->moveBefore(llvm_function->end());
function_lowering.builder().SetInsertPoint(llvm_block);
function_lowering.LowerBlock(block_id);
function_lowering.LowerBlockContents(block_id);
};

lower_block(decl_block_id);

// If the decl block is empty, reuse it as the first body block. We don't do
// this when the decl block is non-empty so that any branches back to the
// first body block don't also re-execute the decl.
llvm::BasicBlock* block = function_lowering.builder().GetInsertBlock();
if (block->empty() &&
function_lowering.TryToReuseBlock(body_block_ids.front(), block)) {
// Reuse this block as the first block of the function body.
} else {
function_lowering.builder().CreateBr(
function_lowering.GetBlock(body_block_ids.front()));
}
zygoloid marked this conversation as resolved.
Show resolved Hide resolved

// Lower all blocks.
for (auto block_id : body_block_ids) {
lower_block(block_id);
}

// LLVM requires that the entry block has no predecessors.
Expand Down
2 changes: 1 addition & 1 deletion toolchain/lower/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ auto FunctionContext::TryToReuseBlock(SemIR::InstBlockId block_id,
return true;
}

auto FunctionContext::LowerBlock(SemIR::InstBlockId block_id) -> void {
auto FunctionContext::LowerBlockContents(SemIR::InstBlockId block_id) -> void {
for (auto inst_id : sem_ir().inst_blocks().Get(block_id)) {
LowerInst(inst_id);
}
Expand Down
2 changes: 1 addition & 1 deletion toolchain/lower/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FunctionContext {
-> bool;

// Builds LLVM IR for the sequence of instructions in `block_id`.
auto LowerBlock(SemIR::InstBlockId block_id) -> void;
auto LowerBlockContents(SemIR::InstBlockId block_id) -> void;

// Builds LLVM IR for the specified instruction.
auto LowerInst(SemIR::InstId inst_id) -> void;
Expand Down
15 changes: 9 additions & 6 deletions toolchain/lower/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,20 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::OutParam /*inst*/) -> void {
CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
// Parameters are lowered by `BuildFunctionDefinition`.
}

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::ValueParam /*inst*/) -> void {
CARBON_FATAL("Parameters should be lowered by `BuildFunctionDefinition`");
// Parameters are lowered by `BuildFunctionDefinition`.
}

auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::ReturnSlot /*inst*/) -> void {
CARBON_FATAL("Return slots should be lowered by `BuildFunctionDefinition`");
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ReturnSlot inst) -> void {
if (SemIR::InitRepr::ForType(context.sem_ir(), inst.type_id).kind ==
SemIR::InitRepr::InPlace) {
context.SetLocal(inst_id, context.GetValue(inst.storage_id));
}
}

auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
Expand Down Expand Up @@ -221,7 +224,7 @@ auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,

auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::SpliceBlock inst) -> void {
context.LowerBlock(inst.block_id);
context.LowerBlockContents(inst.block_id);
context.SetLocal(inst_id, context.GetValue(inst.result_id));
}

Expand Down
Loading