Skip to content

Commit

Permalink
refactored SHGetFolderPathA hook for condemned
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirteenAG committed Oct 24, 2023
1 parent cfc4e1e commit 142ebad
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 137 deletions.
13 changes: 10 additions & 3 deletions includes/stdafx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ void CreateThreadAutoClose(LPSECURITY_ATTRIBUTES lpThreadAttributes, SIZE_T dwSt
CloseHandle(CreateThread(lpThreadAttributes, dwStackSize, lpStartAddress, lpParameter, dwCreationFlags, lpThreadId));
}

bool IsModuleUAL(HMODULE mod)
{
if (GetProcAddress(mod, "DirectInput8Create") != NULL && GetProcAddress(mod, "DirectSoundCreate8") != NULL && GetProcAddress(mod, "InternetOpenA") != NULL)
return true;
return false;
}

bool IsUALPresent()
{
ModuleList dlls;
dlls.Enumerate(ModuleList::SearchLocation::LocalOnly);
for (auto& e : dlls.m_moduleList)
{
if (GetProcAddress(std::get<HMODULE>(e), "DirectInput8Create") != NULL && GetProcAddress(std::get<HMODULE>(e), "DirectSoundCreate8") != NULL && GetProcAddress(std::get<HMODULE>(e), "InternetOpenA") != NULL)
if (IsModuleUAL(std::get<HMODULE>(e)))
return true;
}
return false;
Expand Down Expand Up @@ -223,5 +230,5 @@ std::string RegistryWrapper::section;
CIniReader RegistryWrapper::RegistryReader;
std::map<std::string, std::string> RegistryWrapper::DefaultStrings;
std::set<std::string, std::less<>> RegistryWrapper::PathStrings;
std::map<std::wstring, std::function<void()>, CallbackHandler::Comparator> CallbackHandler::functions;
std::map<std::wstring, std::function<void()>, CallbackHandler::Comparator> CallbackHandler::functions_unload;
std::map<std::wstring, std::function<void()>, CallbackHandler::Comparator> CallbackHandler::onModuleLoad;
std::map<std::wstring, std::function<void()>, CallbackHandler::Comparator> CallbackHandler::onModuleUnload;
85 changes: 63 additions & 22 deletions includes/stdafx.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ float GetFOV(float f, float ar);
float GetFOV2(float f, float ar);
float AdjustFOV(float f, float ar);

bool IsModuleUAL(HMODULE mod);
bool IsUALPresent();
void CreateThreadAutoClose(LPSECURITY_ATTRIBUTES lpThreadAttributes, SIZE_T dwStackSize, LPTHREAD_START_ROUTINE lpStartAddress, LPVOID lpParameter, DWORD dwCreationFlags, LPDWORD lpThreadId);
std::tuple<int32_t, int32_t> GetDesktopRes();
Expand Down Expand Up @@ -318,12 +319,18 @@ class CallbackHandler
{
RegisterDllNotification();
if (!bOnUnload)
GetCallbackList().emplace(module_name, std::forward<std::function<void()>>(fn));
GetOnModuleLoadCallbackList().emplace(module_name, std::forward<std::function<void()>>(fn));
else
GetUnloadCallbackList().emplace(module_name, std::forward<std::function<void()>>(fn));
GetOnModuleUnloadCallbackList().emplace(module_name, std::forward<std::function<void()>>(fn));
}
}

static inline void RegisterCallback(std::function<void(HMODULE)>&& fn)
{
RegisterDllNotification();
GetOnAnyModuleLoadCallbackList().emplace_back(std::forward<std::function<void(HMODULE)>>(fn));
}

static inline void RegisterCallback(std::function<void()>&& fn, bool bPatternNotFound, ptrdiff_t offset = 0x1100, uint32_t* ptr = nullptr)
{
if (!bPatternNotFound)
Expand Down Expand Up @@ -360,29 +367,47 @@ class CallbackHandler
}

private:
static inline void call(std::wstring_view module_name)
static inline void invokeOnModuleLoad(std::wstring_view module_name)
{
if (GetCallbackList().count(module_name.data()))
if (GetOnModuleLoadCallbackList().count(module_name.data()))
{
GetCallbackList().at(module_name.data())();
//GetCallbackList().erase(module_name.data()); //shouldn't do that in case dll with callback gets unloaded and loaded again
GetOnModuleLoadCallbackList().at(module_name.data())();
}
}

