diff --git a/src/D3D12Hook.cpp b/src/D3D12Hook.cpp index 4f479774..a35c7e0b 100644 --- a/src/D3D12Hook.cpp +++ b/src/D3D12Hook.cpp @@ -2,12 +2,14 @@ #include #include #include +#include #include #include #include #include #include +#include #include "REFramework.hpp" @@ -21,6 +23,74 @@ D3D12Hook::~D3D12Hook() { unhook(); } +void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9) { + std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; + + spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress()); + + g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around + + if (g_d3d12_hook != nullptr) { + g_d3d12_hook->unhook(); // Removes all vtable hooks + } + + auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook; + const auto result = hook->get_original()(rcx, rdx, r8, r9); + + // Re-hooks present after the above function creates the swapchain + // This allows the hook to immediately still function + // rather than waiting on the hook monitor to notice the hook isn't working + g_framework->hook_d3d12(); + + return result; +} + +void D3D12Hook::hook_streamline() { + if (D3D12Hook::s_streamline.setup) { + return; + } + + spdlog::info("[Streamline] Hooking Streamline"); + + const auto dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll"); + + if (dlssg_module == nullptr) { + spdlog::error("[Streamline] Failed to get sl.dlss_g.dll module handle"); + return; + } + + const auto str = utility::scan_string(dlssg_module, "linkSwapchainToCmdQueue"); + + if (!str) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue"); + return; + } + + const auto str_ref = utility::scan_displacement_reference(dlssg_module, *str); + + if (!str_ref) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue reference"); + return; + } + + const auto fn = utility::find_function_start_with_call(*str_ref); + + if (!fn) { + spdlog::error("[Streamline] Failed to find linkSwapchainToCmdQueue function"); + return; + } + + D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook = std::make_unique(*fn, (uintptr_t)&Streamline::link_swapchain_to_cmd_queue); + + if (D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook->create()) { + spdlog::info("[Streamline] Hooked linkSwapchainToCmdQueue"); + } else { + spdlog::error("[Streamline] Failed to hook linkSwapchainToCmdQueue"); + } + + D3D12Hook::s_streamline.setup = true; +} + bool D3D12Hook::hook() { spdlog::info("Hooking D3D12"); @@ -330,7 +400,9 @@ bool D3D12Hook::hook() { return false; } - utility::ThreadSuspender suspender{}; + hook_streamline(); + + //utility::ThreadSuspender suspender{}; try { spdlog::info("Initializing hooks"); @@ -341,20 +413,21 @@ bool D3D12Hook::hook() { m_is_phase_1 = true; auto& present_fn = (*(void***)target_swapchain)[8]; // Present - m_present_hook = std::make_unique(&present_fn, (void*)&D3D12Hook::present); + m_present_hook = std::make_unique((uintptr_t)present_fn, (uintptr_t)&D3D12Hook::present); + m_present_hook->create(); m_hooked = true; } catch (const std::exception& e) { spdlog::error("Failed to initialize hooks: {}", e.what()); m_hooked = false; } - suspender.resume(); + //suspender.resume(); - device->Release(); command_queue->Release(); - factory->Release(); swap_chain1->Release(); swap_chain->Release(); + device->Release(); + factory->Release(); if (hwnd) { ::DestroyWindow(hwnd); @@ -368,6 +441,8 @@ bool D3D12Hook::hook() { } bool D3D12Hook::unhook() { + std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; + if (!m_hooked) { return true; } @@ -385,57 +460,68 @@ bool D3D12Hook::unhook() { thread_local int32_t g_present_depth = 0; -HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags) { +HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9) { std::scoped_lock _{g_framework->get_hook_monitor_mutex()}; auto d3d12 = g_d3d12_hook; - HWND swapchain_wnd{nullptr}; - swap_chain->GetHwnd(&swapchain_wnd); - decltype(D3D12Hook::present)* present_fn{nullptr}; - //if (d3d12->m_is_phase_1) { - present_fn = d3d12->m_present_hook->get_original(); - /*} else { + if (d3d12->m_is_phase_1) { + //present_fn = d3d12->m_present_hook->get_original(); + present_fn = d3d12->m_present_hook->get_original(); + } else { present_fn = d3d12->m_swapchain_hook->get_method(8); - }*/ + } + + HWND swapchain_wnd{nullptr}; + swap_chain->GetHwnd(&swapchain_wnd); if (d3d12->m_is_phase_1 && WindowFilter::get().is_filtered(swapchain_wnd)) { //present_fn = d3d12->m_present_hook->get_original(); - return present_fn(swap_chain, sync_interval, flags); + return present_fn(swap_chain, sync_interval, flags, r9); } if (!d3d12->m_is_phase_1 && swap_chain != d3d12->m_swapchain_hook->get_instance()) { - return present_fn(swap_chain, sync_interval, flags); + return present_fn(swap_chain, sync_interval, flags, r9); } if (d3d12->m_is_phase_1) { - //d3d12->m_present_hook.reset(); + // Remove the present hook, we will just rely on the vtable hook below + // because we don't want to cause any conflicts with other hooks + // vtable hooks are the least intrusive + // And doing a global pointer replacement seems to have + // conflicts with Streamline's hooks, causing unexplainable crashes + d3d12->m_present_hook.reset(); // vtable hook the swapchain instead of global hooking // this seems safer for whatever reason // if we globally hook the vtable pointers, it causes all sorts of weird conflicts with other hooks // dont hook present though via this hook so other hooks dont get confused d3d12->m_swapchain_hook = std::make_unique(swap_chain); - //d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present); + //d3d12->m_swapchain_hook->hook_method(2, (uintptr_t)&D3D12Hook::release); + d3d12->m_swapchain_hook->hook_method(8, (uintptr_t)&D3D12Hook::present); d3d12->m_swapchain_hook->hook_method(13, (uintptr_t)&D3D12Hook::resize_buffers); d3d12->m_swapchain_hook->hook_method(14, (uintptr_t)&D3D12Hook::resize_target); d3d12->m_is_phase_1 = false; + + present_fn = d3d12->m_swapchain_hook->get_method(8); } d3d12->m_inside_present = true; d3d12->m_swap_chain = swap_chain; - swap_chain->GetDevice(IID_PPV_ARGS(&d3d12->m_device)); + { + Microsoft::WRL::ComPtr temp_device{}; + swap_chain->GetDevice(IID_PPV_ARGS(&temp_device)); + d3d12->m_device = temp_device.Get(); + } - if (d3d12->m_device != nullptr) { - if (d3d12->m_using_proton_swapchain) { - const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset); - d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset); - } else { - d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset); - } + if (d3d12->m_using_proton_swapchain) { + const auto real_swapchain = *(uintptr_t*)((uintptr_t)swap_chain + d3d12->m_proton_swapchain_offset); + d3d12->m_command_queue = *(ID3D12CommandQueue**)(real_swapchain + d3d12->m_command_queue_offset); + } else { + d3d12->m_command_queue = *(ID3D12CommandQueue**)((uintptr_t)swap_chain + d3d12->m_command_queue_offset); } if (d3d12->m_swapchain_0 == nullptr) { @@ -462,7 +548,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva spdlog::info("Attempting to call real present function"); ++g_present_depth; - const auto result = present_fn(swap_chain, sync_interval, flags); + const auto result = present_fn(swap_chain, sync_interval, flags, r9); --g_present_depth; if (result != S_OK) { @@ -485,7 +571,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, UINT sync_interva auto result = S_OK; if (!d3d12->m_ignore_next_present) { - result = present_fn(swap_chain, sync_interval, flags); + result = present_fn(swap_chain, sync_interval, flags, r9); if (result != S_OK) { spdlog::error("Present failed: {:x}", (uint64_t)result); diff --git a/src/D3D12Hook.hpp b/src/D3D12Hook.hpp index 1a3315ac..d65a8555 100644 --- a/src/D3D12Hook.hpp +++ b/src/D3D12Hook.hpp @@ -10,6 +10,7 @@ #include #include "utility/PointerHook.hpp" +#include "utility/FunctionHook.hpp" #include "utility/VtableHook.hpp" class D3D12Hook @@ -88,7 +89,7 @@ class D3D12Hook bool is_proton_swapchain() const { return m_using_proton_swapchain; } - + bool is_framegen_swapchain() const { return m_using_frame_generation_swapchain; } @@ -97,6 +98,8 @@ class D3D12Hook m_ignore_next_present = true; } + void hook_streamline(); + protected: ID3D12Device4* m_device{ nullptr }; IDXGISwapChain3* m_swap_chain{ nullptr }; @@ -118,17 +121,27 @@ class D3D12Hook bool m_inside_present{false}; bool m_ignore_next_present{false}; - std::unique_ptr m_present_hook{}; + std::unique_ptr m_present_hook{}; + //std::unique_ptr m_release_hook{}; std::unique_ptr m_swapchain_hook{}; //std::unique_ptr m_create_swap_chain_hook{}; + struct Streamline { + static void* link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9); + + std::unique_ptr link_swapchain_to_cmd_queue_hook{}; + bool setup{ false }; + }; + + static inline Streamline s_streamline{}; + OnPresentFn m_on_present{ nullptr }; OnPresentFn m_on_post_present{ nullptr }; OnResizeBuffersFn m_on_resize_buffers{ nullptr }; OnResizeTargetFn m_on_resize_target{ nullptr }; //OnCreateSwapChainFn m_on_create_swap_chain{ nullptr }; - static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, UINT sync_interval, UINT flags); + static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9); static HRESULT WINAPI resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags); static HRESULT WINAPI resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters); //static HRESULT WINAPI create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain); diff --git a/src/REFramework.hpp b/src/REFramework.hpp index 60b4deb8..cd42303b 100644 --- a/src/REFramework.hpp +++ b/src/REFramework.hpp @@ -145,9 +145,11 @@ class REFramework { void draw_ui(); void draw_about(); +public: bool hook_d3d11(); bool hook_d3d12(); +private: bool initialize(); bool initialize_game_data(); bool initialize_windows_message_hook();