Skip to content

Commit

Permalink
Implement syntactic merge checks for parameters. (#4149)
Browse files Browse the repository at this point in the history
Note this isn't implementing checking through imports. The parse node
there is harder to access through the context, so would require
examining the entity in order to get the import declaration, to get at
the ImportIR. We also don't have a parse tree attached in that case, and
would need to add one to SemIR::File. But I believe we do want to add
that, so it's explicitly a TODO.

Note GetTokenText re-lexes literal values, so there's a bit of potential
overhead there. Not sure if we want a more efficient manner for
comparing in cases like this.
  • Loading branch information
jonmeow authored Jul 23, 2024
1 parent 07c286e commit db02265
Show file tree
Hide file tree
Showing 26 changed files with 2,738 additions and 26 deletions.
14 changes: 8 additions & 6 deletions toolchain/check/decl_name_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ auto DeclNameStack::ResolveAsScope(const NameContext& name_context,
return InvalidResult;
}

auto new_params =
DeclParams(name.name_loc_id, name.implicit_params_id, name.params_id);
auto new_params = DeclParams(name.name_loc_id, name.first_param_node_id,
name.last_param_node_id, name.implicit_params_id,
name.params_id);

// Find the scope corresponding to the resolved instruction.
CARBON_KIND_SWITCH(context_->insts().Get(name_context.resolved_inst_id)) {
Expand Down Expand Up @@ -404,10 +405,11 @@ auto DeclNameStack::ResolveAsScope(const NameContext& name_context,
case CARBON_KIND(SemIR::Namespace resolved_inst): {
auto scope_id = resolved_inst.name_scope_id;
auto& scope = context_->name_scopes().Get(scope_id);
if (!CheckRedeclParamsMatch(*context_, new_params,
DeclParams(name_context.resolved_inst_id,
SemIR::InstBlockId::Invalid,
SemIR::InstBlockId::Invalid))) {
if (!CheckRedeclParamsMatch(
*context_, new_params,
DeclParams(name_context.resolved_inst_id, Parse::NodeId::Invalid,
Parse::NodeId::Invalid, SemIR::InstBlockId::Invalid,
SemIR::InstBlockId::Invalid))) {
return InvalidResult;
}
if (scope.is_closed_import) {
Expand Down
6 changes: 4 additions & 2 deletions toolchain/check/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ namespace Carbon::Check {
auto CheckFunctionTypeMatches(Context& context,
const SemIR::Function& new_function,
const SemIR::Function& prev_function,
Substitutions substitutions) -> bool {
Substitutions substitutions, bool check_syntax)
-> bool {
if (!CheckRedeclParamsMatch(context, DeclParams(new_function),
DeclParams(prev_function), substitutions)) {
DeclParams(prev_function), substitutions,
check_syntax)) {
return false;
}

Expand Down
5 changes: 2 additions & 3 deletions toolchain/check/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ struct SuspendedFunction {
// Checks that `new_function` has the same parameter types and return type as
// `prev_function`, applying the specified set of substitutions to the
// previous function. Prints a suitable diagnostic and returns false if not.
// Note that this doesn't include the syntactic check that's performed for
// redeclarations.
auto CheckFunctionTypeMatches(Context& context,
const SemIR::Function& new_function,
const SemIR::Function& prev_function,
Substitutions substitutions) -> bool;
Substitutions substitutions, bool check_syntax)
-> bool;

// Checks that the return type of the specified function is complete, issuing an
// error if not. This computes the return slot usage for the function if
Expand Down
2 changes: 2 additions & 0 deletions toolchain/check/global_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ auto GlobalInit::Finalize() -> void {
.parent_scope_id = SemIR::NameScopeId::Package,
.decl_id = SemIR::InstId::Invalid,
.generic_id = SemIR::GenericId::Invalid,
.first_param_node_id = Parse::NodeId::Invalid,
.last_param_node_id = Parse::NodeId::Invalid,
.implicit_param_refs_id = SemIR::InstBlockId::Invalid,
.param_refs_id = SemIR::InstBlockId::Empty,
.return_storage_id = SemIR::InstId::Invalid,
Expand Down
4 changes: 4 additions & 0 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ static auto MergeClassRedecl(Context& context, SemIRLoc new_loc,
}

if (new_is_definition) {
prev_class.first_param_node_id = new_class.first_param_node_id;
prev_class.last_param_node_id = new_class.last_param_node_id;
prev_class.implicit_param_refs_id = new_class.implicit_param_refs_id;
prev_class.param_refs_id = new_class.param_refs_id;
prev_class.definition_id = new_class.definition_id;
Expand Down Expand Up @@ -225,6 +227,8 @@ static auto BuildClassDecl(Context& context, Parse::AnyClassDeclId node_id,
.name_id = name_context.name_id_for_new_inst(),
.parent_scope_id = name_context.parent_scope_id_for_new_inst(),
.generic_id = SemIR::GenericId::Invalid,
.first_param_node_id = name.first_param_node_id,
.last_param_node_id = name.last_param_node_id,
.implicit_param_refs_id = name.implicit_params_id,
.param_refs_id = name.params_id,
// `.self_type_id` depends on the ClassType, so is set below.
Expand Down
7 changes: 6 additions & 1 deletion toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ static auto MergeFunctionRedecl(Context& context, SemIRLoc new_loc,
SemIR::ImportIRId prev_import_ir_id) -> bool {
auto& prev_function = context.functions().Get(prev_function_id);

if (!CheckFunctionTypeMatches(context, new_function, prev_function, {})) {
if (!CheckFunctionTypeMatches(context, new_function, prev_function, {},
/*check_syntax=*/true)) {
return false;
}

Expand All @@ -118,6 +119,8 @@ static auto MergeFunctionRedecl(Context& context, SemIRLoc new_loc,
// Track the signature from the definition, so that IDs in the body
// match IDs in the signature.
prev_function.definition_id = new_function.definition_id;
prev_function.first_param_node_id = new_function.first_param_node_id;
prev_function.last_param_node_id = new_function.last_param_node_id;
prev_function.implicit_param_refs_id = new_function.implicit_param_refs_id;
prev_function.param_refs_id = new_function.param_refs_id;
prev_function.return_storage_id = new_function.return_storage_id;
Expand Down Expand Up @@ -252,6 +255,8 @@ static auto BuildFunctionDecl(Context& context,
.parent_scope_id = name_context.parent_scope_id_for_new_inst(),
.decl_id = decl_id,
.generic_id = SemIR::GenericId::Invalid,
.first_param_node_id = name.first_param_node_id,
.last_param_node_id = name.last_param_node_id,
.implicit_param_refs_id = name.implicit_params_id,
.param_refs_id = name.params_id,
.return_storage_id = return_storage_id,
Expand Down
2 changes: 2 additions & 0 deletions toolchain/check/handle_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ auto HandleParseNode(Context& context, Parse::ImplIntroducerId node_id)
auto HandleParseNode(Context& context, Parse::ImplForallId node_id) -> bool {
auto params_id =
context.node_stack().Pop<Parse::NodeKind::ImplicitParamList>();
context.node_stack()
.PopAndDiscardSoloNodeId<Parse::NodeKind::ImplicitParamListStart>();
context.node_stack().Push(node_id, params_id);
return true;
}
Expand Down
5 changes: 4 additions & 1 deletion toolchain/check/handle_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ static auto BuildInterfaceDecl(Context& context,
// now we just check the generic parameters match.
if (CheckRedeclParamsMatch(
context,
DeclParams(interface_decl_id, name.implicit_params_id,
DeclParams(interface_decl_id, name.first_param_node_id,
name.last_param_node_id, name.implicit_params_id,
name.params_id),
DeclParams(context.interfaces().Get(
existing_interface_decl->interface_id)))) {
Expand Down Expand Up @@ -90,6 +91,8 @@ static auto BuildInterfaceDecl(Context& context,
.name_id = name_context.name_id_for_new_inst(),
.parent_scope_id = name_context.parent_scope_id_for_new_inst(),
.generic_id = generic_id,
.first_param_node_id = name.first_param_node_id,
.last_param_node_id = name.last_param_node_id,
.implicit_param_refs_id = name.implicit_params_id,
.param_refs_id = name.params_id,
.decl_id = interface_decl_id};
Expand Down
8 changes: 4 additions & 4 deletions toolchain/check/handle_pattern_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ auto HandleParseNode(Context& context, Parse::ImplicitParamListStartId node_id)

auto HandleParseNode(Context& context, Parse::ImplicitParamListId node_id)
-> bool {
// Note the Start node remains on the stack, where the param list handler can
// make use of it.
auto refs_id = context.param_and_arg_refs_stack().EndAndPop(
Parse::NodeKind::ImplicitParamListStart);
context.node_stack()
.PopAndDiscardSoloNodeId<Parse::NodeKind::ImplicitParamListStart>();
context.node_stack().Push(node_id, refs_id);
// The implicit parameter list's scope extends to the end of the following
// parameter list.
Expand All @@ -40,10 +40,10 @@ auto HandleParseNode(Context& context, Parse::PatternListCommaId /*node_id*/)
}

auto HandleParseNode(Context& context, Parse::TuplePatternId node_id) -> bool {
// Note the Start node remains on the stack, where the param list handler can
// make use of it.
auto refs_id = context.param_and_arg_refs_stack().EndAndPop(
Parse::NodeKind::TuplePatternStart);
context.node_stack()
.PopAndDiscardSoloNodeId<Parse::NodeKind::TuplePatternStart>();
context.node_stack().Push(node_id, refs_id);
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion toolchain/check/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static auto CheckAssociatedFunctionImplementation(
// synthesize a suitable thunk.
if (!CheckFunctionTypeMatches(
context, context.functions().Get(impl_function_decl->function_id),
context.functions().Get(interface_function_id), substitutions)) {
context.functions().Get(interface_function_id), substitutions,
/*check_syntax=*/false)) {
return SemIR::InstId::BuiltinError;
}
return impl_decl_id;
Expand Down
6 changes: 6 additions & 0 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ class ImportRefResolver {
// importing the parameters.
.parent_scope_id = SemIR::NameScopeId::Invalid,
.generic_id = generic_id,
.first_param_node_id = Parse::NodeId::Invalid,
.last_param_node_id = Parse::NodeId::Invalid,
.implicit_param_refs_id = import_class.implicit_param_refs_id.is_valid()
? SemIR::InstBlockId::Empty
: SemIR::InstBlockId::Invalid,
Expand Down Expand Up @@ -1143,6 +1145,8 @@ class ImportRefResolver {
.parent_scope_id = parent_scope_id,
.decl_id = function_decl_id,
.generic_id = generic_id,
.first_param_node_id = Parse::NodeId::Invalid,
.last_param_node_id = Parse::NodeId::Invalid,
.implicit_param_refs_id = GetLocalParamRefsId(
function.implicit_param_refs_id, implicit_param_const_ids),
.param_refs_id =
Expand Down Expand Up @@ -1249,6 +1253,8 @@ class ImportRefResolver {
// importing the parameters.
.parent_scope_id = SemIR::NameScopeId::Invalid,
.generic_id = generic_id,
.first_param_node_id = Parse::NodeId::Invalid,
.last_param_node_id = Parse::NodeId::Invalid,
.implicit_param_refs_id =
import_interface.implicit_param_refs_id.is_valid()
? SemIR::InstBlockId::Empty
Expand Down
85 changes: 82 additions & 3 deletions toolchain/check/merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,101 @@ static auto CheckRedeclParams(Context& context, SemIRLoc new_decl_loc,
return true;
}

// Returns true if the two nodes represent the same syntax.
// TODO: Detect raw identifiers (will require token changes).
static auto IsNodeSyntaxEqual(Context& context, Parse::NodeId new_node_id,
Parse::NodeId prev_node_id) -> bool {
if (context.parse_tree().node_kind(new_node_id) !=
context.parse_tree().node_kind(prev_node_id)) {
return false;
}

// TODO: Should there be a trivial way to check if we need to check spellings?
// Identifiers and literals need their text checked for cross-file matching,
// but not intra-file. Keywords and operators shouldn't need the token text
// examined at all.
auto new_spelling = context.tokens().GetTokenText(
context.parse_tree().node_token(new_node_id));
auto prev_spelling = context.tokens().GetTokenText(
context.parse_tree().node_token(prev_node_id));
return new_spelling == prev_spelling;
}

// Returns false if redeclaration parameter syntax doesn't match.
static auto CheckRedeclParamSyntax(Context& context,
Parse::NodeId new_first_param_node_id,
Parse::NodeId new_last_param_node_id,
Parse::NodeId prev_first_param_node_id,
Parse::NodeId prev_last_param_node_id)
-> bool {
// Parse nodes may not always be available to compare.
// TODO: Support cross-file syntax checks. Right now imports provide invalid
// nodes, and we'll need to follow the declaration to its original file to
// get the parse tree.
if (!new_first_param_node_id.is_valid() ||
!prev_first_param_node_id.is_valid()) {
return true;
}
CARBON_CHECK(new_last_param_node_id.is_valid())
<< "new_last_param_node_id.is_valid should match "
"new_first_param_node_id.is_valid";
CARBON_CHECK(prev_last_param_node_id.is_valid())
<< "prev_last_param_node_id.is_valid should match "
"prev_first_param_node_id.is_valid";

auto new_range = context.parse_tree().postorder(new_first_param_node_id,
new_last_param_node_id);
auto prev_range = context.parse_tree().postorder(prev_first_param_node_id,
prev_last_param_node_id);

// zip is using the shortest range. If they differ in length, there should be
// some difference inside the range because the range includes parameter
// brackets. As a consequence, we don't explicitly handle different range
// sizes here.
for (auto [new_node_id, prev_node_id] : llvm::zip(new_range, prev_range)) {
if (!IsNodeSyntaxEqual(context, new_node_id, prev_node_id)) {
CARBON_DIAGNOSTIC(RedeclParamSyntaxDiffers, Error,
"Redeclaration syntax differs here.");
CARBON_DIAGNOSTIC(RedeclParamSyntaxPrevious, Note,
"Comparing with previous declaration here.");
context.emitter()
.Build(new_node_id, RedeclParamSyntaxDiffers)
.Note(prev_node_id, RedeclParamSyntaxPrevious)
.Emit();

return false;
}
}

return true;
}

auto CheckRedeclParamsMatch(Context& context, const DeclParams& new_entity,
const DeclParams& prev_entity,
Substitutions substitutions) -> bool {
Substitutions substitutions, bool check_syntax)
-> bool {
if (EntityHasParamError(context, new_entity) ||
EntityHasParamError(context, prev_entity)) {
return false;
}
if (!CheckRedeclParams(context, new_entity.loc,
new_entity.implicit_param_refs_id, prev_entity.loc,
prev_entity.implicit_param_refs_id, "implicit ",
substitutions) ||
!CheckRedeclParams(context, new_entity.loc, new_entity.param_refs_id,
substitutions)) {
return false;
}
if (!CheckRedeclParams(context, new_entity.loc, new_entity.param_refs_id,
prev_entity.loc, prev_entity.param_refs_id, "",
substitutions)) {
return false;
}
if (check_syntax &&
!CheckRedeclParamSyntax(context, new_entity.first_param_node_id,
new_entity.last_param_node_id,
prev_entity.first_param_node_id,
prev_entity.last_param_node_id)) {
return false;
}
return true;
}

Expand Down
19 changes: 16 additions & 3 deletions toolchain/check/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,30 @@ struct DeclParams {
template <typename Entity>
explicit DeclParams(const Entity& entity)
: loc(entity.decl_id),
first_param_node_id(entity.first_param_node_id),
last_param_node_id(entity.last_param_node_id),
implicit_param_refs_id(entity.implicit_param_refs_id),
param_refs_id(entity.param_refs_id) {}

DeclParams(SemIRLoc loc, SemIR::InstBlockId implicit_params_id,
DeclParams(SemIRLoc loc, Parse::NodeId first_param_node_id,
Parse::NodeId last_param_node_id,
SemIR::InstBlockId implicit_params_id,
SemIR::InstBlockId params_id)
: loc(loc),
first_param_node_id(first_param_node_id),
last_param_node_id(last_param_node_id),
implicit_param_refs_id(implicit_params_id),
param_refs_id(params_id) {}

// The location of the declaration of the entity.
SemIRLoc loc;

// Parse tree bounds for the parameters, including both implicit and explicit
// parameters. These will be compared to match between declaration and
// definition.
Parse::NodeId first_param_node_id;
Parse::NodeId last_param_node_id;

// The implicit parameters of the entity. Can be Invalid if there is no
// implicit parameter list.
SemIR::InstBlockId implicit_param_refs_id;
Expand All @@ -71,8 +84,8 @@ struct DeclParams {
// returns false.
auto CheckRedeclParamsMatch(Context& context, const DeclParams& new_entity,
const DeclParams& prev_entity,
Substitutions substitutions = Substitutions())
-> bool;
Substitutions substitutions = Substitutions(),
bool check_syntax = true) -> bool;

} // namespace Carbon::Check

Expand Down
25 changes: 25 additions & 0 deletions toolchain/check/name_component.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,40 @@
namespace Carbon::Check {

auto PopNameComponent(Context& context) -> NameComponent {
Parse::NodeId first_param_node_id = Parse::InvalidNodeId();
Parse::NodeId last_param_node_id = Parse::InvalidNodeId();

// Explicit params.
auto [params_loc_id, params_id] =
context.node_stack().PopWithNodeIdIf<Parse::NodeKind::TuplePattern>();
if (params_id) {
first_param_node_id =
context.node_stack()
.PopForSoloNodeId<Parse::NodeKind::TuplePatternStart>();
last_param_node_id = params_loc_id;
}

// Implicit params.
auto [implicit_params_loc_id, implicit_params_id] =
context.node_stack()
.PopWithNodeIdIf<Parse::NodeKind::ImplicitParamList>();
if (implicit_params_id) {
// Implicit params always come before explicit params.
first_param_node_id =
context.node_stack()
.PopForSoloNodeId<Parse::NodeKind::ImplicitParamListStart>();
// Only use the end of implicit params if there weren't explicit params.
if (last_param_node_id.is_valid()) {
last_param_node_id = params_loc_id;
}
}

auto [name_loc_id, name_id] = context.node_stack().PopNameWithNodeId();
return {
.name_loc_id = name_loc_id,
.name_id = name_id,
.first_param_node_id = first_param_node_id,
.last_param_node_id = last_param_node_id,
.implicit_params_loc_id = implicit_params_loc_id,
.implicit_params_id =
implicit_params_id.value_or(SemIR::InstBlockId::Invalid),
Expand Down
Loading

0 comments on commit db02265

Please sign in to comment.