//if (GetCallbackList().empty()) //win7 crash in splinter cell
// UnRegisterDllNotification();
static inline void invokeOnUnload(std::wstring_view module_name)
{
if (GetOnModuleUnloadCallbackList().count(module_name.data()))
{
GetOnModuleUnloadCallbackList().at(module_name.data())();
}
}

static inline void call_onunload(std::wstring_view module_name)
static inline void invokeOnAnyModuleLoad(HMODULE mod)
{
if (GetUnloadCallbackList().count(module_name.data()))
if (!GetOnAnyModuleLoadCallbackList().empty())
{
GetUnloadCallbackList().at(module_name.data())();
for (auto& f : GetOnAnyModuleLoadCallbackList())
{
f(mod);
}
}
}

static inline void invoke_all()
static inline void invokeOnAnyModuleUnload(HMODULE mod)
{
for (auto&& fn : GetCallbackList())
if (!GetOnAnyModuleUnloadCallbackList().empty())
{
for (auto& f : GetOnAnyModuleUnloadCallbackList())
{
f(mod);
}
}
}

static inline void InvokeAll()
{
for (auto&& fn : GetOnModuleLoadCallbackList())
fn.second();
}

Expand All @@ -399,14 +424,24 @@ class CallbackHandler
}
};

static std::map<std::wstring, std::function<void()>, Comparator>& GetCallbackList()
static std::map<std::wstring, std::function<void()>, Comparator>& GetOnModuleLoadCallbackList()
{
return functions;
return onModuleLoad;
}

static std::map<std::wstring, std::function<void()>, Comparator>& GetUnloadCallbackList()
static std::map<std::wstring, std::function<void()>, Comparator>& GetOnModuleUnloadCallbackList()
{
return functions_unload;
return onModuleUnload;
}

static inline std::vector<std::function<void(HMODULE)>>& GetOnAnyModuleLoadCallbackList()
{
return onAnyModuleLoad;
}

static inline std::vector<std::function<void(HMODULE)>>& GetOnAnyModuleUnloadCallbackList()
{
return onAnyModuleUnload;
}

struct ThreadParams
Expand Down Expand Up @@ -466,11 +501,13 @@ class CallbackHandler
static constexpr auto LDR_DLL_NOTIFICATION_REASON_UNLOADED = 2;
if (NotificationReason == LDR_DLL_NOTIFICATION_REASON_LOADED)
{
call(NotificationData->Loaded.BaseDllName->Buffer);
invokeOnModuleLoad(NotificationData->Loaded.BaseDllName->Buffer);
invokeOnAnyModuleLoad((HMODULE)NotificationData->Loaded.DllBase);
}
else if (NotificationReason == LDR_DLL_NOTIFICATION_REASON_UNLOADED)
{
call_onunload(NotificationData->Loaded.BaseDllName->Buffer);
invokeOnUnload(NotificationData->Loaded.BaseDllName->Buffer);
invokeOnAnyModuleUnload((HMODULE)NotificationData->Loaded.DllBase);
}
}

Expand All @@ -479,7 +516,8 @@ class CallbackHandler
//wprintf(L"ProbeCallback: Base %p, path '%ls', context %p\r\n", DllBase, FullDllPath, *ActivationContext);

std::wstring str(FullDllPath);
call(str.substr(str.find_last_of(L"/\\") + 1));
invokeOnModuleLoad(str.substr(str.find_last_of(L"/\\") + 1));
invokeOnAnyModuleLoad(DllBase);

//if (!*ActivationContext)
// return STATUS_INVALID_PARAMETER; // breaks on xp
Expand Down Expand Up @@ -532,6 +570,7 @@ class CallbackHandler
LdrUnregisterDllNotification(cookie);
}

private:
static inline DWORD WINAPI ThreadProc(LPVOID ptr)
{
auto paramsPtr = static_cast<CallbackHandler::ThreadParams*>(ptr);
Expand Down Expand Up @@ -566,8 +605,10 @@ class CallbackHandler
static inline fnLdrSetDllManifestProber LdrSetDllManifestProber;
public:
static inline std::once_flag flag;
static std::map<std::wstring, std::function<void()>, Comparator> functions;
static std::map<std::wstring, std::function<void()>, Comparator> functions_unload;
static std::map<std::wstring, std::function<void()>, Comparator> onModuleLoad;
static std::map<std::wstring, std::function<void()>, Comparator> onModuleUnload;
static inline std::vector<std::function<void(HMODULE)>> onAnyModuleLoad;
static inline std::vector<std::function<void(HMODULE)>> onAnyModuleUnload;
};

