Skip to content

Commit

Permalink
Fix/get method from stackframe (#601)
Browse files Browse the repository at this point in the history
* .gitignore update
* GetMethodFromStackframe fix
* Revert ".gitignore update"
This reverts commit 679d2d7.
* cleans up, refactors getting original methods [DebugFat breaking]
---------
Co-authored-by: kohanis <[email protected]>
  • Loading branch information
pardeike authored Mar 25, 2024
1 parent 669729f commit acbd6ee
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Harmony/Internal/CodeTranspiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ internal static IEnumerable ConvertToGeneralInstructions(MethodInfo transpiler,
{
var type = transpiler.GetParameters()
.Select(p => p.ParameterType)
.FirstOrDefault(t => IsCodeInstructionsParameter(t));
.FirstOrDefault(IsCodeInstructionsParameter);
if (type == typeof(IEnumerable<CodeInstruction>))
{
unassignedValues = null;
Expand Down
69 changes: 61 additions & 8 deletions Harmony/Internal/HarmonySharedState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,24 @@ internal static class HarmonySharedState
const string name = "HarmonySharedState";
internal const int internalVersion = 102; // bump this if the layout of the HarmonySharedState type changes

// state/originals/methodStarts are set to instances stored in the global dynamic types static fields with the same name
// state/originals/originalsMono are set to instances stored in the global dynamic types static fields with the same name
static readonly Dictionary<MethodBase, byte[]> state;
static readonly Dictionary<MethodInfo, MethodBase> originals;

static readonly Dictionary<long, MethodBase[]> originalsMono;

static readonly AccessTools.FieldRef<StackFrame, long> methodAddressRef;

internal static readonly int actualVersion;

static HarmonySharedState()
{
// create singleton type
var type = GetOrCreateSharedStateType();

// this field is useed to find methods from stackframes in Mono
if (AccessTools.IsMonoRuntime && AccessTools.Field(typeof(StackFrame), "methodAddress") is FieldInfo field)
methodAddressRef = AccessTools.FieldRefAccess<StackFrame, long>(field);

// copy 'actualVersion' over to our fields
var versionField = type.GetField("version");
if ((int)versionField.GetValue(null) == 0)
Expand All @@ -62,13 +69,23 @@ static HarmonySharedState()
if (originalsField != null && originalsField.GetValue(null) is null)
originalsField.SetValue(null, new Dictionary<MethodInfo, MethodBase>());

// get or initialize global 'originalsMono' field
var originalsMonoField = type.GetField("originalsMono");
if (originalsMonoField != null && originalsMonoField.GetValue(null) is null)
originalsMonoField.SetValue(null, new Dictionary<long, MethodBase[]>());

// copy 'state' over to our fields
state = (Dictionary<MethodBase, byte[]>)stateField.GetValue(null);

// copy 'originals' over to our fields
originals = [];
if (originalsField != null) // may not exist in older versions
originals = (Dictionary<MethodInfo, MethodBase>)originalsField.GetValue(null);

// copy 'originalsMono' over to our fields
originalsMono = [];
if (originalsMonoField != null) // may not exist in older versions
originalsMono = (Dictionary<long, MethodBase[]>)originalsMonoField.GetValue(null);
}

// creates a dynamic 'global' type if it does not exist
Expand All @@ -94,6 +111,12 @@ static Type GetOrCreateSharedStateType()
module.ImportReference(typeof(Dictionary<MethodInfo, MethodBase>))
));

typedef.Fields.Add(new FieldDefinition(
"originalsMono",
Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static,
module.ImportReference(typeof(Dictionary<long, MethodBase[]>))
));

typedef.Fields.Add(new FieldDefinition(
"version",
Mono.Cecil.FieldAttributes.Public | Mono.Cecil.FieldAttributes.Static,
Expand Down Expand Up @@ -122,19 +145,49 @@ internal static void UpdatePatchInfo(MethodBase original, MethodInfo replacement
{
var bytes = patchInfo.Serialize();
lock (state) state[original] = bytes;
lock (originals) originals[replacement] = original;
lock (originals) originals[replacement.Identifiable()] = original;
if (AccessTools.IsMonoRuntime)
{
var methodAddress = (long)replacement.MethodHandle.GetFunctionPointer();
lock (originalsMono) originalsMono[methodAddress] = [original, replacement];
}
}

internal static MethodBase GetOriginal(MethodInfo replacement)
// With mono, useReplacement is used to either return the original or the replacement
// On .NET, useReplacement is ignored and the original is always returned
internal static MethodBase GetRealMethod(MethodInfo method, bool useReplacement)
{
lock (originals) return originals.GetValueSafe(replacement);
var identifiableMethod = method.Identifiable();
lock (originals)
if (originals.TryGetValue(identifiableMethod, out var original))
return original;

if (AccessTools.IsMonoRuntime)
{
var methodAddress = (long)method.MethodHandle.GetFunctionPointer();
lock (originalsMono)
if (originalsMono.TryGetValue(methodAddress, out var info))
return useReplacement ? info[1] : info[0];
}

return method;
}

internal static MethodBase FindReplacement(StackFrame frame)
internal static MethodBase GetStackFrameMethod(StackFrame frame, bool useReplacement)
{
var method = frame.GetMethod() as MethodInfo;
if (method == null) return null;
return GetOriginal(method);
if (method != null)
return GetRealMethod(method, useReplacement);

if (methodAddressRef != null)
{
var methodAddress = methodAddressRef(frame);
lock (originalsMono)
if (originalsMono.TryGetValue(methodAddress, out var info))
return useReplacement ? info[1] : info[0];
}

return null;
}
}
}
2 changes: 1 addition & 1 deletion Harmony/Internal/MethodCopier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ internal List<CodeInstruction> FinalizeILCodes(Emitter emitter, List<MethodInfo>
// pass2 - filter through all processors
//
var codeTranspiler = new CodeTranspiler(ilInstructions);
transpilers.Do(transpiler => codeTranspiler.Add(transpiler));
transpilers.Do(codeTranspiler.Add);
var codeInstructions = codeTranspiler.GetResult(generator, method);

if (emitter is null)
Expand Down
2 changes: 1 addition & 1 deletion Harmony/Internal/MethodPatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ internal MethodInfo CreateReplacement(out Dictionary<int, CodeInstruction> final

Label? skipOriginalLabel = null;
LocalBuilder runOriginalVariable = null;
var prefixAffectsOriginal = prefixes.Any(fix => PrefixAffectsOriginal(fix));
var prefixAffectsOriginal = prefixes.Any(PrefixAffectsOriginal);
var anyFixHasRunOriginalVar = fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RUN_ORIGINAL_VAR));
if (prefixAffectsOriginal || anyFixHasRunOriginalVar)
{
Expand Down
2 changes: 1 addition & 1 deletion Harmony/Internal/PatchModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ internal static AttributePatch Create(MethodInfo patch)
var f_info = AccessTools.Field(attr.GetType(), nameof(HarmonyAttribute.info));
return f_info.GetValue(attr);
})
.Select(harmonyInfo => AccessTools.MakeDeepCopy<HarmonyMethod>(harmonyInfo))
.Select(AccessTools.MakeDeepCopy<HarmonyMethod>)
.ToList();
var info = HarmonyMethod.Merge(list);
info.method = patch;
Expand Down
4 changes: 2 additions & 2 deletions Harmony/Internal/PatchTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal static void DetourMethod(MethodBase method, MethodBase replacement)
static Assembly GetExecutingAssemblyReplacement()
{
var frames = new StackTrace().GetFrames();
if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetOriginalMethodFromStackframe(frame) is { } original)
if (frames?.Skip(1).FirstOrDefault() is { } frame && Harmony.GetMethodFromStackframe(frame) is { } original)
return original.Module.Assembly;
return Assembly.GetExecutingAssembly();
}
Expand Down Expand Up @@ -78,7 +78,7 @@ internal static AssemblyBuilder DefineDynamicAssembly(string name)
internal static List<AttributePatch> GetPatchMethods(Type type)
{
return AccessTools.GetDeclaredMethods(type)
.Select(method => AttributePatch.Create(method))
.Select(AttributePatch.Create)
.Where(attributePatch => attributePatch is not null)
.ToList();
}
Expand Down
20 changes: 8 additions & 12 deletions Harmony/Public/Harmony.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using MonoMod.Core.Platforms;
using System;
using System.Collections.Generic;
using System.Diagnostics;
Expand Down Expand Up @@ -227,7 +226,7 @@ public void Unpatch(MethodBase original, MethodInfo patch)
public static bool HasAnyPatches(string harmonyID)
{
return GetAllPatchedMethods()
.Select(original => GetPatchInfo(original))
.Select(GetPatchInfo)
.Any(info => info.Owners.Contains(harmonyID));
}

Expand All @@ -252,15 +251,13 @@ public IEnumerable<MethodBase> GetPatchedMethods()
public static IEnumerable<MethodBase> GetAllPatchedMethods() => PatchProcessor.GetAllPatchedMethods();

/// <summary>Gets the original method from a given replacement method</summary>
/// <param name="replacement">A replacement method, for example from a stacktrace</param>
/// <param name="replacement">A replacement method (patched original method)</param>
/// <returns>The original method/constructor or <c>null</c> if not found</returns>
///
public static MethodBase GetOriginalMethod(MethodInfo replacement)
{
if (replacement == null) throw new ArgumentNullException(nameof(replacement));
// The runtime can return several different MethodInfo's that point to the same method. Use the correct one
var identifiableReplacement = PlatformTriple.Current.GetIdentifiable(replacement) as MethodInfo;
return HarmonySharedState.GetOriginal(identifiableReplacement);
return HarmonySharedState.GetRealMethod(replacement, useReplacement: false);
}

/// <summary>Tries to get the method from a stackframe including dynamic replacement methods</summary>
Expand All @@ -270,24 +267,23 @@ public static MethodBase GetOriginalMethod(MethodInfo replacement)
public static MethodBase GetMethodFromStackframe(StackFrame frame)
{
if (frame == null) throw new ArgumentNullException(nameof(frame));
return HarmonySharedState.FindReplacement(frame) ?? frame.GetMethod();
return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: true);
}

