Skip to content

Commit

Permalink
Improve access checking code (#4317)
Browse files Browse the repository at this point in the history
This change accomplishes the TODOs for access checking. More
specifically it,
- makes `SemIR::AccessKind` formattable using `llvm::formatv`.
- makes use of `LookupUnqualifiedName` to find `Self`.
  • Loading branch information
brymer-meneses authored Sep 17, 2024
1 parent 1d90455 commit da40c8b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 87 deletions.
22 changes: 11 additions & 11 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ auto Context::LookupNameInDecl(SemIR::LocId loc_id, SemIR::NameId name_id,
}

auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
SemIR::NameId name_id) -> LookupResult {
SemIR::NameId name_id, bool required)
-> LookupResult {
// TODO: Check for shadowed lookup results.

// Find the results from ancestor lexical scopes. These will be combined with
Expand Down Expand Up @@ -328,7 +329,10 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
}

// We didn't find anything at all.
DiagnoseNameNotFound(node_id, name_id);
if (required) {
DiagnoseNameNotFound(node_id, name_id);
}

return {.specific_id = SemIR::SpecificId::Invalid,
.inst_id = SemIR::InstId::BuiltinError};
}
Expand Down Expand Up @@ -368,18 +372,14 @@ static auto DiagnoseInvalidQualifiedNameAccess(Context& context, SemIRLoc loc,
// TODO: Support scoped entities other than just classes.
auto class_info = context.classes().Get(class_type->class_id);

// TODO: Support passing AccessKind to diagnostics.
CARBON_DIAGNOSTIC(ClassInvalidMemberAccess, Error,
"Cannot access {0} member `{1}` of type `{2}`.",
llvm::StringLiteral, SemIR::NameId, SemIR::TypeId);
SemIR::AccessKind, SemIR::NameId, SemIR::TypeId);
CARBON_DIAGNOSTIC(ClassMemberDefinition, Note,
"The {0} member `{1}` is defined here.",
llvm::StringLiteral, SemIR::NameId);
"The {0} member `{1}` is defined here.", SemIR::AccessKind,
SemIR::NameId);

auto parent_type_id = class_info.self_type_id;
auto access_desc = access_kind == SemIR::AccessKind::Private
? llvm::StringLiteral("private")
: llvm::StringLiteral("protected");