class RegistryWrapper
Expand Down
157 changes: 45 additions & 112 deletions source/Condemned.WidescreenFix/dllmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,77 +35,6 @@ HRESULT SHGetFolderPathAHook(HWND hwnd, int csidl, HANDLE hToken, DWORD dwFlags,
return r;
}

void InitSavePath(HMODULE module)
{
if (IniFile.FixSavePath)
{
auto hInst = (size_t)module;
IMAGE_NT_HEADERS* ntHeader = (IMAGE_NT_HEADERS*)(hInst + ((IMAGE_DOS_HEADER*)hInst)->e_lfanew);
IMAGE_IMPORT_DESCRIPTOR* pImports = (IMAGE_IMPORT_DESCRIPTOR*)(hInst + ntHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);
size_t nNumImports = ntHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].Size / sizeof(IMAGE_IMPORT_DESCRIPTOR) - 1;

auto PatchIAT = [&](size_t start, size_t end, size_t exe_end)
{
for (size_t i = 0; i < nNumImports; i++)
{
if (hInst + (pImports + i)->FirstThunk > start && !(end && hInst + (pImports + i)->FirstThunk > end))
end = hInst + (pImports + i)->FirstThunk;
}

if (!end) { end = start + 0x100; }
if (end > exe_end) //for very broken exes
{
start = hInst;
end = exe_end;
}

for (auto i = start; i < end; i += sizeof(size_t))
{
DWORD dwProtect[2];
VirtualProtect((size_t*)i, sizeof(size_t), PAGE_EXECUTE_READWRITE, &dwProtect[0]);

auto ptr = *(size_t*)i;
if (!ptr)
continue;

if (ptr == (size_t)::SHGetFolderPathA)
{
*(size_t*)i = (size_t)SHGetFolderPathAHook;
}

VirtualProtect((size_t*)i, sizeof(size_t), dwProtect[0], &dwProtect[1]);
}
};

static auto getSection = [](const PIMAGE_NT_HEADERS nt_headers, unsigned section) -> PIMAGE_SECTION_HEADER
{
return reinterpret_cast<PIMAGE_SECTION_HEADER>(
(UCHAR*)nt_headers->OptionalHeader.DataDirectory +
nt_headers->OptionalHeader.NumberOfRvaAndSizes * sizeof(IMAGE_DATA_DIRECTORY) +
section * sizeof(IMAGE_SECTION_HEADER));
};

static auto getSectionEnd = [](IMAGE_NT_HEADERS* ntHeader, size_t inst) -> auto
{
auto sec = getSection(ntHeader, ntHeader->FileHeader.NumberOfSections - 1);
auto secSize = max(sec->SizeOfRawData, sec->Misc.VirtualSize);
auto end = inst + max(sec->PointerToRawData, sec->VirtualAddress) + secSize;
return end;
};

auto hInst_end = getSectionEnd(ntHeader, hInst);

for (size_t i = 0; i < nNumImports; i++)
{
if ((size_t)(hInst + (pImports + i)->Name) < hInst_end)
{
if (!_stricmp((const char*)(hInst + (pImports + i)->Name), "SHELL32.dll"))
PatchIAT(hInst + (pImports + i)->FirstThunk, 0, hInst_end);
}
}
}
}

void __fastcall sub_4059F0(float* _this, uint32_t edx, float* a2)
{
_this[59] = 0.0f;
Expand Down Expand Up @@ -202,35 +131,8 @@ void Init()
}
}

void InitSavePathExe()
{
InitSavePath(GetModuleHandle(NULL));
}

void InitSavePathEngineServer()
{
InitSavePath(GetModuleHandle(L"EngineServer"));
}

void InitSavePathGameDatabase()
{
InitSavePath(GetModuleHandle(L"GameDatabase"));
}

void InitSavePathGameServer()
{
InitSavePath(GetModuleHandle(L"GameServer"));
}

void InitSavePathGameClient()
{
InitSavePath(GetModuleHandle(L"GameClient"));
}