/// <summary>Gets the original method from the stackframe and uses original if method is a dynamic replacement</summary>
/// <param name="frame">The <see cref="StackFrame"/></param>
/// <returns>The original method from that stackframe</returns>
public static MethodBase GetOriginalMethodFromStackframe(StackFrame frame)
{
var member = GetMethodFromStackframe(frame);
if (member is MethodInfo methodInfo)
member = GetOriginalMethod(methodInfo) ?? member;
return member;
if (frame == null) throw new ArgumentNullException(nameof(frame));
return HarmonySharedState.GetStackFrameMethod(frame, useReplacement: false);
}

/// <summary>Gets Harmony version for all active Harmony instances</summary>
/// <param name="currentVersion">[out] The current Harmony version</param>
/// <returns>A dictionary containing assembly versions keyed by Harmony IDs</returns>
///
public static Dictionary<string, Version> VersionInfo(out Version currentVersion) => PatchProcessor.VersionInfo(out currentVersion);
public static Dictionary<string, Version> VersionInfo(out Version currentVersion)
=> PatchProcessor.VersionInfo(out currentVersion);
}
}
4 changes: 2 additions & 2 deletions Harmony/Public/HarmonyMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ static HarmonyMethod GetHarmonyMethodInfo(object attribute)
public static List<HarmonyMethod> GetFromType(Type type)
{
return type.GetCustomAttributes(true)
.Select(attr => GetHarmonyMethodInfo(attr))
.Select(GetHarmonyMethodInfo)
.Where(info => info is not null)
.ToList();
}
Expand All @@ -310,7 +310,7 @@ public static List<HarmonyMethod> GetFromType(Type type)
public static List<HarmonyMethod> GetFromMethod(MethodBase method)
{
return method.GetCustomAttributes(true)
.Select(attr => GetHarmonyMethodInfo(attr))
.Select(GetHarmonyMethodInfo)
.Where(info => info is not null)
.ToList();
}
Expand Down
8 changes: 7 additions & 1 deletion Harmony/Tools/AccessTools.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using MonoMod.Core.Platforms;
using MonoMod.Utils;
using System;
using System.Collections;
Expand Down Expand Up @@ -84,7 +85,7 @@ public static Type[] GetTypesFromAssembly(Assembly assembly)
/// <summary>Enumerates all successfully loaded types in the current app domain, excluding visual studio assemblies</summary>
/// <returns>An enumeration of all <see cref="Type"/> in all assemblies, excluding visual studio assemblies</returns>
///
public static IEnumerable<Type> AllTypes() => AllAssemblies().SelectMany(a => GetTypesFromAssembly(a));
public static IEnumerable<Type> AllTypes() => AllAssemblies().SelectMany(GetTypesFromAssembly);

