diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index b44ee487c..61f93e6aa 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -404,8 +404,8 @@ impl TypedDefinition { }) | Definition::Test(Function { body, .. }) = self { - if let Some(expression) = body.find_node(byte_index) { - return Some(Located::Expression(expression)); + if let Some(located) = body.find_node(byte_index) { + return Some(located); } } @@ -420,6 +420,7 @@ impl TypedDefinition { #[derive(Debug, Clone, PartialEq)] pub enum Located<'a> { Expression(&'a TypedExpr), + Pattern(&'a TypedPattern, &'a TypedExpr), Definition(&'a TypedDefinition), } @@ -427,6 +428,10 @@ impl<'a> Located<'a> { pub fn definition_location(&self) -> Option> { match self { Self::Expression(expression) => expression.definition_location(), + // TODO: Revise definition location semantic for 'Pattern' + // e.g. for constructors, we might want to show the type definition + // for that constructor. + Self::Pattern(_, _) => None, Self::Definition(definition) => Some(DefinitionLocation { module: None, span: definition.location(), @@ -499,7 +504,7 @@ impl CallArg { } impl TypedCallArg { - pub fn find_node(&self, byte_index: usize) -> Option<&TypedExpr> { + pub fn find_node(&self, byte_index: usize) -> Option> { self.value.find_node(byte_index) } } @@ -918,6 +923,45 @@ impl Pattern { } } +impl TypedPattern { + pub fn find_node<'a>(&'a self, byte_index: usize, value: &'a TypedExpr) -> Option> { + if !self.location().contains(byte_index) { + return None; + } + + match self { + Pattern::Int { .. } + | Pattern::Var { .. } + | Pattern::Assign { .. } + | Pattern::Discard { .. } => Some(Located::Pattern(self, value)), + + Pattern::List { elements, .. } + | Pattern::Tuple { + elems: elements, .. + } => elements + .iter() + .find_map(|e| e.find_node(byte_index, value)) + .or(Some(Located::Pattern(self, value))), + + Pattern::Constructor { arguments, .. } => arguments + .iter() + .find_map(|e| e.value.find_node(byte_index, value)) + .or(Some(Located::Pattern(self, value))), + } + } + + pub fn tipo(&self, value: &TypedExpr) -> Option> { + match self { + Pattern::Int { .. } => Some(builtins::int()), + Pattern::Constructor { tipo, .. } => Some(tipo.clone()), + Pattern::Var { .. } | Pattern::Assign { .. } | Pattern::Discard { .. } => { + Some(value.tipo()) + } + Pattern::List { .. } | Pattern::Tuple { .. } => None, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum ByteArrayFormatPreference { HexadecimalString, @@ -977,7 +1021,7 @@ impl TypedClause { } } - pub fn find_node(&self, byte_index: usize) -> Option<&TypedExpr> { + pub fn find_node(&self, byte_index: usize) -> Option> { self.then.find_node(byte_index) } } @@ -1119,7 +1163,7 @@ pub struct TypedRecordUpdateArg { } impl TypedRecordUpdateArg { - pub fn find_node(&self, byte_index: usize) -> Option<&TypedExpr> { + pub fn find_node(&self, byte_index: usize) -> Option> { self.value.find_node(byte_index) } } diff --git a/crates/aiken-lang/src/expr.rs b/crates/aiken-lang/src/expr.rs index dee664f43..4570909ef 100644 --- a/crates/aiken-lang/src/expr.rs +++ b/crates/aiken-lang/src/expr.rs @@ -5,7 +5,7 @@ use vec1::Vec1; use crate::{ ast::{ self, Annotation, Arg, AssignmentKind, BinOp, ByteArrayFormatPreference, CallArg, - DefinitionLocation, IfBranch, LogicalOpChainKind, ParsedCallArg, Pattern, + DefinitionLocation, IfBranch, Located, LogicalOpChainKind, ParsedCallArg, Pattern, RecordUpdateSpread, Span, TraceKind, TypedClause, TypedRecordUpdateArg, UnOp, UntypedClause, UntypedRecordUpdateArg, }, @@ -312,7 +312,7 @@ impl TypedExpr { // This could be optimised in places to exit early if the first of a series // of expressions is after the byte index. - pub fn find_node(&self, byte_index: usize) -> Option<&Self> { + pub fn find_node(&self, byte_index: usize) -> Option> { if !self.location().contains(byte_index) { return None; } @@ -323,18 +323,20 @@ impl TypedExpr { | TypedExpr::UInt { .. } | TypedExpr::String { .. } | TypedExpr::ByteArray { .. } - | TypedExpr::ModuleSelect { .. } => Some(self), + | TypedExpr::ModuleSelect { .. } => Some(Located::Expression(self)), TypedExpr::Trace { text, then, .. } => text .find_node(byte_index) .or_else(|| then.find_node(byte_index)) - .or(Some(self)), + .or(Some(Located::Expression(self))), TypedExpr::Pipeline { expressions, .. } | TypedExpr::Sequence { expressions, .. } => { expressions.iter().find_map(|e| e.find_node(byte_index)) } - TypedExpr::Fn { body, .. } => body.find_node(byte_index).or(Some(self)), + TypedExpr::Fn { body, .. } => body + .find_node(byte_index) + .or(Some(Located::Expression(self))), TypedExpr::Tuple { elems: elements, .. @@ -342,19 +344,21 @@ impl TypedExpr { | TypedExpr::List { elements, .. } => elements .iter() .find_map(|e| e.find_node(byte_index)) - .or(Some(self)), + .or(Some(Located::Expression(self))), TypedExpr::Call { fun, args, .. } => args .iter() .find_map(|arg| arg.find_node(byte_index)) .or_else(|| fun.find_node(byte_index)) - .or(Some(self)), + .or(Some(Located::Expression(self))), TypedExpr::BinOp { left, right, .. } => left .find_node(byte_index) .or_else(|| right.find_node(byte_index)), - TypedExpr::Assignment { value, .. } => value.find_node(byte_index), + TypedExpr::Assignment { value, pattern, .. } => pattern + .find_node(byte_index, value) + .or_else(|| value.find_node(byte_index)), TypedExpr::When { subject, clauses, .. @@ -365,20 +369,22 @@ impl TypedExpr { .iter() .find_map(|clause| clause.find_node(byte_index)) }) - .or(Some(self)), + .or(Some(Located::Expression(self))), TypedExpr::RecordAccess { record: expression, .. } | TypedExpr::TupleIndex { tuple: expression, .. - } => expression.find_node(byte_index).or(Some(self)), + } => expression + .find_node(byte_index) + .or(Some(Located::Expression(self))), TypedExpr::RecordUpdate { spread, args, .. } => args .iter() .find_map(|arg| arg.find_node(byte_index)) .or_else(|| spread.find_node(byte_index)) - .or(Some(self)), + .or(Some(Located::Expression(self))), TypedExpr::If { branches, @@ -393,9 +399,11 @@ impl TypedExpr { .or_else(|| branch.body.find_node(byte_index)) }) .or_else(|| final_else.find_node(byte_index)) - .or(Some(self)), + .or(Some(Located::Expression(self))), - TypedExpr::UnOp { value, .. } => value.find_node(byte_index).or(Some(self)), + TypedExpr::UnOp { value, .. } => value + .find_node(byte_index) + .or(Some(Located::Expression(self))), } } } diff --git a/crates/aiken-lang/src/tipo.rs b/crates/aiken-lang/src/tipo.rs index 79f039023..521a5d7d5 100644 --- a/crates/aiken-lang/src/tipo.rs +++ b/crates/aiken-lang/src/tipo.rs @@ -582,6 +582,9 @@ impl ValueConstructor { ValueConstructorVariant::Record { module, location, .. } + | ValueConstructorVariant::ModuleFn { + module, location, .. + } | ValueConstructorVariant::ModuleConstant { location, module, .. } => DefinitionLocation { @@ -589,8 +592,7 @@ impl ValueConstructor { span: *location, }, - ValueConstructorVariant::ModuleFn { location, .. } - | ValueConstructorVariant::LocalVariable { location } => DefinitionLocation { + ValueConstructorVariant::LocalVariable { location } => DefinitionLocation { module: None, span: *location, }, diff --git a/crates/aiken-lsp/src/server.rs b/crates/aiken-lsp/src/server.rs index 8a0b5f2ce..ff42e881f 100644 --- a/crates/aiken-lsp/src/server.rs +++ b/crates/aiken-lsp/src/server.rs @@ -344,6 +344,9 @@ impl Server { self.completion_for_import() } + // TODO: autocompletion for patterns + Some(Located::Pattern(_pattern, _value)) => None, + // TODO: autocompletion for other definitions Some(Located::Definition(_expression)) => None, @@ -458,13 +461,17 @@ impl Server { None => return Ok(None), }; - let expression = match found { - Located::Expression(expression) => expression, + let (location, definition_location, tipo) = match found { + Located::Expression(expression) => ( + expression.location(), + expression.definition_location(), + Some(expression.tipo()), + ), + Located::Pattern(pattern, value) => (pattern.location(), None, pattern.tipo(value)), Located::Definition(_) => return Ok(None), }; - let doc = expression - .definition_location() + let doc = definition_location .and_then(|loc| loc.module.map(|m| (m, loc.span))) .and_then(|(m, span)| { self.compiler @@ -475,12 +482,16 @@ impl Server { .and_then(|(checked_module, span)| checked_module.ast.find_node(span.start)) .and_then(|node| match node { Located::Expression(_) => None, + Located::Pattern(_, _) => None, Located::Definition(def) => def.doc(), }) .unwrap_or_default(); // Show the type of the hovered node to the user - let type_ = Printer::new().pretty_print(expression.tipo().as_ref(), 0); + let type_ = match tipo { + Some(t) => Printer::new().pretty_print(t.as_ref(), 0), + None => "?".to_string(), + }; let contents = formatdoc! {r#" ```aiken @@ -491,7 +502,7 @@ impl Server { Ok(Some(lsp_types::Hover { contents: lsp_types::HoverContents::Scalar(lsp_types::MarkedString::String(contents)), - range: Some(span_to_lsp_range(expression.location(), &line_numbers)), + range: Some(span_to_lsp_range(location, &line_numbers)), })) }