diff --git a/src/main/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoWhen.java b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoWhen.java index 3ef35a195..f830ecc6d 100644 --- a/src/main/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoWhen.java +++ b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoWhen.java @@ -15,79 +15,180 @@ */ package org.openrewrite.java.testing.jmockit; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.regex.Pattern; + import lombok.EqualsAndHashCode; import lombok.Value; +import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; import org.openrewrite.Preconditions; import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; +import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; -import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaCoordinates; +import org.openrewrite.java.tree.JavaType; import org.openrewrite.java.tree.Statement; @Value @EqualsAndHashCode(callSuper = false) public class JMockitExpectationsToMockitoWhen extends Recipe { - @Override - public String getDisplayName() { - return "Rewrite JMockit Expectations"; - } - - @Override - public String getDescription() { - return "Rewrites JMockit `Expectations` to `Mockito.when`."; - } - - @Override - public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("mockit.*", false), - new RewriteExpectationsVisitor()); - } - - private static class RewriteExpectationsVisitor extends JavaVisitor { @Override - public J visitNewClass(J.NewClass newClass, ExecutionContext executionContext) { - J.NewClass nc = (J.NewClass) super.visitNewClass(newClass, executionContext); - if (!(nc.getClazz() instanceof J.Identifier)) { - return nc; - } - J.Identifier clazz = (J.Identifier) nc.getClazz(); - if (!clazz.getSimpleName().equals("Expectations")) { - return nc; - } - - // empty Expectations block is considered invalid - assert nc.getBody() != null : "Expectations block is empty"; - - // prepare the statements for moving - J.Block innerBlock = (J.Block) nc.getBody().getStatements().get(0); - - // TODO: handle multiple mock statements - Statement mockInvocation = innerBlock.getStatements().get(0); - Expression result = ((J.Assignment) innerBlock.getStatements().get(1)).getAssignment(); - - // apply the template and replace the `new Expectations()` statement coordinates - // TODO: handle exception results with another template - J.MethodInvocation newMethod = JavaTemplate.builder("when(#{any()}).thenReturn(#{});") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(executionContext, "mockito-core-3.12")) - .staticImports("org.mockito.Mockito.when") - .build() - .apply( - getCursor(), - nc.getCoordinates().replace(), - mockInvocation, - result - ); - - // handle import changes - maybeAddImport("org.mockito.Mockito", "when"); - maybeRemoveImport("mockit.Expectations"); - - return newMethod.withPrefix(nc.getPrefix()); + public String getDisplayName() { + return "Rewrite JMockit Expectations"; + } + + @Override + public String getDescription() { + return "Rewrites JMockit `Expectations` to `Mockito.when`."; + } + + @Override + public TreeVisitor getVisitor() { + return Preconditions.check(new UsesType<>("mockit.*", false), + new RewriteExpectationsVisitor()); + } + + private static class RewriteExpectationsVisitor extends JavaIsoVisitor { + + private static final String PRIMITIVE_RESULT_TEMPLATE = "when(#{any()}).thenReturn(#{});"; + private static final String OBJECT_RESULT_TEMPLATE = "when(#{any()}).thenReturn(#{any(java.lang.String)});"; + private static final String EXCEPTION_RESULT_TEMPLATE = "when(#{any()}).thenThrow(#{any()});"; + private static final Pattern EXPECTATIONS_PATTERN = Pattern.compile("mockit.Expectations"); + + // the LST element that is being updated when applying one of the java templates + private Object cursorLocation; + + // the coordinates where the next statement should be inserted + private JavaCoordinates coordinates; + + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) { + J.MethodDeclaration md = super.visitMethodDeclaration(methodDeclaration, ctx); + if (md.getBody() == null) { + return md; + } + cursorLocation = md.getBody(); + J.Block newBody = md.getBody(); + List statements = md.getBody().getStatements(); + + // iterate over each statement in the method body, find Expectations blocks and rewrite them + for (int i = 0; i < statements.size(); i++) { + Statement s = statements.get(i); + if (!(s instanceof J.NewClass)) { + continue; + } + J.NewClass nc = (J.NewClass) s; + if (!(nc.getClazz() instanceof J.Identifier)) { + continue; + } + J.Identifier clazz = (J.Identifier) nc.getClazz(); + if (clazz.getType() == null || !clazz.getType().isAssignableFrom(EXPECTATIONS_PATTERN)) { + continue; + } + // empty Expectations block is considered invalid + assert nc.getBody() != null && !nc.getBody().getStatements().isEmpty() : "Expectations block is empty"; + // Expectations block should be composed of a block within another block + assert nc.getBody().getStatements().size() == 1 : "Expectations block is malformed"; + + // we have a valid Expectations block, update imports and rewrite with Mockito statements + maybeAddImport("org.mockito.Mockito", "when"); + maybeRemoveImport("mockit.Expectations"); + + // the first coordinates are the coordinates the Expectations block, replacing it + coordinates = nc.getCoordinates().replace(); + J.Block expectationsBlock = (J.Block) nc.getBody().getStatements().get(0); + List expectationStatements = expectationsBlock.getStatements(); + List templateParams = new ArrayList<>(); + + // iterate over the expectations statements and rebuild the method body + for (Statement expectationStatement : expectationStatements) { + // TODO: handle void methods (including final statement) + + // TODO: handle additional jmockit expectations features + + if (expectationStatement instanceof J.MethodInvocation) { + if (!templateParams.isEmpty()) { + // apply template to build new method body + newBody = buildNewBody(ctx, templateParams, i); + + // reset template params for next expectation + templateParams = new ArrayList<>(); + } + templateParams.add(expectationStatement); + } else { + // assignment + templateParams.add(((J.Assignment) expectationStatement).getAssignment()); + } + } + + // handle the last statement + if (!templateParams.isEmpty()) { + newBody = buildNewBody(ctx, templateParams, i); + } + } + + return md.withBody(newBody); + } + + private J.Block buildNewBody(ExecutionContext ctx, List templateParams, int newStatementIndex) { + Expression result = (Expression) templateParams.get(1); + String template = getTemplate(result); + + J.Block newBody = JavaTemplate.builder(template) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-3.12")) + .staticImports("org.mockito.Mockito.when") + .build() + .apply( + new Cursor(getCursor(), cursorLocation), + coordinates, + templateParams.toArray() + ); + + List newStatements = new ArrayList<>(newBody.getStatements().size()); + for (int i = 0; i < newBody.getStatements().size(); i++) { + Statement s = newBody.getStatements().get(i); + if (i == newStatementIndex) { + // next statement coordinates are immediately after the statement just added + coordinates = s.getCoordinates().after(); + } + newStatements.add(s); + } + newBody = newBody.withStatements(newStatements); + + // cursor location is now the new body + cursorLocation = newBody; + + return newBody; + } + + /* + * Based on the result type, we need to use a different template. + */ + private static String getTemplate(Expression result) { + String template; + JavaType resultType = Objects.requireNonNull(result.getType()); + if (resultType instanceof JavaType.Primitive) { + template = PRIMITIVE_RESULT_TEMPLATE; + } else if (resultType instanceof JavaType.Class) { + Class resultClass; + try { + resultClass = Class.forName(((JavaType.Class) resultType).getFullyQualifiedName()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + template = Throwable.class.isAssignableFrom(resultClass) ? EXCEPTION_RESULT_TEMPLATE : OBJECT_RESULT_TEMPLATE; + } else { + throw new IllegalStateException("Unexpected value: " + result.getType()); + } + return template; + } } - } } diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitToMockitoTest.java index dd7ae8ea9..b8658ec10 100644 --- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitToMockitoTest.java +++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitToMockitoTest.java @@ -43,7 +43,7 @@ public void defaults(RecipeSpec spec) { } @Test - void jMockitExpectationsToMockitoWhen() { + void jMockitExpectationsToMockitoWhenNullResult() { //language=java rewriteRun( java( @@ -66,16 +66,16 @@ public String getSomeField() { @ExtendWith(JMockitExtension.class) class MyTest { - @Mocked - MyObject myObject; + @Mocked + MyObject myObject; - void test() { - new Expectations() {{ - myObject.getSomeField(); - result = null; - }}; - assertNull(myObject.getSomeField()); - } + void test() { + new Expectations() {{ + myObject.getSomeField(); + result = null; + }}; + assertNull(myObject.getSomeField()); + } } """, """ @@ -88,13 +88,325 @@ void test() { @ExtendWith(MockitoExtension.class) class MyTest { - @Mock - MyObject myObject; + @Mock + MyObject myObject; + + void test() { + when(myObject.getSomeField()).thenReturn(null); + assertNull(myObject.getSomeField()); + } + } + """ + ) + ); + } + + @Test + void jMockitExpectationsToMockitoWhenIntResult() { + //language=java + rewriteRun( + java( + """ + class MyObject { + public int getSomeField() { + return 0; + } + } + """ + ), + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + MyObject myObject; + + void test() { + new Expectations() {{ + myObject.getSomeField(); + result = 10; + }}; + assertEquals(10, myObject.getSomeField()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + MyObject myObject; + + void test() { + when(myObject.getSomeField()).thenReturn(10); + assertEquals(10, myObject.getSomeField()); + } + } + """ + ) + ); + } + + @Test + void jMockitExpectationsToMockitoWhenVariableResult() { + //language=java + rewriteRun( + java( + """ + class MyObject { + public String getSomeField() { + return "X"; + } + } + """ + ), + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + MyObject myObject; + + String expected = "expected"; + + void test() { + new Expectations() {{ + myObject.getSomeField(); + result = expected; + }}; + assertEquals(expected, myObject.getSomeField()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + MyObject myObject; + + String expected = "expected"; + + void test() { + when(myObject.getSomeField()).thenReturn(expected); + assertEquals(expected, myObject.getSomeField()); + } + } + """ + ) + ); + } - void test() { - when(myObject.getSomeField()).thenReturn(null); - assertNull(myObject.getSomeField()); - } + @Test + void jMockitExpectationsToMockitoWhenNewClassResult() { + //language=java + rewriteRun( + java( + """ + class MyObject { + public String getSomeField() { + return "X"; + } + } + """ + ), + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertNotNull; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + MyObject myObject; + + void test() { + new Expectations() {{ + myObject.getSomeField(); + result = new Object(); + }}; + assertNotNull(myObject.getSomeField()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertNotNull; + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + MyObject myObject; + + void test() { + when(myObject.getSomeField()).thenReturn(new Object()); + assertNotNull(myObject.getSomeField()); + } + } + """ + ) + ); + } + + @Test + void jMockitExpectationsToMockitoWhenExceptionResult() { + //language=java + rewriteRun( + java( + """ + class MyObject { + public String getSomeField() { + return "X"; + } + } + """ + ), + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + MyObject myObject; + + void test() throws RuntimeException { + new Expectations() {{ + myObject.getSomeField(); + result = new RuntimeException(); + }}; + myObject.getSomeField(); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + MyObject myObject; + + void test() throws RuntimeException { + when(myObject.getSomeField()).thenThrow(new RuntimeException()); + myObject.getSomeField(); + } + } + """ + ) + ); + } + + @Test + void jMockitExpectationsToMockitoWhenMultipleStatements() { + //language=java + rewriteRun( + java( + """ + class MyObject { + public int getSomeField() { + return 0; + } + public Object getSomeObjectField() { + return new Object(); + } + } + """ + ), + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.junit.jupiter.api.Assertions.assertNull; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + MyObject myObject; + + @Mocked + MyObject myOtherObject; + + void test() { + new Expectations() {{ + myObject.getSomeField(); + result = 10; + myOtherObject.getSomeObjectField(); + result = null; + }}; + assertEquals(10, myObject.getSomeField()); + assertNull(myOtherObject.getSomeObjectField()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.junit.jupiter.api.Assertions.assertNull; + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + MyObject myObject; + + @Mock + MyObject myOtherObject; + + void test() { + when(myObject.getSomeField()).thenReturn(10); + when(myOtherObject.getSomeObjectField()).thenReturn(null); + assertEquals(10, myObject.getSomeField()); + assertNull(myOtherObject.getSomeObjectField()); + } } """ )