if (access_kind == SemIR::AccessKind::Private && is_parent_access) {
if (auto base_decl = context.insts().TryGetAsIfValid<SemIR::BaseDecl>(
Expand All @@ -395,9 +395,9 @@ static auto DiagnoseInvalidQualifiedNameAccess(Context& context, SemIRLoc loc,
}

context.emitter()
.Build(loc, ClassInvalidMemberAccess, access_desc, name_id,
.Build(loc, ClassInvalidMemberAccess, access_kind, name_id,
parent_type_id)
.Note(scope_result_id, ClassMemberDefinition, access_desc, name_id)
.Note(scope_result_id, ClassMemberDefinition, access_kind, name_id)
.Emit();
}

Expand Down
4 changes: 2 additions & 2 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class Context {
SemIR::NameScopeId scope_id) -> SemIR::InstId;

// Performs an unqualified name lookup, returning the referenced instruction.
auto LookupUnqualifiedName(Parse::NodeId node_id, SemIR::NameId name_id)
-> LookupResult;
auto LookupUnqualifiedName(Parse::NodeId node_id, SemIR::NameId name_id,
bool required = true) -> LookupResult;

// Performs a name lookup in a specified scope, returning the referenced
// instruction. Does not look into extended scopes. Returns an invalid
Expand Down
34 changes: 3 additions & 31 deletions toolchain/check/member_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,41 +100,13 @@ static auto IsInstanceMethod(const SemIR::File& sem_ir,
return false;
}

// Returns the FunctionId of the current function if it exists.
static auto GetCurrentFunction(Context& context)
-> std::optional<SemIR::FunctionId> {
if (context.return_scope_stack().empty()) {
return std::nullopt;
}

return context.insts()
.GetAs<SemIR::FunctionDecl>(context.return_scope_stack().back().decl_id)
.function_id;
}

// Returns the highest allowed access. For example, if this returns `Protected`
// then only `Public` and `Protected` accesses are allowed--not `Private`.
static auto GetHighestAllowedAccess(Context& context, SemIRLoc loc,
static auto GetHighestAllowedAccess(Context& context, SemIR::LocId loc_id,
SemIR::ConstantId name_scope_const_id)
-> SemIR::AccessKind {
// TODO: Maybe use LookupUnqualifiedName for `Self` to support things like
// `var x: Self.ParentProtectedType`?
auto current_function = GetCurrentFunction(context);
// If `current_function` is a `nullopt` then we're accessing from a global
// variable.
if (!current_function) {
return SemIR::AccessKind::Public;
}

auto scope_id = context.functions().Get(*current_function).parent_scope_id;
if (!scope_id.is_valid()) {
return SemIR::AccessKind::Public;
}
auto scope = context.name_scopes().Get(scope_id);

// Lookup the inst for `Self` in the parent scope of the current function.
auto [self_type_inst_id, _] = context.LookupNameInExactScope(
loc, SemIR::NameId::SelfType, scope_id, scope);
auto [_, self_type_inst_id] = context.LookupUnqualifiedName(
loc_id.node_id(), SemIR::NameId::SelfType, /*required=*/false);
if (!self_type_inst_id.is_valid()) {
return SemIR::AccessKind::Public;
}
Expand Down
74 changes: 31 additions & 43 deletions toolchain/check/testdata/class/access_modifers.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,21 @@ class A {
// CHECK:STDERR: ^
// CHECK:STDERR:
let x: i32 = A.x;
// CHECK:STDERR: fail_global_access.carbon:[[@LINE+7]]:14: ERROR: Cannot access private member `y` of type `A`.
// CHECK:STDERR: fail_global_access.carbon:[[@LINE+6]]:14: ERROR: Cannot access private member `y` of type `A`.
// CHECK:STDERR: let y: i32 = A.y;
// CHECK:STDERR: ^~~
// CHECK:STDERR: fail_global_access.carbon:[[@LINE-14]]:15: The private member `y` is defined here.
// CHECK:STDERR: private let y: i32 = 5;
// CHECK:STDERR: ^
// CHECK:STDERR:
let y: i32 = A.y;

// --- fail_todo_global_self_access.carbon
// --- self_access.carbon

library "[[@TEST_NAME]]";

class A {
private let internal: i32 = 10;
// CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE+13]]:16: ERROR: Member access into incomplete class `A`.
// CHECK:STDERR: let y: i32 = Self.internal;
// CHECK:STDERR: ^~~~~~~~~~~~~
// CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE-5]]:1: Class is incomplete within its definition.
// CHECK:STDERR: class A {
// CHECK:STDERR: ^~~~~~~~~
// CHECK:STDERR:
// CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE+6]]:16: ERROR: Cannot access private member `internal` of type `A`.
// CHECK:STDERR: let y: i32 = Self.internal;
// CHECK:STDERR: ^~~~~~~~~~~~~
// CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE-11]]:15: The private member `internal` is defined here.
// CHECK:STDERR: private let internal: i32 = 10;
// CHECK:STDERR: ^~~~~~~~
let y: i32 = Self.internal;
private fn F() {}
private fn G() { Self.F(); }
}

// CHECK:STDOUT: --- fail_private_field_access.carbon
Expand Down Expand Up @@ -554,9 +540,9 @@ class A {
// CHECK:STDOUT: %int.make_type_32.loc16: init type = call constants.%Int32() [template = i32]
// CHECK:STDOUT: %.loc16_8.1: type = value_of_initializer %int.make_type_32.loc16 [template = i32]
// CHECK:STDOUT: %.loc16_8.2: type = converted %int.make_type_32.loc16, %.loc16_8.1 [template = i32]
// CHECK:STDOUT: %int.make_type_32.loc24: init type = call constants.%Int32() [template = i32]
// CHECK:STDOUT: %.loc24_8.1: type = value_of_initializer %int.make_type_32.loc24 [template = i32]
// CHECK:STDOUT: %.loc24_8.2: type = converted %int.make_type_32.loc24, %.loc24_8.1 [template = i32]
// CHECK:STDOUT: %int.make_type_32.loc23: init type = call constants.%Int32() [template = i32]
// CHECK:STDOUT: %.loc23_8.1: type = value_of_initializer %int.make_type_32.loc23 [template = i32]
// CHECK:STDOUT: %.loc23_8.2: type = converted %int.make_type_32.loc23, %.loc23_8.1 [template = i32]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @A {
Expand Down Expand Up @@ -584,26 +570,27 @@ class A {
// CHECK:STDOUT: %A.ref.loc16: type = name_ref A, file.%A.decl [template = constants.%A]
// CHECK:STDOUT: %x.ref: <error> = name_ref x, <error> [template = <error>]
// CHECK:STDOUT: %x: i32 = bind_name x, <error>
// CHECK:STDOUT: %A.ref.loc24: type = name_ref A, file.%A.decl [template = constants.%A]
// CHECK:STDOUT: %A.ref.loc23: type = name_ref A, file.%A.decl [template = constants.%A]
// CHECK:STDOUT: %y.ref: <error> = name_ref y, <error> [template = <error>]
// CHECK:STDOUT: %y: i32 = bind_name y, <error>
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- fail_todo_global_self_access.carbon
// CHECK:STDOUT: --- self_access.carbon
// CHECK:STDOUT:
// CHECK:STDOUT: constants {
// CHECK:STDOUT: %A: type = class_type @A [template]
// CHECK:STDOUT: %Int32.type: type = fn_type @Int32 [template]
// CHECK:STDOUT: %F.type: type = fn_type @F [template]
// CHECK:STDOUT: %.1: type = tuple_type () [template]
// CHECK:STDOUT: %Int32: %Int32.type = struct_value () [template]
// CHECK:STDOUT: %.2: i32 = int_literal 10 [template]
// CHECK:STDOUT: %.3: type = struct_type {} [template]
// CHECK:STDOUT: %F: %F.type = struct_value () [template]
// CHECK:STDOUT: %G.type: type = fn_type @G [template]
// CHECK:STDOUT: %G: %G.type = struct_value () [template]
// CHECK:STDOUT: %.2: type = struct_type {} [template]
// CHECK:STDOUT: %.3: type = ptr_type %.2 [template]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: imports {
// CHECK:STDOUT: %Core: <namespace> = namespace file.%Core.import, [template] {
// CHECK:STDOUT: .Int32 = %import_ref
// CHECK:STDOUT: import Core//prelude
// CHECK:STDOUT: import Core//prelude/operators
// CHECK:STDOUT: import Core//prelude/types
Expand All @@ -613,7 +600,6 @@ class A {
// CHECK:STDOUT: import Core//prelude/operators/comparison
// CHECK:STDOUT: import Core//prelude/types/bool
// CHECK:STDOUT: }
// CHECK:STDOUT: %import_ref: %Int32.type = import_ref Core//prelude/types, inst+4, loaded [template = constants.%Int32]
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -626,23 +612,25 @@ class A {
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @A {
// CHECK:STDOUT: %int.make_type_32.loc5: init type = call constants.%Int32() [template = i32]
// CHECK:STDOUT: %.loc5_25.1: type = value_of_initializer %int.make_type_32.loc5 [template = i32]
// CHECK:STDOUT: %.loc5_25.2: type = converted %int.make_type_32.loc5, %.loc5_25.1 [template = i32]
// CHECK:STDOUT: %.loc5_31: i32 = int_literal 10 [template = constants.%.2]
// CHECK:STDOUT: %internal: i32 = bind_name internal, %.loc5_31
// CHECK:STDOUT: %int.make_type_32.loc19: init type = call constants.%Int32() [template = i32]
// CHECK:STDOUT: %.loc19_10.1: type = value_of_initializer %int.make_type_32.loc19 [template = i32]
// CHECK:STDOUT: %.loc19_10.2: type = converted %int.make_type_32.loc19, %.loc19_10.1 [template = i32]
// CHECK:STDOUT: %Self.ref: type = name_ref Self, constants.%A [template = constants.%A]
// CHECK:STDOUT: %internal.ref: <error> = name_ref internal, <error> [template = <error>]
// CHECK:STDOUT: %y: i32 = bind_name y, <error>
// CHECK:STDOUT: %F.decl: %F.type = fn_decl @F [template = constants.%F] {}
// CHECK:STDOUT: %G.decl: %G.type = fn_decl @G [template = constants.%G] {}
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%A
// CHECK:STDOUT: .internal [private] = %internal
// CHECK:STDOUT: .y = %y
// CHECK:STDOUT: .F [private] = %F.decl
// CHECK:STDOUT: .G [private] = %G.decl
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @Int32() -> type = "int.make_type_32";
// CHECK:STDOUT: fn @F() {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @G() {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %Self.ref: type = name_ref Self, constants.%A [template = constants.%A]
// CHECK:STDOUT: %F.ref: %F.type = name_ref F, @A.%F.decl [template = constants.%F]
// CHECK:STDOUT: %F.call: init %.1 = call %F.ref()
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
23 changes: 23 additions & 0 deletions toolchain/sem_ir/name_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@ enum class AccessKind : int8_t {
Private,
};

} // namespace Carbon::SemIR

template <>
struct llvm::format_provider<Carbon::SemIR::AccessKind> {
using AccessKind = Carbon::SemIR::AccessKind;
static void format(const AccessKind& loc, raw_ostream& out,
StringRef /*style*/) {
switch (loc) {
case AccessKind::Private:
out << "private";
break;
case AccessKind::Protected:
out << "protected";
break;
case AccessKind::Public:
out << "public";
break;
}
}
};

namespace Carbon::SemIR {

struct NameScope : Printable<NameScope> {
struct Entry {
NameId name_id;
Expand Down

0 comments on commit da40c8b

Please sign in to comment.