Skip to content

Commit

Permalink
Merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cursey committed Mar 15, 2024
1 parent f78fb9d commit aba9ff4
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 279 deletions.
12 changes: 1 addition & 11 deletions include/safetyhook/os.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,9 @@ struct SystemInfo {

SystemInfo system_info();

using ThreadId = uint32_t;
using ThreadHandle = void*;
using ThreadContext = void*;

/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and
/// modifying it's context.
/// @param run_fn The function to run while all other threads are frozen.
/// @param visit_fn The function that will be called for each frozen thread.
/// @note The visit function will be called in the order that the threads were frozen.
/// @note The visit function will be called before the run function.
/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks.
void execute_while_frozen(const std::function<void()>& run_fn,
const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn = {});
void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function<void()>& run_fn);

/// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address.
/// @param ctx The thread context to modify.
Expand Down
21 changes: 0 additions & 21 deletions include/safetyhook/thread_freezer.hpp

This file was deleted.

4 changes: 2 additions & 2 deletions src/os.linux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ SystemInfo system_info() {
};
}

void execute_while_frozen(const std::function<void()>& run_fn,
[[maybe_unused]] const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn) {
void trap_threads([[maybe_unused]] uint8_t* from, [[maybe_unused]] uint8_t* to, [[maybe_unused]] size_t len,
const std::function<void()>& run_fn) {
run_fn();
}

Expand Down
179 changes: 98 additions & 81 deletions src/os.windows.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include <map>
#include <memory>
#include <mutex>

#include "safetyhook/common.hpp"
#include "safetyhook/utility.hpp"

#if SAFETYHOOK_OS_WINDOWS

Expand All @@ -11,19 +16,8 @@
#error "Windows.h not found"
#endif

#include <winternl.h>

#include "safetyhook/os.hpp"

#pragma comment(lib, "ntdll")

extern "C" {
NTSTATUS
NTAPI
NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAccess, ULONG HandleAttributes,
ULONG Flags, PHANDLE NewThreadHandle);
}

