Skip to content

Commit

Permalink
Update AssertJ recipes to current recipe code style
Browse files Browse the repository at this point in the history
  • Loading branch information
jevanlingen committed Nov 5, 2024
1 parent 9118a3c commit 152617c
Show file tree
Hide file tree
Showing 15 changed files with 738 additions and 1,005 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesType;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
Expand All @@ -32,7 +32,10 @@
import java.util.List;

public class JUnitAssertArrayEqualsToAssertThat extends Recipe {
private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions";

private static final String JUNIT = "org.junit.jupiter.api.Assertions";
private static final String ASSERTJ = "org.assertj.core.api.Assertions";
private static final MethodMatcher ASSERT_ARRAY_EQUALS_MATCHER = new MethodMatcher(JUNIT + " assertArrayEquals(..)", true);

@Override
public String getDisplayName() {
Expand All @@ -46,93 +49,77 @@ public String getDescription() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(new UsesType<>(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME, false), new AssertArrayEqualsToAssertThatVisitor());
}

public static class AssertArrayEqualsToAssertThatVisitor extends JavaIsoVisitor<ExecutionContext> {
private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME + " assertArrayEquals(..)");

private JavaParser.Builder<?, ?> assertionsParser;

private JavaParser.Builder<?, ?> assertionsParser(ExecutionContext ctx) {
if (assertionsParser == null) {
assertionsParser = JavaParser.fromJavaVersion()
.classpathFromResources(ctx, "assertj-core-3.24");
}
return assertionsParser;
}


@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
if (!JUNIT_ASSERT_EQUALS.matches(method)) {
return method;
}

List<Expression> args = method.getArguments();
Expression expected = args.get(0);
Expression actual = args.get(1);

// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

if (args.size() == 2) {
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});")
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected);
} else if (args.size() == 3 && !isFloatingPointType(args.get(2))) {
Expression message = args.get(2);
return Preconditions.check(new UsesMethod<>(ASSERT_ARRAY_EQUALS_MATCHER), new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
J.MethodInvocation md = super.visitMethodInvocation(method, ctx);
if (!ASSERT_ARRAY_EQUALS_MATCHER.matches(md)) {
return md;
}

maybeAddImport(ASSERTJ, "assertThat", false);
maybeRemoveImport(JUNIT);

List<Expression> args = md.getArguments();
Expression expected = args.get(0);
Expression actual = args.get(1);
if (args.size() == 2) {
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()});")
.staticImports(ASSERTJ +".assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, expected);
} else if (args.size() == 3 && !isFloatingPointType(args.get(2))) {
Expression message = args.get(2);
JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()});") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()});");
return template
.staticImports(ASSERTJ +".assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, message, expected);
} else if (args.size() == 3) {
maybeAddImport(ASSERTJ, "within", false);
// assert is using floating points with a delta and no message.
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));")
.staticImports(ASSERTJ +".assertThat", ASSERTJ +".within")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), md.getCoordinates().replace(), actual, expected, args.get(2));
}


maybeAddImport(ASSERTJ, "within", false);

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()});") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()});");
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()}, within(#{}));");
return template
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, message, expected);
} else if (args.size() == 3) {
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
// assert is using floating points with a delta and no message.
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));")
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
.staticImports(ASSERTJ +".assertThat", ASSERTJ +".within")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2));
.apply(getCursor(), md.getCoordinates().replace(), actual, message, expected, args.get(2));
}

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
maybeAddImport("org.assertj.core.api.Assertions", "within", false);

JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") :
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(java.util.function.Supplier)}).containsExactly(#{anyArray()}, within(#{}));");
return template
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, message, expected, args.get(2));
}

/**
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private static boolean isFloatingPointType(Expression expression) {

JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName);
/**
* Returns true if the expression's type is either a primitive float/double or their object forms Float/Double
*
* @param expression The expression parsed from the original AST.
* @return true if the type is a floating point number.
*/
private boolean isFloatingPointType(Expression expression) {
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(expression.getType());
if (fullyQualified != null) {
String typeName = fullyQualified.getFullyQualifiedName();
return "java.lang.Double".equals(typeName) || "java.lang.Float".equals(typeName);
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}

JavaType.Primitive parameterType = TypeUtils.asPrimitive(expression.getType());
return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
}
});
}
}
Loading

0 comments on commit 152617c

Please sign in to comment.