Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding two abstract visitor classes for XXE vulnerabilities #98

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
*/
package org.openrewrite.java.security.xml;

import lombok.AllArgsConstructor;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesType;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.TypeUtils;

import javax.xml.XMLConstants;

@AllArgsConstructor
public class TransformerFactoryFixVisitor<P> extends JavaIsoVisitor<P> {
public class TransformerFactoryFixVisitor<P> extends XmlFactoryVisitor<P> {
static final MethodMatcher TRANSFORMER_FACTORY_INSTANCE = new MethodMatcher("javax.xml.transform.TransformerFactory new*()");
static final MethodMatcher TRANSFORMER_FACTORY_SET_ATTRIBUTE = new MethodMatcher("javax.xml.transform.TransformerFactory setAttribute(java.lang.String, ..)");
static final MethodMatcher TRANSFORMER_FACTORY_SET_FEATURE = new MethodMatcher("javax.xml.transform.TransformerFactory setFeature(java.lang.String, ..)");
Expand All @@ -44,6 +40,16 @@ public class TransformerFactoryFixVisitor<P> extends JavaIsoVisitor<P> {

private static final String DISALLOW_MODIFY_FLAG = "DISALLOW_MODIFY_FLAG";

public TransformerFactoryFixVisitor(ExternalDTDAccumulator acc) {
super(
TRANSFORMER_FACTORY_INSTANCE,
TRANSFORMER_FACTORY_FQN,
TRANSFORMER_FACTORY_INITIALIZATION_METHOD,
TRANSFORMER_FACTORY_VARIABLE_NAME,
acc
);
}

@Override
public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P ctx) {
J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx);
Expand Down Expand Up @@ -78,60 +84,49 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P
return cd;
}

@Override
public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, P ctx) {
J.VariableDeclarations.NamedVariable v = super.visitVariable(variable, ctx);
if (TypeUtils.isOfClassType(v.getType(), TRANSFORMER_FACTORY_FQN)) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, TRANSFORMER_FACTORY_VARIABLE_NAME, v.getSimpleName());
}
return v;
}

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P ctx) {
J.MethodInvocation m = super.visitMethodInvocation(method, ctx);
if (TRANSFORMER_FACTORY_INSTANCE.matches(m)) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, TRANSFORMER_FACTORY_INITIALIZATION_METHOD, getCursor().dropParentUntil(J.Block.class::isInstance));
} else if (TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) {
if (TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) {
// If either attribute value is not equal to the empty string, do not make any changes
if (m.getArguments().get(1) instanceof J.Literal) {
J.Literal string = (J.Literal) m.getArguments().get(1);
assert string.getValue() != null;
if (!(((String) string.getValue()).isEmpty())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, DISALLOW_MODIFY_FLAG, getCursor().dropParentUntil(J.Block.class::isInstance));
addMessage(DISALLOW_MODIFY_FLAG);
}
}
J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0);
if (ACCESS_EXTERNAL_DTD_NAME.equals(fa.getSimpleName())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, ACCESS_EXTERNAL_DTD_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
addMessage(ACCESS_EXTERNAL_DTD_NAME);
} else if (ACCESS_EXTERNAL_STYLESHEET_NAME.equals(fa.getSimpleName())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, ACCESS_EXTERNAL_STYLESHEET_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
addMessage(ACCESS_EXTERNAL_STYLESHEET_NAME);
}
} else if (TRANSFORMER_FACTORY_SET_FEATURE.matches(m)) {
// If FEATURE_SECURE_PROCESSING is set to false, do not make any changes
if (m.getArguments().get(1) instanceof J.Literal) {
J.Literal bool = (J.Literal) m.getArguments().get(1);
assert bool.getValue() != null;
if (!((Boolean) bool.getValue())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, DISALLOW_MODIFY_FLAG, getCursor().dropParentUntil(J.Block.class::isInstance));
if (Boolean.FALSE.equals(bool.getValue())) {
addMessage(DISALLOW_MODIFY_FLAG);
}
}
if (m.getArguments().get(0) instanceof J.FieldAccess) {
J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0);
if (FEATURE_SECURE_PROCESSING_NAME.equals(fa.getSimpleName())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FEATURE_SECURE_PROCESSING_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
addMessage(FEATURE_SECURE_PROCESSING_NAME);
}
} else if (m.getArguments().get(0) instanceof J.Literal) {
J.Literal literal = (J.Literal) m.getArguments().get(0);
if (XMLConstants.FEATURE_SECURE_PROCESSING.equals(literal.getValue())) {
getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, FEATURE_SECURE_PROCESSING_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
addMessage(FEATURE_SECURE_PROCESSING_NAME);
}
}
}
return m;
}

