diff --git a/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryFixVisitor.java b/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryFixVisitor.java index bf8ff3f..5616b15 100644 --- a/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryFixVisitor.java +++ b/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryFixVisitor.java @@ -15,21 +15,17 @@ */ package org.openrewrite.java.security.xml; -import lombok.AllArgsConstructor; import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; import org.openrewrite.Preconditions; import org.openrewrite.TreeVisitor; -import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.TypeUtils; import javax.xml.XMLConstants; -@AllArgsConstructor -public class TransformerFactoryFixVisitor

extends JavaIsoVisitor

{ +public class TransformerFactoryFixVisitor

extends XmlFactoryVisitor

{ static final MethodMatcher TRANSFORMER_FACTORY_INSTANCE = new MethodMatcher("javax.xml.transform.TransformerFactory new*()"); static final MethodMatcher TRANSFORMER_FACTORY_SET_ATTRIBUTE = new MethodMatcher("javax.xml.transform.TransformerFactory setAttribute(java.lang.String, ..)"); static final MethodMatcher TRANSFORMER_FACTORY_SET_FEATURE = new MethodMatcher("javax.xml.transform.TransformerFactory setFeature(java.lang.String, ..)"); @@ -44,6 +40,16 @@ public class TransformerFactoryFixVisitor

extends JavaIsoVisitor

{ private static final String DISALLOW_MODIFY_FLAG = "DISALLOW_MODIFY_FLAG"; + public TransformerFactoryFixVisitor(ExternalDTDAccumulator acc) { + super( + TRANSFORMER_FACTORY_INSTANCE, + TRANSFORMER_FACTORY_FQN, + TRANSFORMER_FACTORY_INITIALIZATION_METHOD, + TRANSFORMER_FACTORY_VARIABLE_NAME, + acc + ); + } + @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P ctx) { J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx); @@ -78,53 +84,42 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P return cd; } - @Override - public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, P ctx) { - J.VariableDeclarations.NamedVariable v = super.visitVariable(variable, ctx); - if (TypeUtils.isOfClassType(v.getType(), TRANSFORMER_FACTORY_FQN)) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, TRANSFORMER_FACTORY_VARIABLE_NAME, v.getSimpleName()); - } - return v; - } - @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx) { J.MethodInvocation m = super.visitMethodInvocation(method, ctx); - if (TRANSFORMER_FACTORY_INSTANCE.matches(m)) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, TRANSFORMER_FACTORY_INITIALIZATION_METHOD, getCursor().dropParentUntil(J.Block.class::isInstance)); - } else if (TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) { + if (TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) { // If either attribute value is not equal to the empty string, do not make any changes if (m.getArguments().get(1) instanceof J.Literal) { J.Literal string = (J.Literal) m.getArguments().get(1); assert string.getValue() != null; if (!(((String) string.getValue()).isEmpty())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, DISALLOW_MODIFY_FLAG, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(DISALLOW_MODIFY_FLAG); } } J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0); if (ACCESS_EXTERNAL_DTD_NAME.equals(fa.getSimpleName())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, ACCESS_EXTERNAL_DTD_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(ACCESS_EXTERNAL_DTD_NAME); } else if (ACCESS_EXTERNAL_STYLESHEET_NAME.equals(fa.getSimpleName())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, ACCESS_EXTERNAL_STYLESHEET_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(ACCESS_EXTERNAL_STYLESHEET_NAME); } } else if (TRANSFORMER_FACTORY_SET_FEATURE.matches(m)) { // If FEATURE_SECURE_PROCESSING is set to false, do not make any changes if (m.getArguments().get(1) instanceof J.Literal) { J.Literal bool = (J.Literal) m.getArguments().get(1); assert bool.getValue() != null; - if (!((Boolean) bool.getValue())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, DISALLOW_MODIFY_FLAG, getCursor().dropParentUntil(J.Block.class::isInstance)); + if (Boolean.FALSE.equals(bool.getValue())) { + addMessage(DISALLOW_MODIFY_FLAG); } } if (m.getArguments().get(0) instanceof J.FieldAccess) { J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0); if (FEATURE_SECURE_PROCESSING_NAME.equals(fa.getSimpleName())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FEATURE_SECURE_PROCESSING_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(FEATURE_SECURE_PROCESSING_NAME); } } else if (m.getArguments().get(0) instanceof J.Literal) { J.Literal literal = (J.Literal) m.getArguments().get(0); if (XMLConstants.FEATURE_SECURE_PROCESSING.equals(literal.getValue())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FEATURE_SECURE_PROCESSING_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(FEATURE_SECURE_PROCESSING_NAME); } } } @@ -132,6 +127,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx } public static TreeVisitor create(ExternalDTDAccumulator acc) { - return Preconditions.check(new UsesType<>(TRANSFORMER_FACTORY_FQN, true), new TransformerFactoryFixVisitor<>()); + return Preconditions.check(new UsesType<>(TRANSFORMER_FACTORY_FQN, true), new TransformerFactoryFixVisitor<>(acc)); } } diff --git a/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryInsertAttributeStatementVisitor.java b/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryInsertAttributeStatementVisitor.java index 62097e3..ea40cf3 100644 --- a/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryInsertAttributeStatementVisitor.java +++ b/src/main/java/org/openrewrite/java/security/xml/TransformerFactoryInsertAttributeStatementVisitor.java @@ -15,18 +15,12 @@ */ package org.openrewrite.java.security.xml; -import org.openrewrite.Cursor; -import org.openrewrite.java.JavaIsoVisitor; -import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaCoordinates; import org.openrewrite.java.tree.Statement; -public class TransformerFactoryInsertAttributeStatementVisitor

extends JavaIsoVisitor

{ - private final J.Block scope; - private final StringBuilder attributeTemplate = new StringBuilder(); - private final String transformerFactoryVariableName; +import java.util.Collections; +public class TransformerFactoryInsertAttributeStatementVisitor

extends XmlFactoryInsertVisitor

{ public TransformerFactoryInsertAttributeStatementVisitor( J.Block scope, String factoryVariableName, @@ -34,57 +28,30 @@ public TransformerFactoryInsertAttributeStatementVisitor( boolean needsStylesheetsDisabled, boolean needsFeatureSecureProcessing ) { - this.scope = scope; - this.transformerFactoryVariableName = factoryVariableName; + super( + scope, + factoryVariableName, + TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE, + TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_SET_ATTRIBUTE + ); if (needsExternalEntitiesDisabled) { - attributeTemplate.append(transformerFactoryVariableName).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, \"\");"); + getTemplate().append(getFactoryVariableName()).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, \"\");"); } if (needsStylesheetsDisabled) { - attributeTemplate.append(transformerFactoryVariableName).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, \"\");"); + getTemplate().append(getFactoryVariableName()).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, \"\");"); } if (needsFeatureSecureProcessing) { - attributeTemplate.append(transformerFactoryVariableName).append(".setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);"); + getTemplate().append(getFactoryVariableName()).append(".setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);"); } } @Override public J.Block visitBlock(J.Block block, P ctx) { J.Block b = super.visitBlock(block, ctx); - Statement beforeStatement = null; - if (b.isScope(scope)) { - for (int i = b.getStatements().size() - 2; i > -1; i--) { - Statement st = b.getStatements().get(i); - Statement stBefore = b.getStatements().get(i + 1); - if (st instanceof J.MethodInvocation) { - J.MethodInvocation m = (J.MethodInvocation) st; - if (TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE.matches(m) || TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m)) { - beforeStatement = stBefore; - } - } else if (st instanceof J.VariableDeclarations) { - J.VariableDeclarations vd = (J.VariableDeclarations) st; - if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) { - J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer(); - if (m != null && TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE.matches(m)) { - beforeStatement = stBefore; - } - } - } - } - - if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) { - attributeTemplate.insert(0, "{").append("}"); - } - JavaCoordinates insertCoordinates = beforeStatement != null ? - beforeStatement.getCoordinates().before() : - b.getCoordinates().lastStatement(); - b = JavaTemplate - .builder(attributeTemplate.toString()) - .imports("javax.xml.XMLConstants") - .contextSensitive() - .build() - .apply(new Cursor(getCursor().getParent(), b), insertCoordinates); - maybeAddImport("javax.xml.XMLConstants"); + Statement beforeStatement = getInsertStatement(b); + if (b.isScope(getScope())) { + b = updateBlock(b, block, beforeStatement, Collections.singleton("javax.xml.XMLConstants")); } return b; } diff --git a/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertPropertyStatementVisitor.java b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertPropertyStatementVisitor.java index 5be1b61..be33a57 100644 --- a/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertPropertyStatementVisitor.java +++ b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertPropertyStatementVisitor.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2023 the original author or authors. *

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,8 @@ */ package org.openrewrite.java.security.xml; -import org.openrewrite.Cursor; -import org.openrewrite.java.JavaIsoVisitor; -import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.VariableNameUtils; import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaCoordinates; import org.openrewrite.java.tree.Statement; import java.util.Collections; @@ -28,13 +24,10 @@ import java.util.Set; import java.util.stream.Collectors; -class XmlFactoryInsertPropertyStatementVisitor

