Skip to content

Commit

Permalink
Version 1.5 (#76)
Browse files Browse the repository at this point in the history
* feat: cache remote config locally and allow changing the config directory (#73)

* fix(fabric-agent): allow SiB classes to be loaded from the parent classloader (#74)

* Fix other ois implementations (#75)

* feat: add support for patching custom OIS implementations

* fix: move custom OIS implementations to patch to a separate config value
Otherwise the class patching in older SiB
versions would break with a newer remote config

* fix(modlauncher): add custom ois classes to transformer targets

* chore: bump version to 1.5
  • Loading branch information
dogboy21 authored Aug 29, 2023
1 parent 2d0e3ee commit 5b1b03a
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public byte[] transform(ClassLoader loader, String className, Class<?> classBein
try {
if (className == null) return classfileBuffer;
if ("net/minecraft/launchwrapper/ITweaker".equals(className)) SerializationIsBadAgent.insertLaunchWrapperExclusion();
if ("net/fabricmc/loader/ModContainer".equals(className)) SerializationIsBadAgent.insertFabricValidParentUrl(loader);

String classNameDots = className.replace('/', '.');

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import java.lang.instrument.Instrumentation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Set;

public class SerializationIsBadAgent {

Expand Down Expand Up @@ -35,4 +38,39 @@ static void insertLaunchWrapperExclusion() {
}
}

/**
* Another hacky workaround for newer Fabric versions that enforce
* classpath isolation. This adds the path to the SiB jar to the
* list of jar paths that are allowed to be loaded by the parent
* classloader
*
* @param fabricClassLoader The classloader that was used to load the Fabric classes
*/
static void insertFabricValidParentUrl(ClassLoader fabricClassLoader) {
try {
Path sibPath = new File(SerializationIsBadAgent.class.getProtectionDomain().getCodeSource().getLocation().toURI()).toPath();

// basically accessing the following:
// ((KnotClassDelegate) ((Knot) FabricLauncherBase.getLauncher()).classLoader).validParentCodeSources

Class<?> fabricLauncherBaseClass = Class.forName("net.fabricmc.loader.impl.launch.FabricLauncherBase", true, fabricClassLoader);
Method getLauncherMethod = fabricLauncherBaseClass.getDeclaredMethod("getLauncher");
Object fabricLauncher = getLauncherMethod.invoke(null);
Field classLoaderField = fabricLauncher.getClass().getDeclaredField("classLoader");
classLoaderField.setAccessible(true);
Object classLoader = classLoaderField.get(fabricLauncher);
Field validParentCodeSourcesField = classLoader.getClass().getDeclaredField("validParentCodeSources");
validParentCodeSourcesField.setAccessible(true);
@SuppressWarnings("unchecked")
Set<Path> validParentCodeSources = (Set<Path>) validParentCodeSourcesField.get(classLoader);

Set<Path> newValidParentCodeSources = new HashSet<>(validParentCodeSources);
newValidParentCodeSources.add(sibPath);

validParentCodeSourcesField.set(classLoader, newValidParentCodeSources);
} catch (Throwable e) {
SerializationIsBad.logger.error("Failed to insert Fabric valid parent URL", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ClassFilteringObjectInputStream(InputStream in, PatchModule patchModule)
this(in, patchModule, null);
}

private boolean isClassAllowed(String className) {
private static boolean isClassAllowed(String className, PatchModule patchModule) {
// strip all array dimensions, just get the base type
while (className.startsWith("[")) {
className = className.substring(1);
Expand All @@ -35,12 +35,12 @@ private boolean isClassAllowed(String className) {
}

if (SerializationIsBad.getInstance().getConfig().getClassAllowlist().contains(className)
|| this.patchModule.getClassAllowlist().contains(className)) {
|| patchModule.getClassAllowlist().contains(className)) {
return true;
}

Set<String> allowedPackages = new HashSet<>(SerializationIsBad.getInstance().getConfig().getPackageAllowlist());
allowedPackages.addAll(this.patchModule.getPackageAllowlist());
allowedPackages.addAll(patchModule.getPackageAllowlist());

for (String allowedPackage : allowedPackages) {
if (className.startsWith(allowedPackage + ".")) {
Expand All @@ -53,13 +53,7 @@ private boolean isClassAllowed(String className) {

@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
SerializationIsBad.logger.debug("Resolving class " + desc.getName());

if (!this.isClassAllowed(desc.getName())) {
SerializationIsBad.logger.warn("Tried to resolve class " + desc.getName() + ", which is not allowed to be deserialized");
if (SerializationIsBad.getInstance().getConfig().isExecuteBlocking())
throw new ClassNotFoundException("Class " + desc.getName() + " is not allowed to be deserialized");
}
ClassFilteringObjectInputStream.resolveClassPrecheck(desc, this.patchModule);

if (this.parentClassLoader == null) {
return super.resolveClass(desc);
Expand All @@ -78,6 +72,16 @@ protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, Clas
}
}

public static void resolveClassPrecheck(ObjectStreamClass desc, PatchModule patchModule) throws ClassNotFoundException {
SerializationIsBad.logger.debug("Resolving class " + desc.getName());

if (!ClassFilteringObjectInputStream.isClassAllowed(desc.getName(), patchModule)) {
SerializationIsBad.logger.warn("Tried to resolve class " + desc.getName() + ", which is not allowed to be deserialized");
if (SerializationIsBad.getInstance().getConfig().isExecuteBlocking())
throw new ClassNotFoundException("Class " + desc.getName() + " is not allowed to be deserialized");
}
}

private static final HashMap<String, Class<?>> primClasses = new HashMap<>(8, 1.0F);
static {
ClassFilteringObjectInputStream.primClasses.put("boolean", boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ public class Patches {

public static PatchModule getPatchModuleForClass(String className) {
for (PatchModule patchModule : SerializationIsBad.getInstance().getConfig().getPatchModules()) {
if (patchModule.getClassesToPatch().contains(className)) {
if (patchModule.getClassesToPatch().contains(className)
|| patchModule.getCustomOISClasses().contains(className)) {
return patchModule;
}
}
Expand All @@ -43,6 +44,33 @@ private static byte[] writeClassNode(ClassNode classNode) {

private static void applyPatches(String className, ClassNode classNode, boolean passClassLoader) {
SerializationIsBad.logger.info("Applying patches to " + className);
PatchModule patchModule = Patches.getPatchModuleForClass(className);
if (patchModule == null) {
SerializationIsBad.logger.info(" No patches to apply");
return;
}

if (patchModule.getCustomOISClasses().contains(className) && "java/io/ObjectInputStream".equals(classNode.superName)) {
for (MethodNode methodNode : classNode.methods) {
if (!"resolveClass".equals(methodNode.name)) continue;

InsnList additionalInstructions = new InsnList();
additionalInstructions.add(new VarInsnNode(Opcodes.ALOAD, 1)); // Class Descriptor
additionalInstructions.add(new LdcInsnNode(className));
additionalInstructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/Patches",
"getPatchModuleForClass", "(Ljava/lang/String;)Lio/dogboy/serializationisbad/core/config/PatchModule;", false));
additionalInstructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "io/dogboy/serializationisbad/core/ClassFilteringObjectInputStream",
"resolveClassPrecheck", "(Ljava/io/ObjectStreamClass;Lio/dogboy/serializationisbad/core/config/PatchModule;)V", false));

methodNode.instructions.insertBefore(methodNode.instructions.getFirst(), additionalInstructions);

SerializationIsBad.logger.info(" Injecting resolveClass precheck in method " + methodNode.name);

break;
}

return;
}

for (MethodNode methodNode : classNode.methods) {
InsnList instructions = methodNode.instructions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -72,8 +74,17 @@ public SIBConfig getConfig() {
return this.config;
}

private static File getConfigDir(File minecraftDir) {
String configDirOverride = System.getProperty("serializationisbad.configdir");
if (configDirOverride != null) {
return new File(configDirOverride);
}

return new File(minecraftDir, "config");
}

private static SIBConfig readConfig(File minecraftDir) {
File configFile = new File(new File(minecraftDir, "config"), "serializationisbad.json");
File configFile = new File(SerializationIsBad.getConfigDir(minecraftDir), "serializationisbad.json");
Gson gson = new GsonBuilder().setPrettyPrinting().create();

SIBConfig localConfig = new SIBConfig();
Expand All @@ -98,7 +109,7 @@ private static SIBConfig readConfig(File minecraftDir) {
return localConfig;
}

SIBConfig remoteConfig = SerializationIsBad.readRemoteConfig(localConfig.getRemoteConfigUrl());
SIBConfig remoteConfig = SerializationIsBad.readRemoteConfig(minecraftDir, localConfig.getRemoteConfigUrl());
if (remoteConfig != null) {
SerializationIsBad.logger.info("Using remote config file");
return remoteConfig;
Expand All @@ -108,8 +119,10 @@ private static SIBConfig readConfig(File minecraftDir) {
return localConfig;
}

private static SIBConfig readRemoteConfig(String url) {
private static SIBConfig readRemoteConfig(File minecraftDir, String url) {
Gson gson = new Gson();
File cacheFile = new File(SerializationIsBad.getConfigDir(minecraftDir), "serializationisbad-remotecache.json");

try {
HttpsURLConnection connection = (HttpsURLConnection) new URL(url).openConnection();
SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
Expand All @@ -120,16 +133,41 @@ private static SIBConfig readRemoteConfig(String url) {

if (connection.getResponseCode() != 200) throw new IOException("Invalid response code: " + connection.getResponseCode());

try (InputStreamReader inputStreamReader = new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8)) {
return gson.fromJson(inputStreamReader, SIBConfig.class);
byte[] configBytes = SerializationIsBad.readInputStream(connection.getInputStream());
SIBConfig remoteConfig = gson.fromJson(new String(configBytes, StandardCharsets.UTF_8), SIBConfig.class);

try (FileOutputStream fileOutputStream = new FileOutputStream(cacheFile)) {
fileOutputStream.write(configBytes);
}

return remoteConfig;
} catch (Exception e) {
SerializationIsBad.logger.error("Failed to load remote config file", e);
}

if (cacheFile.isFile()) {
try (FileInputStream fileInputStream = new FileInputStream(cacheFile)) {
return gson.fromJson(new InputStreamReader(fileInputStream, StandardCharsets.UTF_8), SIBConfig.class);
} catch (Exception e) {
SerializationIsBad.logger.error("Failed to load cached remote config file", e);
}
}

return null;
}

private static byte[] readInputStream(InputStream inputStream) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int read;

while ((read = inputStream.read(buffer)) != -1) {
byteArrayOutputStream.write(buffer, 0, read);
}

return byteArrayOutputStream.toByteArray();
}

private static String getImplementationType() {
for (StackTraceElement stackTraceElement : Thread.currentThread().getStackTrace()) {
if (stackTraceElement.getClassName().startsWith("io.dogboy.serializationisbad.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

public class PatchModule {
private Set<String> classesToPatch;
private Set<String> customOISClasses;
private Set<String> classAllowlist;
private Set<String> packageAllowlist;

public PatchModule() {
this.classesToPatch = new HashSet<>();
this.customOISClasses = new HashSet<>();
this.classAllowlist = new HashSet<>();
this.packageAllowlist = new HashSet<>();
}
Expand All @@ -22,6 +24,15 @@ public void setClassesToPatch(Set<String> classesToPatch) {
this.classesToPatch = classesToPatch;
}


public Set<String> getCustomOISClasses() {
return this.customOISClasses;
}

public void setCustomOISClasses(Set<String> customOISClasses) {
this.customOISClasses = customOISClasses;
}

public Set<String> getClassAllowlist() {
return this.classAllowlist;
}
Expand All @@ -37,4 +48,5 @@ public Set<String> getPackageAllowlist() {
public void setPackageAllowlist(Set<String> packageAllowlist) {
this.packageAllowlist = packageAllowlist;
}

}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
group=io.dogboy.serializationisbad
name=serializationisbad
version=1.4
version=1.5
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class SIBTransformer implements ITransformer<ClassNode> {
private final PatchModule patchModule;
Expand All @@ -36,7 +37,8 @@ public TransformerVoteResult castVote(ITransformerVotingContext context) {

@Override
public Set<Target> targets() {
return this.patchModule.getClassesToPatch().stream()
return Stream.concat(this.patchModule.getClassesToPatch().stream(),
this.patchModule.getCustomOISClasses().stream())
.map(Target::targetClass)
.collect(Collectors.toSet());
}
Expand Down

0 comments on commit 5b1b03a

Please sign in to comment.