diff --git a/src/main/java/org/openrewrite/python/PythonParser.java b/src/main/java/org/openrewrite/python/PythonParser.java index b04356d8..f8447647 100644 --- a/src/main/java/org/openrewrite/python/PythonParser.java +++ b/src/main/java/org/openrewrite/python/PythonParser.java @@ -181,7 +181,8 @@ private static com.jetbrains.python.psi.LanguageLevel mapLanguageLevel(LanguageL @Override public boolean accept(Path path) { - return path.toString().endsWith(".py"); + String pathString = path.toString(); + return pathString.endsWith(".py") || pathString.endsWith(".pyi"); } @Override diff --git a/src/main/java/org/openrewrite/python/internal/PsiPythonMapper.java b/src/main/java/org/openrewrite/python/internal/PsiPythonMapper.java index f9138417..639146f5 100644 --- a/src/main/java/org/openrewrite/python/internal/PsiPythonMapper.java +++ b/src/main/java/org/openrewrite/python/internal/PsiPythonMapper.java @@ -2320,7 +2320,7 @@ public J.Literal mapNoneLiteral(PyNoneLiteralExpression element) { spaceBefore(element), EMPTY, null, - "null", + element.isEllipsis() ? "..." : "null", null, JavaType.Primitive.Null ); diff --git a/src/main/java/org/openrewrite/python/internal/PythonPrinter.java b/src/main/java/org/openrewrite/python/internal/PythonPrinter.java index 639925b5..bee7608b 100755 --- a/src/main/java/org/openrewrite/python/internal/PythonPrinter.java +++ b/src/main/java/org/openrewrite/python/internal/PythonPrinter.java @@ -610,7 +610,9 @@ public J visitLiteral(J.Literal literal, PrintOutputCapture
p) { if (literal.getMarkers().findFirst(ImplicitNone.class).isPresent()) { literal = literal.withValueSource(""); } else { - literal = literal.withValueSource("None"); + if ("null".equals(literal.getValueSource())) { + literal = literal.withValueSource("None"); + } } } diff --git a/src/test/java/org/openrewrite/python/tree/ImportTest.java b/src/test/java/org/openrewrite/python/tree/ImportTest.java index a77dfe67..7881c39e 100644 --- a/src/test/java/org/openrewrite/python/tree/ImportTest.java +++ b/src/test/java/org/openrewrite/python/tree/ImportTest.java @@ -16,138 +16,163 @@ package org.openrewrite.python.tree; import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.openrewrite.Issue; +import org.openrewrite.internal.lang.Nullable; import org.openrewrite.test.RewriteTest; - -import static org.openrewrite.python.tree.ParserAssertions.python; +import org.openrewrite.test.SourceSpecs; class ImportTest implements RewriteTest { - @ParameterizedTest - //language=py - @ValueSource(strings = { - "import math", - "import math", - }) - void simpleImport(@Language("py") String arg) { - rewriteRun( - python(arg) - ); - } - @ParameterizedTest - //language=py - @ValueSource(strings = { - "import math as math2", - "import math as math2", - "import math as math2", - "import math as math2", - }) - void simpleImportAlias(@Language("py") String arg) { - rewriteRun( - python(arg) - ); - } + abstract class Base { + abstract SourceSpecs python(@Language("py") @Nullable String before); - @ParameterizedTest - //language=py - @ValueSource(strings = { - "from . import foo", - "from . import foo", - "from . import foo", - "from . import foo", - "from .mod import foo", - "from .mod import foo", - "from .mod import foo", - "from .mod import foo", - "from ...mod import foo", - "from ....mod import foo", - }) - void localImport(@Language("py") String arg) { - rewriteRun( - python(arg) - ); - } + @ParameterizedTest + //language=py + @ValueSource(strings = { + "import math", + "import math", + }) + void simpleImport(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } - @ParameterizedTest - //language=py - @ValueSource(strings = { - "from math import ceil", - "from math import ceil", - "from math import ceil", - "from math import ceil", - }) - void qualifiedImport(@Language("py") String arg) { - rewriteRun( - python(arg) - ); - } + @ParameterizedTest + //language=py + @ValueSource(strings = { + "import math as math2", + "import math as math2", + "import math as math2", + "import math as math2", + }) + void simpleImportAlias(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } - @ParameterizedTest - //language=py - @ValueSource(strings = { - "from . import foo as foo2", - "from . import foo as foo2", - "from . import foo as foo2", - "from . import foo as foo2", - "from . import foo as foo2", - "from . import foo as foo2", - }) - void localImportAlias(@Language("py") String arg) { - rewriteRun( - python(arg) - ); - } - @Issue("https://github.com/openrewrite/rewrite-python/issues/35") - @Test - void multipleImports() { - rewriteRun( - python( - """ - import sys - - import math - """) - ); + @ParameterizedTest + //language=py + @ValueSource(strings = { + "from . import foo", + "from . import foo", + "from . import foo", + "from . import foo", + "from .mod import foo", + "from .mod import foo", + "from .mod import foo", + "from .mod import foo", + "from ...mod import foo", + "from ....mod import foo", + }) + void localImport(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } + + @ParameterizedTest + //language=py + @ValueSource(strings = { + "from math import ceil", + "from math import ceil", + "from math import ceil", + "from math import ceil", + }) + void qualifiedImport(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } + + @ParameterizedTest + //language=py + @ValueSource(strings = { + "from . import foo as foo2", + "from . import foo as foo2", + "from . import foo as foo2", + "from . import foo as foo2", + "from . import foo as foo2", + "from . import foo as foo2", + }) + void localImportAlias(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-python/issues/35") + @Test + void multipleImports() { + rewriteRun( + python( + """ + import sys + + import math + """) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-python/issues/35") + @Test + void enclosedInParens() { + rewriteRun( + python( + """ + from math import ( + sin, + cos + ) + """) + ); + } + + @SuppressWarnings("TrailingWhitespacesInTextBlock") + @ParameterizedTest + //language=py + @ValueSource(strings = { + "from math import sin, cos # stuff\n\n", + "from math import sin as sin2, cos\n", + "from math import sin as sin2, cos as cos2\n", + """ + from math import ( + sin, + cos + ) + """, + }) + void multipleImport(@Language("py") String arg) { + rewriteRun( + python(arg) + ); + } } - @Issue("https://github.com/openrewrite/rewrite-python/issues/35") - @Test - void enclosedInParens() { - rewriteRun( - python( - """ - from math import ( - sin, - cos - ) - """) - ); + @Nested + class PythonFile extends Base { + + @Override + SourceSpecs python(@Nullable String before) { + return org.openrewrite.python.tree.ParserAssertions.python(before); + } } - @SuppressWarnings("TrailingWhitespacesInTextBlock") - @ParameterizedTest - //language=py - @ValueSource(strings = { - "from math import sin, cos # stuff\n\n", - "from math import sin as sin2, cos\n", - "from math import sin as sin2, cos as cos2\n", - """ - from math import ( - sin, - cos - ) - """, - }) - void multipleImport(@Language("py") String arg) { - rewriteRun( - python(arg) - ); + @Nested + class PythonStub extends Base { + + @Override + SourceSpecs python(@Nullable String before) { + return org.openrewrite.python.tree.ParserAssertions.python(before, spec -> spec.path("file.pyi")); + } } } diff --git a/src/test/java/org/openrewrite/python/tree/StubTest.java b/src/test/java/org/openrewrite/python/tree/StubTest.java new file mode 100644 index 00000000..8bf4deed --- /dev/null +++ b/src/test/java/org/openrewrite/python/tree/StubTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2021 the original author or authors. + *
+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *
+ * https://www.apache.org/licenses/LICENSE-2.0 + *
+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.python.tree; + +import org.junit.jupiter.api.Test; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.python.Assertions.python; + +public class StubTest implements RewriteTest { + + @Test + void simpleStub() { + rewriteRun( + python(""" + # Variables with annotations do not need to be assigned a value. + # So by convention, we omit them in the stub file. + x: int + + # Function bodies cannot be completely removed. By convention, + # we replace them with `...` instead of the `pass` statement. + def func_1(code: str) -> int: ... + + # We can do the same with default arguments. + def func_2(a: int, b: int = ...) -> int: ... + """, + spec -> spec.path("file.pyi") + ) + ); + } +}