Skip to content

Commit

Permalink
Failing test for InMemoryEngine Where with no constant
Browse files Browse the repository at this point in the history
  • Loading branch information
Pete Forrest authored and markjerz committed Nov 16, 2022
1 parent 573a313 commit a0b7c08
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
25 changes: 25 additions & 0 deletions Dashing.Tests/Engine/InMemory/InMemoryEngineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ public void WhereNotFetchedWorks() {
Assert.True(comments.First().Post.Author.Username == null);
}

[Fact]
public void WhereComparisonWithNoConstantWorks() {
var session = this.GetSession();
var commentsByAuthor = session.Query<Comment>()
.Where(c => c.User == c.Post.Author)
.ToArray();
Assert.Equal(2, commentsByAuthor.Length);
}

[Fact]
public void TestConfigWorks() {
var config = new TestConfiguration();
Expand Down Expand Up @@ -185,6 +194,22 @@ public void PagedWorks() {
Assert.Equal(4, thirdFourthComments.Items.ElementAt(1).CommentId);
}

[Fact]
public void DeleteSameWorks() {
var session = this.GetSession();
var blog1 = new Blog();
var blog2 = new Blog();
session.Insert(blog1);
session.Insert(blog2);
var pair = new Pair { Left = blog1, Right = blog2 };
session.Insert(pair);
var samePair = new Pair { Left = blog1, Right = blog1 };
session.Insert(samePair);
Assert.Equal(2, session.Query<Pair>().Count());
session.Delete<Pair>(p => p.Left == p.Right);
Assert.Single(session.Query<Pair>());
}

private ISession GetSession() {
var sessionCreator = new InMemoryDatabase(new TestConfiguration());
var session = sessionCreator.BeginSession();
Expand Down
9 changes: 9 additions & 0 deletions Dashing.Tests/Engine/InMemory/TestDomain/Pair.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Dashing.Tests.Engine.InMemory.TestDomain {
public class Pair {
public int PairId { get; set; }

public Blog Left { get; set; }

public Blog Right { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ public void NullableEqualsNullOnParentWithOrWorks() {
Assert.Equal(expectedResult.ToDebugString(), rewrittenClause.ToDebugString());
}

[Fact]
public void BinaryComparisonParametersBothSides() {
Expression<Func<Pair, bool>> exp = c => c.Left == c.Right;
Expression<Func<Pair, bool>> expectedResult = c => c.Left != null && c.Right != null && c.Left == c.Right;
var rewriter = new WhereClauseNullCheckRewriter();
var rewrittenClause = rewriter.Rewrite(exp);
Assert.Equal(expectedResult.ToDebugString(), rewrittenClause.ToDebugString());
}

public class CourseType {
public virtual int CourseTypeId { get; set; }
}
Expand Down
54 changes: 35 additions & 19 deletions Dashing/Engine/InMemory/WhereClauseNullCheckRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,39 @@ protected override Expression VisitBinary(BinaryExpression node) {
this.Visit(node.Left);
}

// visit right hand side
var leftHandSideIsNull = this.isNullCheck;
var leftHandSideContainsNullable = this.containsNullable;
this.treeHasParameter = false;
this.isNullCheck = false;
this.containsNullable = false;
var leftHandSideNullCheckExpressions = this.nullCheckExpressions;
var leftHandSideHasParameter = this.treeHasParameter;
this.ResetVariables();

// visit right hand side
if (node.Right.NodeType == ExpressionType.Convert) {
this.Visit(((UnaryExpression)node.Right).Operand);
}
else {
this.Visit(node.Right);
}

if ((leftHandSideIsNull || this.isNullCheck) && !leftHandSideContainsNullable && !this.containsNullable) {
// we're checking null somewhere (e.g. e.Post == null) so we should remove that last null check (unless it's got a nullable inside
if (this.nullCheckExpressions.Count > 0) {
this.nullCheckExpressions.RemoveAt(this.nullCheckExpressions.Count - 1);
}
var rightHandSideIsNull = this.isNullCheck;
var rightHandSideContainsNullable = this.containsNullable;
var rightHandSideNullCheckExpressions = this.nullCheckExpressions;
var rightHandSideHasParameter = this.treeHasParameter;

if (leftHandSideIsNull && !rightHandSideContainsNullable && rightHandSideNullCheckExpressions.Count > 0) {
rightHandSideNullCheckExpressions.RemoveAt(
rightHandSideNullCheckExpressions.Count - 1);
}

if (rightHandSideIsNull && !leftHandSideContainsNullable && leftHandSideNullCheckExpressions.Count > 0) {
leftHandSideNullCheckExpressions.RemoveAt(
leftHandSideNullCheckExpressions.Count - 1);
}


return this.CombineExpressions(node);
var combined = CombineExpressions(node, leftHandSideNullCheckExpressions.Union(rightHandSideNullCheckExpressions));
this.nullCheckExpressions.Clear();
return combined;
}

if (isInAndOrOrExpression) {
Expand Down Expand Up @@ -208,26 +220,30 @@ private Expression ModifyExpression(Expression leftExpr, Expression rightExpr, E
}
}

private Expression CombineExpressions(Expression exp) {
if (!Enumerable.Any<Expression>(this.nullCheckExpressions)) {
private static Expression CombineExpressions(Expression exp, IEnumerable<Expression> expressions) {
if (!expressions.Any()) {
return exp;
}

if (this.nullCheckExpressions.Count == 1) {
var expr = Expression.AndAlso(Enumerable.First<Expression>(this.nullCheckExpressions), exp);
this.nullCheckExpressions.Clear();
if (expressions.Count() == 1) {
var expr = Expression.AndAlso(Enumerable.First<Expression>(expressions), exp);
return expr;
}

var combinedExpr = Expression.AndAlso(Enumerable.First<Expression>(this.nullCheckExpressions), Enumerable.ElementAt<Expression>(this.nullCheckExpressions, 1));
for (var i = 2; i < this.nullCheckExpressions.Count; i++) {
combinedExpr = Expression.AndAlso(combinedExpr, Enumerable.ElementAt<Expression>(this.nullCheckExpressions, i));
var combinedExpr = Expression.AndAlso(Enumerable.First<Expression>(expressions), Enumerable.ElementAt<Expression>(expressions, 1));
for (var i = 2; i < expressions.Count(); i++) {
combinedExpr = Expression.AndAlso(combinedExpr, Enumerable.ElementAt<Expression>(expressions, i));
}

this.nullCheckExpressions.Clear();
return Expression.AndAlso(combinedExpr, exp);
}

private Expression CombineExpressions(Expression exp) {
var combined = CombineExpressions(exp, this.nullCheckExpressions);
this.nullCheckExpressions.Clear();
return combined;
}

private bool IsInBinaryComparisonExpression(ExpressionType nodeType) {
switch (nodeType) {
case ExpressionType.Equal:
Expand Down

0 comments on commit a0b7c08

Please sign in to comment.