/// <summary>Enumerates all inner types (non-recursive) of a given type</summary>
/// <param name="type">The class/type to start with</param>
Expand Down Expand Up @@ -133,6 +134,11 @@ public static T FindIncludingInnerTypes<T>(Type type, Func<Type, T> func) where
return result;
}

/// <summary>Creates an identifiable version of a method</summary>
/// <param name="method">The method</param>
/// <returns></returns>
public static MethodInfo Identifiable(this MethodInfo method) => PlatformTriple.Current.GetIdentifiable(method) as MethodInfo ?? method;

/// <summary>Gets the reflection information for a directly declared field</summary>
/// <param name="type">The class/type where the field is defined</param>
/// <param name="name">The name of the field</param>
Expand Down
32 changes: 16 additions & 16 deletions HarmonyTests/Extras/RetrieveOriginalMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Diagnostics;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace HarmonyLibTests.Extras
{
Expand All @@ -14,20 +15,20 @@ private static void CheckStackTraceFor(MethodBase expectedMethod)
Assert.NotNull(expectedMethod);

var st = new StackTrace(1, false);
var method = Harmony.GetMethodFromStackframe(st.GetFrame(0));

Assert.NotNull(method);

if (method is MethodInfo replacement)
{
var original = Harmony.GetOriginalMethod(replacement);
Assert.NotNull(original);
Assert.AreEqual(original, expectedMethod);
}
var frame = st.GetFrame(0);
Assert.NotNull(frame);

var methodFromStackframe = Harmony.GetMethodFromStackframe(frame);
Assert.NotNull(methodFromStackframe);
Assert.AreEqual(expectedMethod, methodFromStackframe);

var replacement = frame.GetMethod() as MethodInfo;
Assert.NotNull(replacement);
var original = Harmony.GetOriginalMethod(replacement);
Assert.NotNull(original);
Assert.AreEqual(expectedMethod, original);
}

/* TODO
*
[Test]
public void TestRegularMethod()
{
Expand All @@ -37,7 +38,7 @@ public void TestRegularMethod()
_ = harmony.Patch(originalMethod, new HarmonyMethod(dummyPrefix));
PatchTarget();
}

[Test]
public void TestConstructor()
{
Expand All @@ -48,7 +49,6 @@ public void TestConstructor()
var inst = new NestedClass(5);
_ = inst.index;
}
*/

internal static void PatchTarget()
{
Expand All @@ -60,7 +60,7 @@ internal static void PatchTarget()
}
}

// [MethodImpl(MethodImplOptions.NoInlining)]
[MethodImpl(MethodImplOptions.NoInlining)]
internal static void DummyPrefix()
{
}
Expand All @@ -69,7 +69,7 @@ class NestedClass {
public NestedClass(int i)
{
try {
CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)]));
CheckStackTraceFor(AccessTools.Constructor(typeof(NestedClass), [typeof(int)]));
throw new Exception();
} catch (Exception e)
{
Expand Down

0 comments on commit acbd6ee

Please sign in to comment.