Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSP: fix "call hierarchy" across files #26040

Merged
merged 15 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/lib/resolution/resolution-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2150,8 +2150,8 @@ ApplicabilityResult instantiateSignature(ResolutionContext* rc,
const TypedFnSignature* parentSignature = sig->parentFn();
if (parentSignature) {
for (auto up = parentSignature; up; up = up->parentFn()) {
CHPL_ASSERT(!up->needsInstantiation());
if (up->needsInstantiation()) {
CHPL_UNIMPL("parent function needs instantiation");
return ApplicabilityResult::failure(sig->id(), FAIL_CANDIDATE_OTHER);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tools/chapel-py/src/python-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ template <typename T>
std::vector<T> unwrapVector(ContextObject* CONTEXT, PyObject* vec) {
std::vector<T> toReturn(PyList_Size(vec));
for (ssize_t i = 0; i < PyList_Size(vec); i++) {
toReturn.push_back(PythonReturnTypeInfo<T>::unwrap(CONTEXT, PyList_GetItem(vec, i)));
toReturn[i] = PythonReturnTypeInfo<T>::unwrap(CONTEXT, PyList_GetItem(vec, i));
}
return toReturn;
}
Expand Down
102 changes: 77 additions & 25 deletions tools/chpl-language-server/src/chpl-language-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def __init__(self, file: str, config: Optional["WorkspaceConfig"]):
self.context: chapel.Context = chapel.Context()
self.file_infos: List["FileInfo"] = []
self.global_uses: Dict[str, List[References]] = defaultdict(list)
self.instantiation_ids: Dict[chapel.TypedSignature, str] = {}
jabraham17 marked this conversation as resolved.
Show resolved Hide resolved
self.instantiation_id_counter = 0

if config:
file_config = config.for_file(file)
Expand All @@ -508,6 +510,27 @@ def __init__(self, file: str, config: Optional["WorkspaceConfig"]):

self.context.set_module_paths(self.module_paths, self.file_paths)

def register_signature(self, sig: chapel.TypedSignature) -> str:
"""
The language server can't send over typed signatures directly for
situations such as call hierarchy items (but we need to reason about
instantiations). Instead, keep a global unique ID for each signature,
and use that to identify them.
"""
if sig in self.instantiation_ids:
return self.instantiation_ids[sig]

self.instantiation_id_counter += 1
uid = str(self.instantiation_id_counter)
self.instantiation_ids[sig] = uid
return uid

def retrieve_signature(self, uid: str) -> Optional[chapel.TypedSignature]:
for sig, sig_uid in self.instantiation_ids.items():
if sig_uid == uid:
return sig
return None

def new_file_info(
self, uri: str, use_resolver: bool
) -> Tuple["FileInfo", List[Any]]:
Expand Down Expand Up @@ -1007,7 +1030,9 @@ def __init__(self, config: CLSConfig):
super().__init__("chpl-language-server", "v0.1")

self.contexts: Dict[str, ContextContainer] = {}
self.file_infos: Dict[str, FileInfo] = {}
self.context_ids: Dict[ContextContainer, str] = {}
self.context_id_counter = 0
self.file_infos: Dict[Tuple[str, Optional[str]], FileInfo] = {}
self.configurations: Dict[str, WorkspaceConfig] = {}

self.use_resolver: bool = config.get("resolver")
Expand Down Expand Up @@ -1107,17 +1132,28 @@ def get_context(self, uri: str) -> ContextContainer:
for file in context.file_paths:
self.contexts[file] = context
self.contexts[path] = context
self.context_id_counter += 1
self.context_ids[context] = str(self.context_id_counter)

return context

def retrieve_context(self, context_id: str) -> Optional[ContextContainer]:
for ctx, cid in self.context_ids.items():
if cid == context_id:
return ctx
return None

def eagerly_process_all_files(self, context: ContextContainer):
cfg = context.config
if cfg:
for file in cfg.files:
self.get_file_info("file://" + file, do_update=False)

def get_file_info(
self, uri: str, do_update: bool = False
self,
uri: str,
do_update: bool = False,
context_id: Optional[str] = None,
) -> Tuple[FileInfo, List[Any]]:
"""
The language server maintains a FileInfo object per file. The FileInfo
Expand All @@ -1128,19 +1164,34 @@ def get_file_info(
creating one if it doesn't exist. If do_update is set to True,
then the FileInfo's index is rebuilt even if it has already been
computed. This is useful if the underlying file has changed.

Most of the tiem, we will create a new context for a given URI. When
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, typo

requested, however, context_id will be used to create a FileInfo
for a specific context. This is useful if e.g., file A wants to display
an instantiation in file B.
"""

errors = []

if uri in self.file_infos:
file_info = self.file_infos[uri]
fi_key = (uri, context_id)
if fi_key in self.file_infos:
file_info = self.file_infos[fi_key]
if do_update:
errors = file_info.context.advance()
else:
file_info, errors = self.get_context(uri).new_file_info(
uri, self.use_resolver
)
self.file_infos[uri] = file_info
if context_id:
context = self.retrieve_context(context_id)
assert context
else:
context = self.get_context(uri)

file_info, errors = context.new_file_info(uri, self.use_resolver)
self.file_infos[fi_key] = file_info

# Also make this the "default" context for this file in case we
# open it.
if (uri, None) not in self.file_infos:
self.file_infos[(uri, None)] = file_info

# filter out errors that are not related to the file
cur_path = uri[len("file://") :]
Expand Down Expand Up @@ -1396,7 +1447,8 @@ def sym_to_call_hierarchy_item(
"""
loc = location_to_location(sym.location())

inst_idx = -1
inst_id = None
context_id = None

return CallHierarchyItem(
name=sym.name(),
Expand All @@ -1405,11 +1457,11 @@ def sym_to_call_hierarchy_item(
uri=loc.uri,
range=loc.range,
selection_range=location_to_range(sym.name_location()),
data=[sym.unique_id(), inst_idx],
data=[sym.unique_id(), inst_id, context_id],
)

def fn_to_call_hierarchy_item(
self, sig: chapel.TypedSignature
self, sig: chapel.TypedSignature, caller_context: ContextContainer
) -> CallHierarchyItem:
"""
Like sym_to_call_hierarchy_item, but for function instantiations.
Expand All @@ -1419,8 +1471,8 @@ def fn_to_call_hierarchy_item(
"""
fn: chapel.Function = sig.ast()
item = self.sym_to_call_hierarchy_item(fn)
fi, _ = self.get_file_info(item.uri)
item.data[1] = fi.index_of_instantiation(fn, sig)
item.data[1] = caller_context.register_signature(sig)
item.data[2] = self.context_ids[caller_context]

return item

Expand All @@ -1433,16 +1485,17 @@ def unpack_call_hierarchy_item(
item.data is None
or not isinstance(item.data, list)
or not isinstance(item.data[0], str)
or not isinstance(item.data[1], int)
or not isinstance(item.data[1], Optional[str])
or not isinstance(item.data[2], Optional[str])
):
self.show_message(
"Call hierarchy item contains missing or invalid additional data",
MessageType.Error,
)
return None
uid, idx = item.data
uid, idx, ctx = item.data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, rename idx to inst_id


fi, _ = self.get_file_info(item.uri)
fi, _ = self.get_file_info(item.uri, context_id=ctx)

# TODO: Performance:
# Once the Python bindings supports it, we can use the
Expand All @@ -1456,11 +1509,7 @@ def unpack_call_hierarchy_item(
# We don't handle that here.
return None

instantiation = None
if idx != -1:
instantiation = fi.instantiation_at_index(fn, idx)
else:
instantiation = fi.concrete_instantiation_for(fn)
instantiation = fi.context.retrieve_signature(idx)

return (fi, fn, instantiation)

Expand Down Expand Up @@ -2000,7 +2049,10 @@ async def prepare_call_hierarchy(

# Oddly, returning multiple here makes for no child nodes in the VSCode
# UI. Just take one signature for now.
return next(([ls.fn_to_call_hierarchy_item(sig)] for sig in sigs), [])
return next(
([ls.fn_to_call_hierarchy_item(sig, fi.context)] for sig in sigs),
[],
)

@server.feature(CALL_HIERARCHY_INCOMING_CALLS)
async def call_hierarchy_incoming(
Expand Down Expand Up @@ -2046,7 +2098,7 @@ async def call_hierarchy_incoming(
if isinstance(called_fn, str):
item = ls.sym_to_call_hierarchy_item(hack_id_to_node[called_fn])
else:
item = ls.fn_to_call_hierarchy_item(called_fn)
item = ls.fn_to_call_hierarchy_item(called_fn, fi.context)

to_return.append(
CallHierarchyIncomingCall(
Expand All @@ -2070,7 +2122,7 @@ async def call_hierarchy_outgoing(
if unpacked is None:
return None

_, fn, instantiation = unpacked
fi, fn, instantiation = unpacked

outgoing_calls: Dict[chapel.TypedSignature, List[chapel.FnCall]] = (
defaultdict(list)
Expand All @@ -2093,7 +2145,7 @@ async def call_hierarchy_outgoing(

to_return = []
for called_fn, calls in outgoing_calls.items():
item = ls.fn_to_call_hierarchy_item(called_fn)
item = ls.fn_to_call_hierarchy_item(called_fn, fi.context)
to_return.append(
CallHierarchyOutgoingCall(
item,
Expand Down
34 changes: 34 additions & 0 deletions tools/chpl-language-server/test/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,40 @@ async def test_go_to_definition_use_standard(client: LanguageClient):
await check_goto_decl_def_module(client, doc, pos((2, 8)), mod_Time)


@pytest.mark.asyncio
async def test_go_to_definition_use_across_modules(client: LanguageClient):
"""
Ensure that go-to-definition works on symbols that reference other modules
"""

fileA = """
module A {
var x = 42;
}
"""
fileB = """
module B {
use A;
var y = x;
}
"""

async def check(docs):
docA = docs("A")
docB = docs("B")

await check_goto_decl_def_module(client, docB, pos((1, 6)), docA)
await check_goto_decl_def(
client, docB, pos((2, 10)), (docA, pos((1, 6)))
)

async with source_files(client, A=fileA, B=fileB) as docs:
await check(docs)

async with unrelated_source_files(client, A=fileA, B=fileB) as docs:
await check(docs)


@pytest.mark.asyncio
async def test_go_to_definition_standard_rename(client: LanguageClient):
"""
Expand Down
Loading