diff --git a/app/src/main/cpp/skyline/nce.cpp b/app/src/main/cpp/skyline/nce.cpp index c06948c3..b7437788 100644 --- a/app/src/main/cpp/skyline/nce.cpp +++ b/app/src/main/cpp/skyline/nce.cpp @@ -75,6 +75,12 @@ namespace skyline::nce { if (*tls) { // If TLS was restored then this occurred in guest code auto &mctx{ctx->uc_mcontext}; const auto &state{*reinterpret_cast(*tls)->state}; + + if (signal == SIGSEGV && info->si_code == SEGV_ACCERR) + // If we get a guest access violation then we want to handle any accesses that may be from a trapped region + if (state.nce->TrapHandler(reinterpret_cast(info->si_addr), true)) + return; + if (signal != SIGINT) { signal::StackFrame topFrame{.lr = reinterpret_cast(ctx->uc_mcontext.pc), .next = reinterpret_cast(ctx->uc_mcontext.regs[29])}; std::string trace{state.loader->GetStackTrace(&topFrame)}; @@ -84,7 +90,7 @@ namespace skyline::nce { cpuContext += fmt::format("\n Fault Address: 0x{:X}", mctx.fault_address); if (mctx.sp) cpuContext += fmt::format("\n Stack Pointer: 0x{:X}", mctx.sp); - for (u8 index{}; index < (sizeof(mcontext_t::regs) / sizeof(u64)); index += 2) + for (size_t index{}; index < (sizeof(mcontext_t::regs) / sizeof(u64)); index += 2) cpuContext += fmt::format("\n X{:<2}: 0x{:<16X} X{:<2}: 0x{:X}", index, mctx.regs[index], index + 1, mctx.regs[index + 1]); Logger::Error("Thread #{} has crashed due to signal: {}\nStack Trace:{}\nCPU Context:{}", state.thread->id, strsignal(signal), trace, cpuContext); @@ -106,7 +112,7 @@ namespace skyline::nce { static std::ifstream status("/proc/self/status"); status.seekg(0); - constexpr std::string_view TracerPidTag = "TracerPid:"; + constexpr std::string_view TracerPidTag{"TracerPid:"}; for (std::string line; std::getline(status, line);) { if (line.starts_with(TracerPidTag)) { line = line.substr(TracerPidTag.size()); @@ -355,4 +361,128 @@ namespace skyline::nce { } } } + + NCE::CallbackEntry::CallbackEntry(TrapProtection protection, NCE::TrapCallback readCallback, NCE::TrapCallback writeCallback) : protection(protection), readCallback(std::move(readCallback)), writeCallback(std::move(writeCallback)) {} + + void NCE::ReprotectIntervals(const std::vector &intervals, TrapProtection protection) { + auto reprotectIntervalsWithFunction = [&intervals](auto getProtection) { + for (auto region : intervals) { + region = region.Align(PAGE_SIZE); + mprotect(region.start, region.Size(), getProtection(region)); + } + }; + + // We need to determine the lowest protection possible for the given interval + if (protection == TrapProtection::None) { + reprotectIntervalsWithFunction([&](auto region) { + auto entries{trapMap.GetRange(region)}; + + TrapProtection lowestProtection{TrapProtection::None}; + for (const auto &entry : entries) { + auto entryProtection{entry.get().protection}; + if (entryProtection > lowestProtection) { + lowestProtection = entryProtection; + if (entryProtection == TrapProtection::ReadWrite) + return PROT_EXEC; + } + } + + switch (lowestProtection) { + case TrapProtection::None: + return PROT_READ | PROT_WRITE | PROT_EXEC; + case TrapProtection::WriteOnly: + return PROT_READ | PROT_EXEC; + case TrapProtection::ReadWrite: + return PROT_EXEC; + } + }); + } else if (protection == TrapProtection::WriteOnly) { + reprotectIntervalsWithFunction([&](auto region) { + auto entries{trapMap.GetRange(region)}; + for (const auto &entry : entries) + if (entry.get().protection == TrapProtection::ReadWrite) + return PROT_EXEC; + + return PROT_READ | PROT_EXEC; + }); + } else { + reprotectIntervalsWithFunction([&](auto region) { + return PROT_EXEC; // No checks are needed as this is already the highest level of protection + }); + } + } + + bool NCE::TrapHandler(u8 *address, bool write) { + std::scoped_lock lock(trapMutex); + + // Check if we have a callback for this address + auto[entries, intervals]{trapMap.GetAlignedRecursiveRange(address)}; + + if (entries.empty()) + return false; + + // Do callbacks for every entry in the intervals + if (write) { + for (auto entryRef : entries) { + auto &entry{entryRef.get()}; + if (entry.protection == TrapProtection::None) + // We don't need to do the callback if the entry doesn't require any protection already + continue; + + entry.writeCallback(); + entry.protection = TrapProtection::None; // We don't need to protect this entry anymore + } + } else { + bool allNone{true}; // If all entries require no protection, we can protect to allow all accesses + for (auto entryRef : entries) { + auto &entry{entryRef.get()}; + if (entry.protection < TrapProtection::ReadWrite) { + // We don't need to do the callback if the entry can already handle read accesses + allNone = allNone && entry.protection == TrapProtection::None; + continue; + } + + entry.readCallback(); + entry.protection = TrapProtection::WriteOnly; // We only need to trap writes to this entry + } + write = allNone; + } + + int permission{PROT_READ | (write ? PROT_WRITE : 0) | PROT_EXEC}; + for (const auto &interval : intervals) + // Reprotect the interval to the lowest protection level that the callbacks performed allow + mprotect(interval.start, interval.Size(), permission); + + return true; + } + + constexpr NCE::TrapHandle::TrapHandle(const TrapMap::GroupHandle &handle) : TrapMap::GroupHandle(handle) {} + + NCE::TrapHandle NCE::TrapRegions(span> regions, bool writeOnly, const TrapCallback &readCallback, const TrapCallback &writeCallback) { + std::scoped_lock lock(trapMutex); + auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite}; + TrapHandle handle{trapMap.Insert(regions, CallbackEntry{protection, readCallback, writeCallback})}; + ReprotectIntervals(handle->intervals, protection); + return handle; + } + + void NCE::RetrapRegions(TrapHandle handle, bool writeOnly) { + std::scoped_lock lock(trapMutex); + auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite}; + handle->value.protection = protection; + ReprotectIntervals(handle->intervals, protection); + } + + void NCE::RemoveTrap(TrapHandle handle) { + std::scoped_lock lock(trapMutex); + handle->value.protection = TrapProtection::None; + ReprotectIntervals(handle->intervals, TrapProtection::None); + } + + void NCE::DeleteTrap(TrapHandle handle) { + std::scoped_lock lock(trapMutex); + handle->value.protection = TrapProtection::None; + ReprotectIntervals(handle->intervals, TrapProtection::None); + trapMap.Remove(handle); + } } diff --git a/app/src/main/cpp/skyline/nce.h b/app/src/main/cpp/skyline/nce.h index de8848c0..184d26bf 100644 --- a/app/src/main/cpp/skyline/nce.h +++ b/app/src/main/cpp/skyline/nce.h @@ -3,8 +3,9 @@ #pragma once -#include "common.h" #include +#include "common.h" +#include "common/interval_map.h" namespace skyline::nce { /** @@ -14,6 +15,35 @@ namespace skyline::nce { private: const DeviceState &state; + /** + * @brief The level of protection that is required for a callback entry + */ + enum class TrapProtection { + None = 0, //!< No protection is required + WriteOnly = 1, //!< Only write protection is required + ReadWrite = 2, //!< Both read and write protection are required + }; + + using TrapCallback = std::function; + + struct CallbackEntry { + TrapProtection protection; //!< The least restrictive protection that this callback needs to have + TrapCallback readCallback, writeCallback; + + CallbackEntry(TrapProtection protection, NCE::TrapCallback readCallback, NCE::TrapCallback writeCallback); + }; + + std::mutex trapMutex; //!< Synchronizes the accesses to the trap map + using TrapMap = IntervalMap; + TrapMap trapMap; //!< A map of all intervals and corresponding callbacks that have been registered + + /** + * @brief Reprotects the intervals to the least restrictive protection given the supplied protection + */ + void ReprotectIntervals(const std::vector& intervals, TrapProtection protection); + + bool TrapHandler(u8* address, bool write); + static void SvcHandler(u16 svcId, ThreadContext *ctx); public: @@ -26,7 +56,7 @@ namespace skyline::nce { ExitException(bool killAllThreads = true); - virtual const char* what() const noexcept; + virtual const char *what() const noexcept; }; /** @@ -48,5 +78,38 @@ namespace skyline::nce { * @param patch A pointer to the .patch section which should be exactly patchSize in size and located before the .text section */ static void PatchCode(std::vector &text, u32 *patch, size_t patchSize, const std::vector &offsets); + + /** + * @brief An opaque handle to a group of trapped region + */ + class TrapHandle : private TrapMap::GroupHandle { + constexpr TrapHandle(const TrapMap::GroupHandle &handle); + + friend NCE; + }; + + /** + * @brief Traps a region of guest memory with a callback for when an access to it has been made + * @param writeOnly If the trap is optimally for write-only accesses initially, this is not guarenteed + * @note The handle **must** be deleted using DeleteTrap before the NCE instance is destroyed + * @note It is UB to supply a region of host memory rather than guest memory + */ + TrapHandle TrapRegions(span> regions, bool writeOnly, const TrapCallback& readCallback, const TrapCallback& writeCallback); + + /** + * @brief Re-traps a region of memory after protections were removed + * @param writeOnly If the trap is optimally for write-only accesses, this is not guarenteed + */ + void RetrapRegions(TrapHandle handle, bool writeOnly); + + /** + * @brief Removes protections from a region of memory + */ + void RemoveTrap(TrapHandle handle); + + /** + * @brief Deletes a trap handle and removes the protection from the region + */ + void DeleteTrap(TrapHandle handle); }; }