extends JavaIsoVisitor

{ - private final J.Block scope; - private final StringBuilder propertyTemplate = new StringBuilder(); +class XmlFactoryInsertPropertyStatementVisitor

extends XmlFactoryInsertVisitor

{ private final ExternalDTDAccumulator acc; private final boolean generateAllowList; - private final String xmlFactoryVariableName; public XmlFactoryInsertPropertyStatementVisitor( J.Block scope, @@ -46,21 +39,25 @@ public XmlFactoryInsertPropertyStatementVisitor( boolean needsResolverMethod, ExternalDTDAccumulator acc ) { - this.scope = scope; + super( + scope, + factoryVariableName, + XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE, + XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_SET_PROPERTY + ); this.acc = acc; - this.xmlFactoryVariableName = factoryVariableName; if (needsExternalEntitiesDisabled) { - propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);"); + getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);"); } if (needsSupportDTDFalse && accIsEmpty) { if (needsSupportDTDTrue) { - propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);"); + getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);"); } } if (needsSupportDTDFalse && !accIsEmpty) { if (needsResolverMethod && needsSupportDTDTrue) { - propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, true);"); + getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.SUPPORT_DTD, true);"); } this.generateAllowList = needsResolverMethod; } else if (!needsSupportDTDTrue && !accIsEmpty) { @@ -86,12 +83,12 @@ private Set addAllowList(boolean generateAllowList) { if (acc.getExternalDTDs().size() > 1) { imports.add("java.util.Arrays"); - propertyTemplate.append( + getTemplate().append( "Collection" + newAllowListVariableName + " = Arrays.asList(\n" ); } else { imports.add("java.util.Collections"); - propertyTemplate.append( + getTemplate().append( "Collection" + newAllowListVariableName + " = Collections.singleton(\n" ); } @@ -101,8 +98,8 @@ private Set addAllowList(boolean generateAllowList) { "\t", "" )); - propertyTemplate.append(allowListContent).append("\n);\n"); - propertyTemplate.append(xmlFactoryVariableName).append( + getTemplate().append(allowListContent).append("\n);\n"); + getTemplate().append(getFactoryVariableName()).append( ".setXMLResolver((publicID, systemID, baseURI, namespace) -> {\n" + " if (" + newAllowListVariableName + ".contains(systemID)){\n" + " // returning null will cause the parser to resolve the entity\n" + @@ -117,44 +114,10 @@ private Set addAllowList(boolean generateAllowList) { @Override public J.Block visitBlock(J.Block block, P ctx) { J.Block b = super.visitBlock(block, ctx); - Statement beforeStatement = null; - if (b.isScope(scope)) { - for (int i = b.getStatements().size() - 2; i > -1; i--) { - Statement st = b.getStatements().get(i); - Statement stBefore = b.getStatements().get(i + 1); - if (st instanceof J.MethodInvocation) { - J.MethodInvocation m = (J.MethodInvocation) st; - if (XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE.matches(m) || XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_SET_PROPERTY.matches(m)) { - beforeStatement = stBefore; - } - } else if (st instanceof J.VariableDeclarations) { - J.VariableDeclarations vd = (J.VariableDeclarations) st; - if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) { - J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer(); - if (m != null && XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE.matches(m)) { - beforeStatement = stBefore; - } - } - } - } - + Statement beforeStatement = getInsertStatement(b); + if (b.isScope(getScope())) { Set imports = addAllowList(generateAllowList); - - if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) { - propertyTemplate.insert(0, "{\n").append("}"); - } - JavaCoordinates insertCoordinates = beforeStatement != null ? - beforeStatement.getCoordinates().before() : - b.getCoordinates().lastStatement(); - b = JavaTemplate - .builder(propertyTemplate.toString()) - .imports(imports.toArray(new String[0])) - .contextSensitive() - .build() - .apply(new Cursor(getCursor().getParent(), b), insertCoordinates); - if (b != block) { - imports.forEach(this::maybeAddImport); - } + b = updateBlock(b, block, beforeStatement, imports); } return b; } diff --git a/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertVisitor.java b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertVisitor.java new file mode 100644 index 0000000..3d84f34 --- /dev/null +++ b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryInsertVisitor.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.security.xml; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.openrewrite.Cursor; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaCoordinates; +import org.openrewrite.java.tree.Statement; + +import java.util.Set; + +@AllArgsConstructor +@Getter +public abstract class XmlFactoryInsertVisitor

extends JavaIsoVisitor

{ + private final StringBuilder template = new StringBuilder(); + private final J.Block scope; + private final String factoryVariableName; + private final MethodMatcher factoryInstanceMatcher; + private final MethodMatcher factoryMethodCallMatcher; + + public Statement getInsertStatement(J.Block b) { + Statement beforeStatement = null; + if (b.isScope(scope)) { + for (int i = b.getStatements().size() - 2; i > -1; i--) { + Statement st = b.getStatements().get(i); + Statement stBefore = b.getStatements().get(i + 1); + if (st instanceof J.MethodInvocation) { + J.MethodInvocation m = (J.MethodInvocation) st; + if (factoryInstanceMatcher.matches(m) || factoryMethodCallMatcher.matches(m)) { + beforeStatement = stBefore; + } + } else if (st instanceof J.VariableDeclarations) { + J.VariableDeclarations vd = (J.VariableDeclarations) st; + if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) { + J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer(); + if (m != null && factoryInstanceMatcher.matches(m)) { + beforeStatement = stBefore; + } + } + } + } + } + return beforeStatement; + } + + private JavaCoordinates getInsertCoordinates(J.Block b, Statement s) { + return s != null ? s.getCoordinates().before() : b.getCoordinates().lastStatement(); + } + + public J.Block updateBlock(J.Block b, J.Block block, Statement beforeStatement, Set imports) { + if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) { + template.insert(0, "{\n").append("}"); + } + b = JavaTemplate + .builder(template.toString()) + .imports(imports.toArray(new String[0])) + .contextSensitive() + .build() + .apply(new Cursor(getCursor().getParent(), b), getInsertCoordinates(b, beforeStatement)); + if (b != block) { + imports.forEach(this::maybeAddImport); + } + return b; + } +} diff --git a/src/main/java/org/openrewrite/java/security/xml/XmlFactoryVisitor.java b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryVisitor.java new file mode 100644 index 0000000..f3c5386 --- /dev/null +++ b/src/main/java/org/openrewrite/java/security/xml/XmlFactoryVisitor.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.security.xml; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.TypeUtils; + +@AllArgsConstructor +@Getter +public abstract class XmlFactoryVisitor

extends JavaIsoVisitor

{ + private final MethodMatcher FACTORY_INSTANCE; + + private final String FACTORY_FQN; + + private final String FACTORY_INITIALIZATION_METHOD; + private final String FACTORY_VARIABLE_NAME; + + private final ExternalDTDAccumulator acc; + + @Override + public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, P ctx) { + J.VariableDeclarations.NamedVariable v = super.visitVariable(variable, ctx); + if (TypeUtils.isOfClassType(v.getType(), FACTORY_FQN)) { + getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FACTORY_VARIABLE_NAME, v.getSimpleName()); + } + return v; + } + + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx) { + J.MethodInvocation m = super.visitMethodInvocation(method, ctx); + if (FACTORY_INSTANCE.matches(m)) { + getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FACTORY_INITIALIZATION_METHOD, getCursor().dropParentUntil(J.Block.class::isInstance)); + } + return m; + } + + /** + * Adds a message/flag on the first enclosing class instance. + * + * @param message The message to be added. + */ + public void addMessage(String message) { + getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, message, getCursor().dropParentUntil(J.Block.class::isInstance)); + } +} diff --git a/src/main/java/org/openrewrite/java/security/xml/XmlInputFactoryFixVisitor.java b/src/main/java/org/openrewrite/java/security/xml/XmlInputFactoryFixVisitor.java index 30beeba..f56af5a 100644 --- a/src/main/java/org/openrewrite/java/security/xml/XmlInputFactoryFixVisitor.java +++ b/src/main/java/org/openrewrite/java/security/xml/XmlInputFactoryFixVisitor.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2023 the original author or authors. *

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,10 @@ */ package org.openrewrite.java.security.xml; -import lombok.AllArgsConstructor; import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; import org.openrewrite.Preconditions; import org.openrewrite.TreeVisitor; -import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.J; @@ -28,8 +26,7 @@ import javax.xml.stream.XMLInputFactory; -@AllArgsConstructor -public class XmlInputFactoryFixVisitor

