Skip to content

Commit

Permalink
Fixing pattern match consistencies
Browse files Browse the repository at this point in the history
let ["a", a] = ["a", 1] now works as a result of removing the top-level
special casing, and these two match statements now produce the same
result (although the code generated is different!).

```muse
match (1, 2) {
  (1, 2) => true,
};

let a = (1, 2);
match a {
  (1, 2) => true,
}
```

Both of these examples produced pattern mismatches before, although in
one case Nil was returned due to a bug in the compiler that has also
been fixed.
  • Loading branch information
ecton committed Mar 3, 2024
1 parent f9b66e9 commit bea9518
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 98 deletions.
211 changes: 137 additions & 74 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,45 +917,13 @@ impl<'a> Scope<'a> {
refutable
}
PatternKind::DestructureTuple(patterns) => {
let expected_count = u64::try_from(patterns.len()).unwrap_or_else(|_| {
self.compiler
.errors
.push(Ranged::new(pattern.range(), Error::UsizeTooLarge));
u64::MAX
});

self.compiler
.code
.invoke(source.clone(), Symbol::len_symbol(), 0);
self.compiler.code.compare(
CompareKind::Equal,
Register(0),
expected_count,
Register(0),
self.compile_tuple_destructure(
patterns,
&source,
doesnt_match,
pattern.range(),
bindings,
);
self.compiler
.code
.jump_if_not(doesnt_match, Register(0), ());

let element = self.new_temporary();
for (i, index) in (0..expected_count).zip(0_usize..) {
if i == expected_count - 1 && matches!(&pattern.0, PatternKind::AnyRemaining) {
break;
}
self.compiler.code.set_current_source_range(pattern.range());
self.compiler.code.copy(i, Register(0));
self.compiler
.code
.invoke(source.clone(), Symbol::nth_symbol(), 1);
self.compiler.code.copy(Register(0), element);
self.compile_pattern_binding(
&patterns[index],
ValueOrSource::Stack(element),
doesnt_match,
bindings,
);
}

Refutability::Refutable
}
PatternKind::DestructureMap(entries) => {
Expand Down Expand Up @@ -1009,29 +977,78 @@ impl<'a> Scope<'a> {
}
}

fn compile_tuple_destructure(
&mut self,
patterns: &[Ranged<PatternKind>],
source: &ValueOrSource,
doesnt_match: Label,
pattern_range: SourceRange,
bindings: &mut PatternBindings,
) {
let expected_count = u64::try_from(patterns.len()).unwrap_or_else(|_| {
self.compiler
.errors
.push(Ranged::new(pattern_range, Error::UsizeTooLarge));
u64::MAX
});

self.compiler
.code
.invoke(source.clone(), Symbol::len_symbol(), 0);
self.compiler
.code
.compare(CompareKind::Equal, Register(0), expected_count, Register(0));
self.compiler
.code
.jump_if_not(doesnt_match, Register(0), ());

let element = self.new_temporary();
for (i, index) in (0..expected_count).zip(0_usize..) {
if i == expected_count - 1 && matches!(&patterns[index].0, PatternKind::AnyRemaining) {
break;
}
self.compiler
.code
.set_current_source_range(patterns[index].range());
self.compiler.code.copy(i, Register(0));
self.compiler
.code
.invoke(source.clone(), Symbol::nth_symbol(), 1);
self.compiler.code.copy(Register(0), element);
self.compile_pattern_binding(
&patterns[index],
ValueOrSource::Stack(element),
doesnt_match,
bindings,
);
}
}

fn compile_match_expression(&mut self, match_expr: &MatchExpression, dest: OpDestination) {
let condition = if let Expression::Tuple(tuple) = &match_expr.condition.0 {
&**tuple
let conditions = if let Expression::Tuple(tuple) | Expression::List(tuple) =
&match_expr.condition.0
{
MatchExpressions::Tuple(tuple.iter().map(|expr| self.compile_source(expr)).collect())
} else {
std::slice::from_ref(&match_expr.condition)
MatchExpressions::Single(self.compile_source(&match_expr.condition))
};
let conditions = condition
.iter()
.map(|condition| self.compile_source(condition))
.collect::<Vec<_>>();
let after_expression = self.compiler.code.new_label();
self.compile_match(&conditions, &match_expr.matches, dest);
self.compiler.code.label(after_expression);
}

#[allow(clippy::too_many_lines)]
fn compile_match(
&mut self,
conditions: &[ValueOrSource],
conditions: &MatchExpressions,
matches: &Ranged<Matches>,
dest: OpDestination,
) {
self.compiler.code.set_current_source_range(matches.range());
let mut refutable = Refutability::Irrefutable;
let previous_handler = self.new_temporary();
let mut stored_previous_handler = false;
let after_expression = self.compiler.code.new_label();

for matches in &matches.patterns {
let mut pattern_block = self.enter_block(None);
Expand All @@ -1048,37 +1065,73 @@ impl<'a> Scope<'a> {
.code
.set_exception_handler(next_pattern, previous_handler);

let parameters = match &matches.pattern.kind.0 {
PatternKind::Any(None) | PatternKind::AnyRemaining => None,
PatternKind::Any(_)
| PatternKind::Literal(_)
| PatternKind::Or(_, _)
| PatternKind::DestructureMap(_) => {
Some(std::slice::from_ref(&matches.pattern.kind))
}
PatternKind::DestructureTuple(patterns) => Some(&**patterns),
};

if let Some(parameters) = parameters {
if parameters.len() == conditions.len()
|| parameters
.last()
.map_or(false, |param| matches!(&param.0, PatternKind::AnyRemaining))
{
for (parameter, condition) in parameters.iter().zip(conditions) {
refutable |= pattern_block.compile_pattern_binding(
parameter,
condition.clone(),
next_pattern,
&mut PatternBindings::default(),
);
match &conditions {
MatchExpressions::Single(condition) => {
match &matches.pattern.kind.0 {
PatternKind::Any(None) | PatternKind::AnyRemaining => {}
PatternKind::Any(_)
| PatternKind::Literal(_)
| PatternKind::Or(_, _)
| PatternKind::DestructureMap(_) => {
refutable |= pattern_block.compile_pattern_binding(
&matches.pattern.kind,
condition.clone(),
next_pattern,
&mut PatternBindings::default(),
);
}
PatternKind::DestructureTuple(patterns) => {
pattern_block.compile_tuple_destructure(
patterns,
condition,
next_pattern,
matches.range(),
&mut PatternBindings::default(),
);
refutable = Refutability::Refutable;
}
}

pattern_block
.compiler
.code
.set_current_source_range(matches.range());
} else {
pattern_block.compiler.code.jump(next_pattern, ());
}
MatchExpressions::Tuple(conditions) => {
let parameters = match &matches.pattern.kind.0 {
PatternKind::Any(None) | PatternKind::AnyRemaining => None,
PatternKind::Any(_)
| PatternKind::Literal(_)
| PatternKind::Or(_, _)
| PatternKind::DestructureMap(_) => {
Some(std::slice::from_ref(&matches.pattern.kind))
}
PatternKind::DestructureTuple(patterns) => Some(&**patterns),
};

if let Some(parameters) = parameters {
if parameters.len() == conditions.len()
|| parameters.last().map_or(false, |param| {
matches!(&param.0, PatternKind::AnyRemaining)
})
{
for (parameter, condition) in parameters.iter().zip(conditions) {
refutable |= pattern_block.compile_pattern_binding(
parameter,
condition.clone(),
next_pattern,
&mut PatternBindings::default(),
);
}
pattern_block
.compiler
.code
.set_current_source_range(matches.range());
} else {
refutable = Refutability::Refutable;
pattern_block.compiler.code.jump(next_pattern, ());
}
}
}
}

Expand All @@ -1105,7 +1158,7 @@ impl<'a> Scope<'a> {
.compiler
.code
.set_current_source_range(matches.range());
pattern_block.compiler.code.return_early();
pattern_block.compiler.code.jump(after_expression, ());
drop(pattern_block);

self.compiler.code.label(next_pattern);
Expand All @@ -1120,6 +1173,7 @@ impl<'a> Scope<'a> {
if refutable == Refutability::Refutable {
self.compiler.code.throw(FaultKind::PatternMismatch);
}
self.compiler.code.label(after_expression);
}

fn compile_throw(&mut self, value: &Ranged<Expression>, range: SourceRange) {
Expand All @@ -1136,7 +1190,11 @@ impl<'a> Scope<'a> {
|this| this.compile_expression(&try_expr.body, dest),
|this| {
if let Some(matches) = &try_expr.catch {
this.compile_match(&[ValueOrSource::Register(Register(0))], matches, dest);
this.compile_match(
&MatchExpressions::Single(ValueOrSource::Register(Register(0))),
matches,
dest,
);
} else {
// A catch-less try converts to nil. It's just a different
// form of the ? operator.
Expand Down Expand Up @@ -2089,3 +2147,8 @@ struct InstructionRange {
range: SourceRange,
instructions: Range<usize>,
}

enum MatchExpressions {
Single(ValueOrSource),
Tuple(Vec<ValueOrSource>),
}
21 changes: 5 additions & 16 deletions src/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2368,7 +2368,7 @@ fn parse_pattern(
tokens: &mut TokenReader<'_>,
config: &ParserConfig<'_>,
) -> Result<Option<Ranged<Pattern>>, Ranged<Error>> {
let Some(kind) = parse_pattern_kind(tokens, true)? else {
let Some(kind) = parse_pattern_kind(tokens)? else {
return Ok(None);
};

Expand All @@ -2390,7 +2390,6 @@ fn parse_pattern(
#[allow(clippy::too_many_lines)]
fn parse_pattern_kind(
tokens: &mut TokenReader<'_>,
top_level: bool,
) -> Result<Option<Ranged<PatternKind>>, Ranged<Error>> {
let Some(indicator) = tokens.peek() else {
return Ok(None);
Expand Down Expand Up @@ -2471,17 +2470,7 @@ fn parse_pattern_kind(
}
Token::Open(kind @ (Paired::Paren | Paired::Bracket)) => {
tokens.next()?;
let mut pattern =
parse_tuple_destructure_pattern(indicator.range().start, *kind, tokens)?;
// TODO move this to pattern compilation, checking to see if the
// argument is a tuple or not
if top_level && *kind == Paired::Bracket {
pattern = Ranged::new(
pattern.range(),
PatternKind::DestructureTuple(vec![pattern]),
);
}
pattern
parse_tuple_destructure_pattern(indicator.range().start, *kind, tokens)?
}
Token::Open(Paired::Brace) => {
tokens.next()?;
Expand All @@ -2493,7 +2482,7 @@ fn parse_pattern_kind(

while tokens.peek_token() == Some(Token::Char('|')) {
tokens.next()?;
let Some(rhs) = parse_pattern_kind(tokens, top_level)? else {
let Some(rhs) = parse_pattern_kind(tokens)? else {
return Err(tokens.ranged(tokens.last_index.., Error::ExpectedPattern));
};
pattern = tokens.ranged(
Expand All @@ -2511,7 +2500,7 @@ fn parse_tuple_destructure_pattern(
tokens: &mut TokenReader<'_>,
) -> Result<Ranged<PatternKind>, Ranged<Error>> {
let mut patterns = Vec::new();
while let Some(pattern) = parse_pattern_kind(tokens, false)? {
while let Some(pattern) = parse_pattern_kind(tokens)? {
patterns.push(pattern);

if tokens.peek_token() == Some(Token::Char(',')) {
Expand Down Expand Up @@ -2566,7 +2555,7 @@ fn parse_map_destructure_pattern(
return Err(colon.map(|_| Error::ExpectedColon));
}

let Some(value) = parse_pattern_kind(tokens, false)? else {
let Some(value) = parse_pattern_kind(tokens)? else {
return Err(tokens.ranged(tokens.last_index.., Error::ExpectedPattern));
};
entries.push(tokens.ranged(key.range().start.., EntryPattern { key, value }));
Expand Down
21 changes: 20 additions & 1 deletion tests/cases/match.rsn
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ let_else_doesnt_diverge: {
}

map_tuple_destructure: {
src: r#"let ["a", 1] = {"a": 1}"#,
src: r#"let (["a", 1]) = {"a": 1}"#,
output: Bool(true),
}

Expand All @@ -100,6 +100,11 @@ tuple_mismatch: {
),
}

list_match: {
src: r#"let ["a", a] = ["a", 1]; a"#,
output: Int(1),
}

map_match: {
src: r#"let {"a": a} = {"a": 1}; a"#,
output: Int(1),
Expand All @@ -116,4 +121,18 @@ tuple_remaining: {
(1, ...) => true,
}"#,
output: Bool(true),
}

match_problem: {
src: r#"
match (1, 2) {
(1, 2) => true,
};

let a = (1, 2);
match a {
(1, 2) => true,
}
"#,
output: Bool(true),
}
Loading

0 comments on commit bea9518

Please sign in to comment.