public static TreeVisitor<?, ExecutionContext> create(ExternalDTDAccumulator acc) {
return Preconditions.check(new UsesType<>(TRANSFORMER_FACTORY_FQN, true), new TransformerFactoryFixVisitor<>());
return Preconditions.check(new UsesType<>(TRANSFORMER_FACTORY_FQN, true), new TransformerFactoryFixVisitor<>(acc));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,76 +15,43 @@
*/
package org.openrewrite.java.security.xml;

import org.openrewrite.Cursor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.java.tree.Statement;

public class TransformerFactoryInsertAttributeStatementVisitor<P> extends JavaIsoVisitor<P> {
private final J.Block scope;
private final StringBuilder attributeTemplate = new StringBuilder();
private final String transformerFactoryVariableName;
import java.util.Collections;

public class TransformerFactoryInsertAttributeStatementVisitor<P> extends XmlFactoryInsertVisitor<P> {
public TransformerFactoryInsertAttributeStatementVisitor(
J.Block scope,
String factoryVariableName,
boolean needsExternalEntitiesDisabled,
boolean needsStylesheetsDisabled,
boolean needsFeatureSecureProcessing
) {
this.scope = scope;
this.transformerFactoryVariableName = factoryVariableName;
super(
scope,
factoryVariableName,
TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE,
TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_SET_ATTRIBUTE
);

if (needsExternalEntitiesDisabled) {
attributeTemplate.append(transformerFactoryVariableName).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, \"\");");
getTemplate().append(getFactoryVariableName()).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, \"\");");
}
if (needsStylesheetsDisabled) {
attributeTemplate.append(transformerFactoryVariableName).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, \"\");");
getTemplate().append(getFactoryVariableName()).append(".setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, \"\");");
}
if (needsFeatureSecureProcessing) {
attributeTemplate.append(transformerFactoryVariableName).append(".setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);");
getTemplate().append(getFactoryVariableName()).append(".setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);");
}
}

@Override
public J.Block visitBlock(J.Block block, P ctx) {
J.Block b = super.visitBlock(block, ctx);
Statement beforeStatement = null;
if (b.isScope(scope)) {
for (int i = b.getStatements().size() - 2; i > -1; i--) {
Statement st = b.getStatements().get(i);
Statement stBefore = b.getStatements().get(i + 1);
if (st instanceof J.MethodInvocation) {
J.MethodInvocation m = (J.MethodInvocation) st;
if (TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE.matches(m) || TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_SET_ATTRIBUTE.matches(m)) {
beforeStatement = stBefore;
}
} else if (st instanceof J.VariableDeclarations) {
J.VariableDeclarations vd = (J.VariableDeclarations) st;
if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) {
J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer();
if (m != null && TransformerFactoryFixVisitor.TRANSFORMER_FACTORY_INSTANCE.matches(m)) {
beforeStatement = stBefore;
}
}
}
}

if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) {
attributeTemplate.insert(0, "{").append("}");
}
JavaCoordinates insertCoordinates = beforeStatement != null ?
beforeStatement.getCoordinates().before() :
b.getCoordinates().lastStatement();
b = JavaTemplate
.builder(attributeTemplate.toString())
.imports("javax.xml.XMLConstants")
.contextSensitive()
.build()
.apply(new Cursor(getCursor().getParent(), b), insertCoordinates);
maybeAddImport("javax.xml.XMLConstants");
Statement beforeStatement = getInsertStatement(b);
if (b.isScope(getScope())) {
b = updateBlock(b, block, beforeStatement, Collections.singleton("javax.xml.XMLConstants"));
}
return b;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 the original author or authors.
* Copyright 2023 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,26 +15,19 @@
*/
package org.openrewrite.java.security.xml;

import org.openrewrite.Cursor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.VariableNameUtils;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.java.tree.Statement;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

class XmlFactoryInsertPropertyStatementVisitor<P> extends JavaIsoVisitor<P> {
private final J.Block scope;
private final StringBuilder propertyTemplate = new StringBuilder();
class XmlFactoryInsertPropertyStatementVisitor<P> extends XmlFactoryInsertVisitor<P> {
private final ExternalDTDAccumulator acc;

private final boolean generateAllowList;
private final String xmlFactoryVariableName;

public XmlFactoryInsertPropertyStatementVisitor(
J.Block scope,
Expand All @@ -46,21 +39,25 @@ public XmlFactoryInsertPropertyStatementVisitor(
boolean needsResolverMethod,
ExternalDTDAccumulator acc
) {
this.scope = scope;
super(
scope,
factoryVariableName,
XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE,
XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_SET_PROPERTY
);
this.acc = acc;
this.xmlFactoryVariableName = factoryVariableName;

if (needsExternalEntitiesDisabled) {
propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);");
getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);");
}
if (needsSupportDTDFalse && accIsEmpty) {
if (needsSupportDTDTrue) {
propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);");
getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);");
}
}
if (needsSupportDTDFalse && !accIsEmpty) {
if (needsResolverMethod && needsSupportDTDTrue) {
propertyTemplate.append(xmlFactoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, true);");
getTemplate().append(getFactoryVariableName()).append(".setProperty(XMLInputFactory.SUPPORT_DTD, true);");
}
this.generateAllowList = needsResolverMethod;
} else if (!needsSupportDTDTrue && !accIsEmpty) {
Expand All @@ -86,12 +83,12 @@ private Set<String> addAllowList(boolean generateAllowList) {

if (acc.getExternalDTDs().size() > 1) {
imports.add("java.util.Arrays");
propertyTemplate.append(
getTemplate().append(
"Collection<String>" + newAllowListVariableName + " = Arrays.asList(\n"
);
} else {
imports.add("java.util.Collections");
propertyTemplate.append(
getTemplate().append(
"Collection<String>" + newAllowListVariableName + " = Collections.singleton(\n"
);
}
Expand All @@ -101,8 +98,8 @@ private Set<String> addAllowList(boolean generateAllowList) {
"\t",
""
));
propertyTemplate.append(allowListContent).append("\n);\n");
propertyTemplate.append(xmlFactoryVariableName).append(
getTemplate().append(allowListContent).append("\n);\n");
getTemplate().append(getFactoryVariableName()).append(
".setXMLResolver((publicID, systemID, baseURI, namespace) -> {\n" +
" if (" + newAllowListVariableName + ".contains(systemID)){\n" +
" // returning null will cause the parser to resolve the entity\n" +
Expand All @@ -117,44 +114,10 @@ private Set<String> addAllowList(boolean generateAllowList) {
@Override
public J.Block visitBlock(J.Block block, P ctx) {
J.Block b = super.visitBlock(block, ctx);
Statement beforeStatement = null;
if (b.isScope(scope)) {
for (int i = b.getStatements().size() - 2; i > -1; i--) {
Statement st = b.getStatements().get(i);
Statement stBefore = b.getStatements().get(i + 1);
if (st instanceof J.MethodInvocation) {
J.MethodInvocation m = (J.MethodInvocation) st;
if (XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE.matches(m) || XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_SET_PROPERTY.matches(m)) {
beforeStatement = stBefore;
}
} else if (st instanceof J.VariableDeclarations) {
J.VariableDeclarations vd = (J.VariableDeclarations) st;
if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) {
J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer();
if (m != null && XmlInputFactoryFixVisitor.XML_PARSER_FACTORY_INSTANCE.matches(m)) {
beforeStatement = stBefore;
}
}
}
}

Statement beforeStatement = getInsertStatement(b);
if (b.isScope(getScope())) {
Set<String> imports = addAllowList(generateAllowList);

if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) {
propertyTemplate.insert(0, "{\n").append("}");
}
JavaCoordinates insertCoordinates = beforeStatement != null ?
beforeStatement.getCoordinates().before() :
b.getCoordinates().lastStatement();
b = JavaTemplate
.builder(propertyTemplate.toString())
.imports(imports.toArray(new String[0]))
.contextSensitive()
.build()
.apply(new Cursor(getCursor().getParent(), b), insertCoordinates);
if (b != block) {
imports.forEach(this::maybeAddImport);
}
b = updateBlock(b, block, beforeStatement, imports);
}
return b;
}
Expand Down
Loading