extends JavaIsoVisitor

{ +public class XmlInputFactoryFixVisitor

extends XmlFactoryVisitor

{ static final MethodMatcher XML_PARSER_FACTORY_INSTANCE = new MethodMatcher("javax.xml.stream.XMLInputFactory new*()"); static final MethodMatcher XML_PARSER_FACTORY_SET_PROPERTY = new MethodMatcher("javax.xml.stream.XMLInputFactory setProperty(java.lang.String, ..)"); @@ -45,10 +42,18 @@ public class XmlInputFactoryFixVisitor

extends JavaIsoVisitor

{ private static final String XML_RESOLVER_METHOD = "xml-resolver-initialization-method"; - private final ExternalDTDAccumulator acc; + public XmlInputFactoryFixVisitor(ExternalDTDAccumulator acc) { + super( + XML_PARSER_FACTORY_INSTANCE, + XML_FACTORY_FQN, + XML_PARSER_INITIALIZATION_METHOD, + XML_FACTORY_VARIABLE_NAME, + acc + ); + } + @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P ctx) { - J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx); Cursor supportsExternalCursor = getCursor().getMessage(SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME); Cursor supportsFalseDTDCursor = getCursor().getMessage(SUPPORT_DTD_FALSE_PROPERTY_NAME); @@ -60,7 +65,6 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P Cursor setPropertyBlockCursor = null; if (supportsExternalCursor == null && supportsFalseDTDCursor == null) { setPropertyBlockCursor = initializationCursor; - } else if (supportsExternalCursor == null ^ supportsFalseDTDCursor == null) { setPropertyBlockCursor = supportsExternalCursor == null ? supportsFalseDTDCursor : supportsExternalCursor; } @@ -70,33 +74,22 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P xmlFactoryVariableName, supportsExternalCursor == null, supportsFalseDTDCursor == null, - acc.getExternalDTDs().isEmpty(), + getAcc().getExternalDTDs().isEmpty(), supportsDTDTrueCursor == null, xmlResolverMethod == null, - acc + getAcc() )); } return cd; } - @Override - public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, P ctx) { - J.VariableDeclarations.NamedVariable v = super.visitVariable(variable, ctx); - if (TypeUtils.isOfClassType(v.getType(), XML_FACTORY_FQN)) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XML_FACTORY_VARIABLE_NAME, v.getSimpleName()); - } - return v; - } - @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx) { J.MethodInvocation m = super.visitMethodInvocation(method, ctx); - if (XML_PARSER_FACTORY_INSTANCE.matches(m)) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XML_PARSER_INITIALIZATION_METHOD, getCursor().dropParentUntil(J.Block.class::isInstance)); - } else if (XML_PARSER_FACTORY_SET_PROPERTY.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) { + if (XML_PARSER_FACTORY_SET_PROPERTY.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) { J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0); if (SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME.equals(fa.getSimpleName())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME); } else if (SUPPORT_DTD_FALSE_PROPERTY_NAME.equals(fa.getSimpleName())) { checkDTDSupport(m); } @@ -104,26 +97,24 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx J.Literal literal = (J.Literal) m.getArguments().get(0); if (TypeUtils.isString(literal.getType())) { if (XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES.equals(literal.getValue())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME); } else if (XMLInputFactory.SUPPORT_DTD.equals(literal.getValue())) { checkDTDSupport(m); } } } else if (XML_PARSER_FACTORY_SET_RESOLVER.matches(m)) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XML_RESOLVER_METHOD, getCursor().dropParentUntil((J.Block.class::isInstance))); + addMessage(XML_RESOLVER_METHOD); } return m; - - } private void checkDTDSupport(J.MethodInvocation m) { if (m.getArguments().get(1) instanceof J.Literal) { J.Literal literal = (J.Literal) m.getArguments().get(1); if (Boolean.TRUE.equals(literal.getValue())) { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORT_DTD_TRUE_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(SUPPORT_DTD_TRUE_PROPERTY_NAME); } else { - getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORT_DTD_FALSE_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance)); + addMessage(SUPPORT_DTD_FALSE_PROPERTY_NAME); } } } diff --git a/src/test/java/org/openrewrite/java/security/xml/XmlInputFactoryXXEVulnerabilityTest.java b/src/test/java/org/openrewrite/java/security/xml/XmlInputFactoryXXEVulnerabilityTest.java index a31ed29..f8b84d5 100644 --- a/src/test/java/org/openrewrite/java/security/xml/XmlInputFactoryXXEVulnerabilityTest.java +++ b/src/test/java/org/openrewrite/java/security/xml/XmlInputFactoryXXEVulnerabilityTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2021 the original author or authors. + * Copyright 2023 the original author or authors. *

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.