Skip to content

Commit

Permalink
Convert StructTypeField to a specific type. (#4492)
Browse files Browse the repository at this point in the history
This converts `StructTypeField` from an instruction to a dedicated type,
with its own store. This had originated from discussing how
`.GetAs<SemIR::StructTypeField>` was more prevalent than for other
instructions, but is probably more interesting for the storage savings
(16 bytes StructTypeField + 4 byte LocId + 4 byte InstId -> 8 byte
StructTypeField).

Due to the different structure, these now have their own stack during
construction, reducing (but not eliminating) `args_type_info_stack_`
use-cases.

The test changes of different InstIds is expected because structs and
classes generate fewer instructions now. Other than that, results should
remain the same.

I'm generally trying to avoid unrelated cleanup here due to the PR size,
though I did scrutinize the `VerifyOnFinish` calls, adding one and
commenting others (putting them in member order because that's how I was
checking what was verified and what wasn't).
  • Loading branch information
jonmeow authored Nov 6, 2024
1 parent 7977a9c commit be56ff8
Show file tree
Hide file tree
Showing 48 changed files with 530 additions and 431 deletions.
3 changes: 3 additions & 0 deletions common/array_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class ArrayStack {
// Returns the current number of values in all arrays.
auto all_values_size() const -> size_t { return values_.size(); }

// Returns true if the stack has no arrays pushed.
auto empty() const -> bool { return array_offsets_.empty(); }

private:
// For each pushed array, the start index in elements_.
llvm::SmallVector<int32_t> array_offsets_;
Expand Down
45 changes: 23 additions & 22 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ auto Context::VerifyOnFinish() -> void {
// various pieces of context go out of scope. At this point, nothing should
// remain.
// node_stack_ will still contain top-level entities.
scope_stack_.VerifyOnFinish();
inst_block_stack_.VerifyOnFinish();
pattern_block_stack_.VerifyOnFinish();
param_and_arg_refs_stack_.VerifyOnFinish();
args_type_info_stack_.VerifyOnFinish();
CARBON_CHECK(struct_type_fields_stack_.empty());
// TODO: Add verification for decl_name_stack_ and
// decl_introducer_state_stack_.
scope_stack_.VerifyOnFinish();
// TODO: Add verification for generic_region_stack_.
}

auto Context::GetOrAddInst(SemIR::LocIdAndInst loc_id_and_inst)
Expand Down Expand Up @@ -847,10 +852,8 @@ class TypeCompleter {
break;
}
case CARBON_KIND(SemIR::StructType inst): {
for (auto field_id : context_.inst_blocks().Get(inst.fields_id)) {
Push(context_.insts()
.GetAs<SemIR::StructTypeField>(field_id)
.field_type_id);
for (auto field : context_.struct_type_fields().Get(inst.fields_id)) {
Push(field.type_id);
}
break;
}
Expand Down Expand Up @@ -981,33 +984,30 @@ class TypeCompleter {
auto BuildValueReprForInst(SemIR::TypeId type_id,
SemIR::StructType struct_type) const
-> SemIR::ValueRepr {
// TODO: Share more code with tuples.
auto fields = context_.inst_blocks().Get(struct_type.fields_id);
auto fields = context_.struct_type_fields().Get(struct_type.fields_id);
if (fields.empty()) {
return MakeEmptyValueRepr();
}

// Find the value representation for each field, and construct a struct
// of value representations.
llvm::SmallVector<SemIR::InstId> value_rep_fields;
llvm::SmallVector<SemIR::StructTypeField> value_rep_fields;
value_rep_fields.reserve(fields.size());
bool same_as_object_rep = true;
for (auto field_id : fields) {
auto field = context_.insts().GetAs<SemIR::StructTypeField>(field_id);
auto field_value_rep = GetNestedValueRepr(field.field_type_id);
if (field_value_rep.type_id != field.field_type_id) {
for (auto field : fields) {
auto field_value_rep = GetNestedValueRepr(field.type_id);
if (field_value_rep.type_id != field.type_id) {
same_as_object_rep = false;
field.field_type_id = field_value_rep.type_id;
field_id = context_.constant_values().GetInstId(
TryEvalInst(context_, SemIR::InstId::Invalid, field));
field.type_id = field_value_rep.type_id;
}
value_rep_fields.push_back(field_id);
value_rep_fields.push_back(field);
}

auto value_rep = same_as_object_rep
? type_id
: context_.GetStructType(
context_.inst_blocks().Add(value_rep_fields));
auto value_rep =
same_as_object_rep
? type_id
: context_.GetStructType(
context_.struct_type_fields().AddCanonical(value_rep_fields));
return BuildStructOrTupleValueRepr(fields.size(), value_rep,
same_as_object_rep);
}
Expand Down Expand Up @@ -1243,8 +1243,9 @@ static auto GetCompleteTypeImpl(Context& context, EachArgT... each_arg)
return type_id;
}

auto Context::GetStructType(SemIR::InstBlockId refs_id) -> SemIR::TypeId {
return GetTypeImpl<SemIR::StructType>(*this, refs_id);
auto Context::GetStructType(SemIR::StructTypeFieldsId fields_id)
-> SemIR::TypeId {
return GetTypeImpl<SemIR::StructType>(*this, fields_id);
}

auto Context::GetTupleType(llvm::ArrayRef<SemIR::TypeId> type_ids)
Expand Down
16 changes: 13 additions & 3 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,8 @@ class Context {
// Returns a pointer type whose pointee type is `pointee_type_id`.
auto GetPointerType(SemIR::TypeId pointee_type_id) -> SemIR::TypeId;

// Returns a struct type with the given fields, which should be a block of
// `StructTypeField`s.
auto GetStructType(SemIR::InstBlockId refs_id) -> SemIR::TypeId;
// Returns a struct type with the given fields.
auto GetStructType(SemIR::StructTypeFieldsId fields_id) -> SemIR::TypeId;

// Returns a tuple type with the given element types.
auto GetTupleType(llvm::ArrayRef<SemIR::TypeId> type_ids) -> SemIR::TypeId;
Expand Down Expand Up @@ -467,6 +466,10 @@ class Context {
return args_type_info_stack_;
}

auto struct_type_fields_stack() -> ArrayStack<SemIR::StructTypeField>& {
return struct_type_fields_stack_;
}

auto decl_name_stack() -> DeclNameStack& { return decl_name_stack_; }

auto decl_introducer_state_stack() -> DeclIntroducerStateStack& {
Expand Down Expand Up @@ -527,6 +530,9 @@ class Context {
auto name_scopes() -> SemIR::NameScopeStore& {
return sem_ir().name_scopes();
}
auto struct_type_fields() -> SemIR::StructTypeFieldsStore& {
return sem_ir().struct_type_fields();
}
auto types() -> SemIR::TypeStore& { return sem_ir().types(); }
auto type_blocks() -> SemIR::BlockValueStore<SemIR::TypeBlockId>& {
return sem_ir().type_blocks();
Expand Down Expand Up @@ -613,6 +619,10 @@ class Context {
// arguments.
InstBlockStack args_type_info_stack_;

// The stack of StructTypeFields for in-progress StructTypeLiterals and Class
// object representations.
ArrayStack<SemIR::StructTypeField> struct_type_fields_stack_;

// The stack used for qualified declaration name construction.
DeclNameStack decl_name_stack_;

Expand Down
29 changes: 11 additions & 18 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,10 @@ static auto ConvertStructToStructOrClass(Context& context,
std::is_same_v<SemIR::ClassElementAccess, TargetAccessInstT>;

auto& sem_ir = context.sem_ir();
auto src_elem_fields = sem_ir.inst_blocks().Get(src_type.fields_id);
auto dest_elem_fields = sem_ir.inst_blocks().Get(dest_type.fields_id);
bool dest_has_vptr =
!dest_elem_fields.empty() &&
sem_ir.insts()
.GetAs<SemIR::StructTypeField>(dest_elem_fields.front())
.name_id == SemIR::NameId::Vptr;
auto src_elem_fields = sem_ir.struct_type_fields().Get(src_type.fields_id);
auto dest_elem_fields = sem_ir.struct_type_fields().Get(dest_type.fields_id);
bool dest_has_vptr = !dest_elem_fields.empty() &&
dest_elem_fields.front().name_id == SemIR::NameId::Vptr;
auto dest_elem_fields_size = dest_elem_fields.size() - dest_has_vptr;

auto value = sem_ir.insts().Get(value_id);
Expand Down Expand Up @@ -432,9 +429,8 @@ static auto ConvertStructToStructOrClass(Context& context,
// Prepare to look up fields in the source by index.
Map<SemIR::NameId, int32_t> src_field_indexes;
if (src_type.fields_id != dest_type.fields_id) {
for (auto [i, field_id] : llvm::enumerate(src_elem_fields)) {
auto result = src_field_indexes.Insert(
context.insts().GetAs<SemIR::StructTypeField>(field_id).name_id, i);
for (auto [i, field] : llvm::enumerate(src_elem_fields)) {
auto result = src_field_indexes.Insert(field.name_id, i);
CARBON_CHECK(result.is_inserted(), "Duplicate field in source structure");
}
}
Expand All @@ -460,9 +456,7 @@ static auto ConvertStructToStructOrClass(Context& context,
: SemIR::CopyOnWriteInstBlock(
sem_ir, SemIR::CopyOnWriteInstBlock::UninitializedBlock{
dest_elem_fields.size()});
for (auto [i, dest_field_id] : llvm::enumerate(dest_elem_fields)) {
auto dest_field =
sem_ir.insts().GetAs<SemIR::StructTypeField>(dest_field_id);
for (auto [i, dest_field] : llvm::enumerate(dest_elem_fields)) {
if (dest_field.name_id == SemIR::NameId::Vptr) {
// TODO: Initialize the vptr to point to a vtable.
new_block.Set(i, SemIR::InstId::BuiltinError);
Expand Down Expand Up @@ -494,16 +488,15 @@ static auto ConvertStructToStructOrClass(Context& context,
return SemIR::InstId::BuiltinError;
}
}
auto src_field = sem_ir.insts().GetAs<SemIR::StructTypeField>(
src_elem_fields[src_field_index]);
auto src_field = src_elem_fields[src_field_index];

// TODO: This call recurses back into conversion. Switch to an iterative
// approach.
auto init_id =
ConvertAggregateElement<SemIR::StructAccess, TargetAccessInstT>(
context, value_loc_id, value_id, src_field.field_type_id,
literal_elems, inner_kind, target.init_id, dest_field.field_type_id,
target.init_block, src_field_index);
context, value_loc_id, value_id, src_field.type_id, literal_elems,
inner_kind, target.init_id, dest_field.type_id, target.init_block,
src_field_index);
if (init_id == SemIR::InstId::BuiltinError) {
return SemIR::InstId::BuiltinError;
}
Expand Down
33 changes: 30 additions & 3 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,36 @@ static auto GetConstantValue(EvalContext& eval_context,
return eval_context.inst_blocks().AddCanonical(const_insts);
}

// Compute the constant value of a type block. This may be different from the
// input type block if we have known generic arguments.
static auto GetConstantValue(EvalContext& eval_context,
SemIR::StructTypeFieldsId fields_id, Phase* phase)
-> SemIR::StructTypeFieldsId {
if (!fields_id.is_valid()) {
return SemIR::StructTypeFieldsId::Invalid;
}
auto fields = eval_context.context().struct_type_fields().Get(fields_id);
llvm::SmallVector<SemIR::StructTypeField> new_fields;
for (auto field : fields) {
auto new_type_id = GetConstantValue(eval_context, field.type_id, phase);
if (!new_type_id.is_valid()) {
return SemIR::StructTypeFieldsId::Invalid;
}

// Once we leave the small buffer, we know the first few elements are all
// constant, so it's likely that the entire block is constant. Resize to the
// target size given that we're going to allocate memory now anyway.
if (new_fields.size() == new_fields.capacity()) {
new_fields.reserve(fields.size());
}

new_fields.push_back({.name_id = field.name_id, .type_id = new_type_id});
}
// TODO: If the new block is identical to the original block, and we know the
// old ID was canonical, return the original ID.
return eval_context.context().struct_type_fields().AddCanonical(new_fields);
}

// Compute the constant value of a type block. This may be different from the
// input type block if we have known generic arguments.
static auto GetConstantValue(EvalContext& eval_context,
Expand Down Expand Up @@ -1235,9 +1265,6 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
case SemIR::StructType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::StructType::fields_id);
case SemIR::StructTypeField::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::StructTypeField::field_type_id);
case SemIR::StructValue::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::StructValue::type_id,
Expand Down
11 changes: 4 additions & 7 deletions toolchain/check/handle_binding_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,12 @@ static auto HandleAnyBindingPattern(Context& context, Parse::NodeId node_id,
binding_id,
{.type_id = field_type_id,
.name_id = name_id,
.index = SemIR::ElementIndex(context.args_type_info_stack()
.PeekCurrentBlockContents()
.size())});
.index = SemIR::ElementIndex(
context.struct_type_fields_stack().PeekArray().size())});

// Add a corresponding field to the object representation of the class.
context.args_type_info_stack().AddInstId(
context.AddInstInNoBlock<SemIR::StructTypeField>(
binding_id,
{.name_id = name_id, .field_type_id = cast_type_id}));
context.struct_type_fields_stack().AppendToTop(
{.name_id = name_id, .type_id = cast_type_id});
context.node_stack().Push(node_id, field_id);
break;
}
Expand Down
46 changes: 24 additions & 22 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ auto HandleParseNode(Context& context, Parse::ClassDefinitionStartId node_id)

context.inst_block_stack().Push();
context.node_stack().Push(node_id, class_id);
context.args_type_info_stack().Push();
context.struct_type_fields_stack().PushArray();

// TODO: Handle the case where there's control flow in the class body. For
// example:
Expand Down Expand Up @@ -533,11 +533,10 @@ auto HandleParseNode(Context& context, Parse::BaseDeclId node_id) -> bool {
auto field_type_id =
context.GetUnboundElementType(class_info.self_type_id, base_info.type_id);
class_info.base_id = context.AddInst<SemIR::BaseDecl>(
node_id,
{.type_id = field_type_id,
.base_type_id = base_info.type_id,
.index = SemIR::ElementIndex(
context.args_type_info_stack().PeekCurrentBlockContents().size())});
node_id, {.type_id = field_type_id,
.base_type_id = base_info.type_id,
.index = SemIR::ElementIndex(
context.struct_type_fields_stack().PeekArray().size())});

if (base_info.type_id != SemIR::TypeId::Error) {
auto base_class_info = context.classes().Get(
Expand All @@ -548,10 +547,8 @@ auto HandleParseNode(Context& context, Parse::BaseDeclId node_id) -> bool {
// Add a corresponding field to the object representation of the class.
// TODO: Consider whether we want to use `partial T` here.
// TODO: Should we diagnose if there are already any fields?
context.args_type_info_stack().AddInstId(
context.AddInstInNoBlock<SemIR::StructTypeField>(
node_id, {.name_id = SemIR::NameId::Base,
.field_type_id = base_info.type_id}));
context.struct_type_fields_stack().AppendToTop(
{.name_id = SemIR::NameId::Base, .type_id = base_info.type_id});

// Bind the name `base` in the class to the base field.
context.decl_name_stack().AddNameOrDiagnoseDuplicate(
Expand All @@ -576,7 +573,7 @@ auto HandleParseNode(Context& context, Parse::BaseDeclId node_id) -> bool {
static auto CheckCompleteAdapterClassType(Context& context,
Parse::NodeId node_id,
SemIR::ClassId class_id,
SemIR::InstBlockId fields_id)
SemIR::StructTypeFieldsId fields_id)
-> SemIR::InstId {
const auto& class_info = context.classes().Get(class_id);
if (class_info.base_id.is_valid()) {
Expand All @@ -589,14 +586,17 @@ static auto CheckCompleteAdapterClassType(Context& context,
return SemIR::InstId::BuiltinError;
}

if (!context.inst_blocks().Get(fields_id).empty()) {
auto first_field_id = context.inst_blocks().Get(fields_id).front();
if (auto fields = context.struct_type_fields().Get(fields_id);
!fields.empty()) {
auto [first_field_inst_id, _] = context.LookupNameInExactScope(
node_id, fields.front().name_id, class_info.scope_id,
context.name_scopes().Get(class_info.scope_id));
CARBON_DIAGNOSTIC(AdaptWithFields, Error, "adapter with fields");
CARBON_DIAGNOSTIC(AdaptWithFieldHere, Note,
"first field declaration is here");
context.emitter()
.Build(class_info.adapt_id, AdaptWithFields)
.Note(first_field_id, AdaptWithFieldHere)
.Note(first_field_inst_id, AdaptWithFieldHere)
.Emit();
return SemIR::InstId::BuiltinError;
}
Expand Down Expand Up @@ -649,7 +649,9 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
SemIR::ClassId class_id) -> SemIR::InstId {
auto& class_info = context.classes().Get(class_id);
if (class_info.adapt_id.is_valid()) {
auto fields_id = context.args_type_info_stack().Pop();
auto fields_id = context.struct_type_fields().AddCanonical(
context.struct_type_fields_stack().PeekArray());
context.struct_type_fields_stack().PopArray();

return CheckCompleteAdapterClassType(context, node_id, class_id, fields_id);
}
Expand All @@ -666,15 +668,15 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
}

if (defining_vtable_ptr) {
context.args_type_info_stack().AddFrontInstId(
context.AddInstInNoBlock<SemIR::StructTypeField>(
Parse::NodeId::Invalid,
{.name_id = SemIR::NameId::Vptr,
.field_type_id = context.GetPointerType(
context.GetBuiltinType(SemIR::BuiltinInstKind::VtableType))}));
context.struct_type_fields_stack().PrependToTop(
{.name_id = SemIR::NameId::Vptr,
.type_id = context.GetPointerType(
context.GetBuiltinType(SemIR::BuiltinInstKind::VtableType))});
}

auto fields_id = context.args_type_info_stack().Pop();
auto fields_id = context.struct_type_fields().AddCanonical(
context.struct_type_fields_stack().PeekArray());
context.struct_type_fields_stack().PopArray();

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
Expand Down
Loading

0 comments on commit be56ff8

Please sign in to comment.