namespace safetyhook {
std::expected<uint8_t*, OsError> vm_allocate(uint8_t* address, size_t size, VmAccess access) {
DWORD protect = 0;
Expand Down Expand Up @@ -158,104 +152,127 @@ SystemInfo system_info() {
return info;
}

void execute_while_frozen(
const std::function<void()>& run_fn, const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn) {
// Freeze all threads.
int num_threads_frozen;
auto first_run = true;

do {
num_threads_frozen = 0;
HANDLE thread{};
struct TrapInfo {
uint8_t* page_start;
uint8_t* page_end;
uint8_t* from;
uint8_t* to;
size_t len;
};

class TrapManager final {
public:
static std::mutex mutex;
static TrapManager instance;

~TrapManager() {
if (m_trap_veh != nullptr) {
RemoveVectoredExceptionHandler(m_trap_veh);
}
}

while (true) {
HANDLE next_thread{};
const auto status = NtGetNextThread(GetCurrentProcess(), thread,
THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0,
0, &next_thread);
TrapInfo* find_trap(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) {
return address >= trap.second.from && address < trap.second.from + trap.second.len;
});

if (thread != nullptr) {
CloseHandle(thread);
}
if (search == m_traps.end()) {
return nullptr;
}

if (!NT_SUCCESS(status)) {
break;
}
return &search->second;
}

thread = next_thread;
TrapInfo* find_trap_page(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(),
[address](auto& trap) { return address >= trap.second.page_start && address < trap.second.page_end; });

const auto thread_id = GetThreadId(thread);
if (search == m_traps.end()) {
return nullptr;
}

if (thread_id == 0 || thread_id == GetCurrentThreadId()) {
continue;
}
return &search->second;
}

const auto suspend_count = SuspendThread(thread);
void add_trap(uint8_t* from, uint8_t* to, size_t len) {
m_traps.insert_or_assign(from, TrapInfo{.page_start = align_down(from, 0x1000),
.page_end = align_up(from + len, 0x1000),
.from = from,
.to = to,
.len = len});
}

if (suspend_count == static_cast<DWORD>(-1)) {
continue;
}
private:
std::map<uint8_t*, TrapInfo> m_traps;
PVOID m_trap_veh{};

// Check if the thread was already frozen. Only resume if the thread was already frozen, and it wasn't the
// first run of this freeze loop to account for threads that may have already been frozen for other reasons.
if (suspend_count != 0 && !first_run) {
ResumeThread(thread);
continue;
}
TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); }

CONTEXT thread_ctx{};
static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) {
auto exception_code = exp->ExceptionRecord->ExceptionCode;

thread_ctx.ContextFlags = CONTEXT_FULL;
if (exception_code != EXCEPTION_ACCESS_VIOLATION) {
return EXCEPTION_CONTINUE_SEARCH;
}

if (GetThreadContext(thread, &thread_ctx) == FALSE) {
continue;
}
std::scoped_lock lock{mutex};
auto* faulting_address = reinterpret_cast<uint8_t*>(exp->ExceptionRecord->ExceptionInformation[1]);
auto* trap = instance.find_trap(faulting_address);

if (visit_fn) {
visit_fn(static_cast<ThreadId>(thread_id), static_cast<ThreadHandle>(thread),
static_cast<ThreadContext>(&thread_ctx));
if (trap == nullptr) {
if (instance.find_trap_page(faulting_address) != nullptr) {
return EXCEPTION_CONTINUE_EXECUTION;
} else {
return EXCEPTION_CONTINUE_SEARCH;
}
}

SetThreadContext(thread, &thread_ctx);
auto* ctx = exp->ContextRecord;

++num_threads_frozen;
for (size_t i = 0; i < trap->len; i++) {
fix_ip(ctx, trap->from + i, trap->to + i);
}

first_run = false;
} while (num_threads_frozen != 0);

// Run the function.
if (run_fn) {
run_fn();
return EXCEPTION_CONTINUE_EXECUTION;
}
};

// Resume all threads.
HANDLE thread{};
std::mutex TrapManager::mutex{};
TrapManager TrapManager::instance{};

while (true) {
HANDLE next_thread{};
const auto status = NtGetNextThread(GetCurrentProcess(), thread,
THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, 0,
&next_thread);
void find_me() {
}

if (thread != nullptr) {
CloseHandle(thread);
}
void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function<void()>& run_fn) {
MEMORY_BASIC_INFORMATION find_me_mbi{};
MEMORY_BASIC_INFORMATION from_mbi{};
MEMORY_BASIC_INFORMATION to_mbi{};

if (!NT_SUCCESS(status)) {
break;
}
VirtualQuery(reinterpret_cast<void*>(find_me), &find_me_mbi, sizeof(find_me_mbi));
VirtualQuery(from, &from_mbi, sizeof(from_mbi));
VirtualQuery(to, &to_mbi, sizeof(to_mbi));

auto new_protect = PAGE_READWRITE;

thread = next_thread;
if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) {
new_protect = PAGE_EXECUTE_READWRITE;
}

const auto thread_id = GetThreadId(thread);
std::scoped_lock lock{TrapManager::mutex};
TrapManager::instance.add_trap(from, to, len);

if (thread_id == 0 || thread_id == GetCurrentThreadId()) {
continue;
}
DWORD from_protect;
DWORD to_protect;

ResumeThread(thread);
VirtualProtect(from, len, new_protect, &from_protect);
VirtualProtect(to, len, new_protect, &to_protect);

if (run_fn) {
run_fn();
}

VirtualProtect(to, len, to_protect, &to_protect);
VirtualProtect(from, len, from_protect, &from_protect);
}

void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) {
Expand Down
Loading

0 comments on commit aba9ff4

Please sign in to comment.