void InitGameClient()
{
InitSavePathGameClient();

if (IniFile.FixMenu)
{
auto unk_10169F30 = *hook::module_pattern(GetModuleHandle(L"GameClient"), "C7 05 ? ? ? ? ? ? ? ? C7 05 ? ? ? ? ? ? ? ? E8 ? ? ? ? 6A 04").get_first<void**>(6);
Expand Down Expand Up @@ -289,25 +191,56 @@ void InitConfig()
}
}

void HookShell32IAT(HMODULE mod)
{
IATHook::Replace(mod, "SHELL32.DLL",
std::forward_as_tuple("SHGetFolderPathA", SHGetFolderPathAHook)
);
}

HMODULE hm = NULL;
void OverrideSHGetFolderPathAInDLLs(HMODULE mod)
{
ModuleList dlls;
dlls.Enumerate(ModuleList::SearchLocation::LocalOnly);
for (auto& e : dlls.m_moduleList)
{
auto m = std::get<HMODULE>(e);
if (m == mod && !IsModuleUAL(m) && m != hm && m != GetModuleHandle(NULL))
HookShell32IAT(mod);
}
}

CEXP void InitializeASI()
{
std::call_once(CallbackHandler::flag, []()
{
CIniReader iniReader("");
IniFile.FixAspectRatio = iniReader.ReadInteger("MAIN", "FixAspectRatio", 1) != 0;
IniFile.FixMenu = iniReader.ReadInteger("MAIN", "FixMenu", 1) != 0;
IniFile.FixLowFramerate = iniReader.ReadInteger("MAIN", "FixLowFramerate", 1) != 0;
IniFile.FixSavePath = iniReader.ReadInteger("MAIN", "FixSavePath", 1) != 0;
IniFile.BorderlessWindowed = iniReader.ReadInteger("MAIN", "BorderlessWindowed", 1) != 0;

CallbackHandler::RegisterCallback(Init, hook::pattern("E8 ? ? ? ? 8B C6 5E 83 C4 10 C3"));
CallbackHandler::RegisterCallback(InitConfig, hook::pattern("0F 85 ? ? ? ? 83 7C 24 2C 16 0F 85 ? ? ? ? 6A 7F"));
CallbackHandler::RegisterCallback(L"GameClient.dll", InitGameClient);

if (IniFile.FixSavePath)
{
CIniReader iniReader("");
IniFile.FixAspectRatio = iniReader.ReadInteger("MAIN", "FixAspectRatio", 1) != 0;
IniFile.FixMenu = iniReader.ReadInteger("MAIN", "FixMenu", 1) != 0;
IniFile.FixLowFramerate = iniReader.ReadInteger("MAIN", "FixLowFramerate", 1) != 0;
IniFile.FixSavePath = iniReader.ReadInteger("MAIN", "FixSavePath", 1) != 0;
IniFile.BorderlessWindowed = iniReader.ReadInteger("MAIN", "BorderlessWindowed", 1) != 0;
GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, (LPCSTR)&OverrideSHGetFolderPathAInDLLs, &hm);

CallbackHandler::RegisterCallback(Init, hook::pattern("E8 ? ? ? ? 8B C6 5E 83 C4 10 C3"));
CallbackHandler::RegisterCallback(InitConfig, hook::pattern("0F 85 ? ? ? ? 83 7C 24 2C 16 0F 85 ? ? ? ? 6A 7F"));
CallbackHandler::RegisterCallback(InitSavePathExe, hook::pattern("FF 15 ? ? ? ? 85 C0 7C 4F 56 57 BE ? ? ? ? 56"));
CallbackHandler::RegisterCallback(L"GameClient.dll", InitGameClient);
CallbackHandler::RegisterCallback(L"EngineServer.dll", InitSavePathEngineServer);
CallbackHandler::RegisterCallback(L"GameDatabase.dll", InitSavePathGameDatabase);
CallbackHandler::RegisterCallback(L"GameServer.dll", InitSavePathGameServer);
});
ModuleList dlls;
dlls.Enumerate(ModuleList::SearchLocation::LocalOnly);
for (auto& e : dlls.m_moduleList)
{
auto m = std::get<HMODULE>(e);
if (!IsModuleUAL(m) && m != hm)
HookShell32IAT(m);
}
CallbackHandler::RegisterCallback(OverrideSHGetFolderPathAInDLLs);
}
});
}

BOOL APIENTRY DllMain(HMODULE hModule, DWORD reason, LPVOID lpReserved)
Expand Down

0 comments on commit 142ebad

Please sign in to comment.