From a0b7c08dbd4fc2df1fea8f84fd7d34f673a1433d Mon Sep 17 00:00:00 2001 From: Pete Forrest Date: Fri, 7 Oct 2022 20:03:39 +0100 Subject: [PATCH] Failing test for InMemoryEngine Where with no constant --- .../Engine/InMemory/InMemoryEngineTests.cs | 25 +++++++++ .../Engine/InMemory/TestDomain/Pair.cs | 9 ++++ .../WhereClauseNullCheckRewriterTests.cs | 9 ++++ .../InMemory/WhereClauseNullCheckRewriter.cs | 54 ++++++++++++------- 4 files changed, 78 insertions(+), 19 deletions(-) create mode 100644 Dashing.Tests/Engine/InMemory/TestDomain/Pair.cs diff --git a/Dashing.Tests/Engine/InMemory/InMemoryEngineTests.cs b/Dashing.Tests/Engine/InMemory/InMemoryEngineTests.cs index bd1d4056..f63d8620 100644 --- a/Dashing.Tests/Engine/InMemory/InMemoryEngineTests.cs +++ b/Dashing.Tests/Engine/InMemory/InMemoryEngineTests.cs @@ -51,6 +51,15 @@ public void WhereNotFetchedWorks() { Assert.True(comments.First().Post.Author.Username == null); } + [Fact] + public void WhereComparisonWithNoConstantWorks() { + var session = this.GetSession(); + var commentsByAuthor = session.Query() + .Where(c => c.User == c.Post.Author) + .ToArray(); + Assert.Equal(2, commentsByAuthor.Length); + } + [Fact] public void TestConfigWorks() { var config = new TestConfiguration(); @@ -185,6 +194,22 @@ public void PagedWorks() { Assert.Equal(4, thirdFourthComments.Items.ElementAt(1).CommentId); } + [Fact] + public void DeleteSameWorks() { + var session = this.GetSession(); + var blog1 = new Blog(); + var blog2 = new Blog(); + session.Insert(blog1); + session.Insert(blog2); + var pair = new Pair { Left = blog1, Right = blog2 }; + session.Insert(pair); + var samePair = new Pair { Left = blog1, Right = blog1 }; + session.Insert(samePair); + Assert.Equal(2, session.Query().Count()); + session.Delete(p => p.Left == p.Right); + Assert.Single(session.Query()); + } + private ISession GetSession() { var sessionCreator = new InMemoryDatabase(new TestConfiguration()); var session = sessionCreator.BeginSession(); diff --git a/Dashing.Tests/Engine/InMemory/TestDomain/Pair.cs b/Dashing.Tests/Engine/InMemory/TestDomain/Pair.cs new file mode 100644 index 00000000..01824054 --- /dev/null +++ b/Dashing.Tests/Engine/InMemory/TestDomain/Pair.cs @@ -0,0 +1,9 @@ +namespace Dashing.Tests.Engine.InMemory.TestDomain { + public class Pair { + public int PairId { get; set; } + + public Blog Left { get; set; } + + public Blog Right { get; set; } + } +} diff --git a/Dashing.Tests/Engine/InMemory/WhereClauseNullCheckRewriterTests.cs b/Dashing.Tests/Engine/InMemory/WhereClauseNullCheckRewriterTests.cs index a376b14f..d5d48a5c 100644 --- a/Dashing.Tests/Engine/InMemory/WhereClauseNullCheckRewriterTests.cs +++ b/Dashing.Tests/Engine/InMemory/WhereClauseNullCheckRewriterTests.cs @@ -192,6 +192,15 @@ public void NullableEqualsNullOnParentWithOrWorks() { Assert.Equal(expectedResult.ToDebugString(), rewrittenClause.ToDebugString()); } + [Fact] + public void BinaryComparisonParametersBothSides() { + Expression> exp = c => c.Left == c.Right; + Expression> expectedResult = c => c.Left != null && c.Right != null && c.Left == c.Right; + var rewriter = new WhereClauseNullCheckRewriter(); + var rewrittenClause = rewriter.Rewrite(exp); + Assert.Equal(expectedResult.ToDebugString(), rewrittenClause.ToDebugString()); + } + public class CourseType { public virtual int CourseTypeId { get; set; } } diff --git a/Dashing/Engine/InMemory/WhereClauseNullCheckRewriter.cs b/Dashing/Engine/InMemory/WhereClauseNullCheckRewriter.cs index ce5b5c4f..2fe63787 100644 --- a/Dashing/Engine/InMemory/WhereClauseNullCheckRewriter.cs +++ b/Dashing/Engine/InMemory/WhereClauseNullCheckRewriter.cs @@ -42,12 +42,13 @@ protected override Expression VisitBinary(BinaryExpression node) { this.Visit(node.Left); } - // visit right hand side var leftHandSideIsNull = this.isNullCheck; var leftHandSideContainsNullable = this.containsNullable; - this.treeHasParameter = false; - this.isNullCheck = false; - this.containsNullable = false; + var leftHandSideNullCheckExpressions = this.nullCheckExpressions; + var leftHandSideHasParameter = this.treeHasParameter; + this.ResetVariables(); + + // visit right hand side if (node.Right.NodeType == ExpressionType.Convert) { this.Visit(((UnaryExpression)node.Right).Operand); } @@ -55,14 +56,25 @@ protected override Expression VisitBinary(BinaryExpression node) { this.Visit(node.Right); } - if ((leftHandSideIsNull || this.isNullCheck) && !leftHandSideContainsNullable && !this.containsNullable) { - // we're checking null somewhere (e.g. e.Post == null) so we should remove that last null check (unless it's got a nullable inside - if (this.nullCheckExpressions.Count > 0) { - this.nullCheckExpressions.RemoveAt(this.nullCheckExpressions.Count - 1); - } + var rightHandSideIsNull = this.isNullCheck; + var rightHandSideContainsNullable = this.containsNullable; + var rightHandSideNullCheckExpressions = this.nullCheckExpressions; + var rightHandSideHasParameter = this.treeHasParameter; + + if (leftHandSideIsNull && !rightHandSideContainsNullable && rightHandSideNullCheckExpressions.Count > 0) { + rightHandSideNullCheckExpressions.RemoveAt( + rightHandSideNullCheckExpressions.Count - 1); + } + + if (rightHandSideIsNull && !leftHandSideContainsNullable && leftHandSideNullCheckExpressions.Count > 0) { + leftHandSideNullCheckExpressions.RemoveAt( + leftHandSideNullCheckExpressions.Count - 1); } + - return this.CombineExpressions(node); + var combined = CombineExpressions(node, leftHandSideNullCheckExpressions.Union(rightHandSideNullCheckExpressions)); + this.nullCheckExpressions.Clear(); + return combined; } if (isInAndOrOrExpression) { @@ -208,26 +220,30 @@ private Expression ModifyExpression(Expression leftExpr, Expression rightExpr, E } } - private Expression CombineExpressions(Expression exp) { - if (!Enumerable.Any(this.nullCheckExpressions)) { + private static Expression CombineExpressions(Expression exp, IEnumerable expressions) { + if (!expressions.Any()) { return exp; } - if (this.nullCheckExpressions.Count == 1) { - var expr = Expression.AndAlso(Enumerable.First(this.nullCheckExpressions), exp); - this.nullCheckExpressions.Clear(); + if (expressions.Count() == 1) { + var expr = Expression.AndAlso(Enumerable.First(expressions), exp); return expr; } - var combinedExpr = Expression.AndAlso(Enumerable.First(this.nullCheckExpressions), Enumerable.ElementAt(this.nullCheckExpressions, 1)); - for (var i = 2; i < this.nullCheckExpressions.Count; i++) { - combinedExpr = Expression.AndAlso(combinedExpr, Enumerable.ElementAt(this.nullCheckExpressions, i)); + var combinedExpr = Expression.AndAlso(Enumerable.First(expressions), Enumerable.ElementAt(expressions, 1)); + for (var i = 2; i < expressions.Count(); i++) { + combinedExpr = Expression.AndAlso(combinedExpr, Enumerable.ElementAt(expressions, i)); } - this.nullCheckExpressions.Clear(); return Expression.AndAlso(combinedExpr, exp); } + private Expression CombineExpressions(Expression exp) { + var combined = CombineExpressions(exp, this.nullCheckExpressions); + this.nullCheckExpressions.Clear(); + return combined; + } + private bool IsInBinaryComparisonExpression(ExpressionType nodeType) { switch (nodeType) { case ExpressionType.Equal: