diff --git a/core/src/main/java/io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream.java b/core/src/main/java/io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream.java index 6432d82..f8a7248 100644 --- a/core/src/main/java/io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream.java +++ b/core/src/main/java/io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream.java @@ -24,7 +24,7 @@ public ClassFilteringObjectInputStream(InputStream in, PatchModule patchModule) this(in, patchModule, null); } - private boolean isClassAllowed(String className) { + private static boolean isClassAllowed(String className, PatchModule patchModule) { // strip all array dimensions, just get the base type while (className.startsWith("[")) { className = className.substring(1); @@ -35,12 +35,12 @@ private boolean isClassAllowed(String className) { } if (SerializationIsBad.getInstance().getConfig().getClassAllowlist().contains(className) - || this.patchModule.getClassAllowlist().contains(className)) { + || patchModule.getClassAllowlist().contains(className)) { return true; } Set allowedPackages = new HashSet<>(SerializationIsBad.getInstance().getConfig().getPackageAllowlist()); - allowedPackages.addAll(this.patchModule.getPackageAllowlist()); + allowedPackages.addAll(patchModule.getPackageAllowlist()); for (String allowedPackage : allowedPackages) { if (className.startsWith(allowedPackage + ".")) { @@ -53,13 +53,7 @@ private boolean isClassAllowed(String className) { @Override protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - SerializationIsBad.logger.debug("Resolving class " + desc.getName()); - - if (!this.isClassAllowed(desc.getName())) { - SerializationIsBad.logger.warn("Tried to resolve class " + desc.getName() + ", which is not allowed to be deserialized"); - if (SerializationIsBad.getInstance().getConfig().isExecuteBlocking()) - throw new ClassNotFoundException("Class " + desc.getName() + " is not allowed to be deserialized"); - } + ClassFilteringObjectInputStream.resolveClassPrecheck(desc, this.patchModule); if (this.parentClassLoader == null) { return super.resolveClass(desc); @@ -78,6 +72,16 @@ protected Class resolveClass(ObjectStreamClass desc) throws IOException, Clas } } + public static void resolveClassPrecheck(ObjectStreamClass desc, PatchModule patchModule) throws ClassNotFoundException { + SerializationIsBad.logger.debug("Resolving class " + desc.getName()); + + if (!ClassFilteringObjectInputStream.isClassAllowed(desc.getName(), patchModule)) { + SerializationIsBad.logger.warn("Tried to resolve class " + desc.getName() + ", which is not allowed to be deserialized"); + if (SerializationIsBad.getInstance().getConfig().isExecuteBlocking()) + throw new ClassNotFoundException("Class " + desc.getName() + " is not allowed to be deserialized"); + } + } + private static final HashMap> primClasses = new HashMap<>(8, 1.0F); static { ClassFilteringObjectInputStream.primClasses.put("boolean", boolean.class); diff --git a/core/src/main/java/io/dogboy/serializationisbad/core/Patches.java b/core/src/main/java/io/dogboy/serializationisbad/core/Patches.java index dab88f9..4d694e3 100644 --- a/core/src/main/java/io/dogboy/serializationisbad/core/Patches.java +++ b/core/src/main/java/io/dogboy/serializationisbad/core/Patches.java @@ -44,6 +44,28 @@ private static byte[] writeClassNode(ClassNode classNode) { private static void applyPatches(String className, ClassNode classNode, boolean passClassLoader) { SerializationIsBad.logger.info("Applying patches to " + className); + if ("java/io/ObjectInputStream".equals(classNode.superName)) { + for (MethodNode methodNode : classNode.methods) { + if (!"resolveClass".equals(methodNode.name)) continue; + + InsnList additionalInstructions = new InsnList(); + additionalInstructions.add(new VarInsnNode(Opcodes.ALOAD, 1)); // Class Descriptor + additionalInstructions.add(new LdcInsnNode(className)); + additionalInstructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/Patches", + "getPatchModuleForClass", "(Ljava/lang/String;)Lio/dogboy/serializationisbad/core/config/PatchModule;", false)); + additionalInstructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream", + "resolveClassPrecheck", "(Ljava/io/ObjectStreamClass;Lio/dogboy/serializationisbad/core/config/PatchModule;)V", false)); + + methodNode.instructions.insertBefore(methodNode.instructions.getFirst(), additionalInstructions); + + SerializationIsBad.logger.info(" Injecting resolveClass precheck in method " + methodNode.name); + + break; + } + + return; + } + for (MethodNode methodNode : classNode.methods) { InsnList instructions = methodNode.instructions; for (int i = 0; i < instructions.size(); i++) {