Fix NCE Trap API Lock Callback

The lock callback would `continue` which would end up skipping over the current item as it applied to the inner loop rather than the outer loop as intended. This has now been fixed by using `break` and a check instead.
This commit is contained in:
PixelyIon 2022-07-13 22:23:33 +05:30
parent 745d809e07
commit 45cb8388cc
No known key found for this signature in database
GPG Key ID: 11BC6C3201BC2C05

View File

@ -469,17 +469,18 @@ namespace skyline::nce {
bool NCE::TrapHandler(u8 *address, bool write) { bool NCE::TrapHandler(u8 *address, bool write) {
LockCallback lockCallback{}; LockCallback lockCallback{};
while (true) { while (true) {
if (lockCallback) if (lockCallback) {
// We want to avoid a deadlock of holding trapMutex while locking the resource inside a callback while another thread holding the resource's mutex waits on trapMutex, we solve this by quitting the loop if a callback would be blocking and attempt to lock the resource externally // We want to avoid a deadlock of holding trapMutex while locking the resource inside a callback while another thread holding the resource's mutex waits on trapMutex, we solve this by quitting the loop if a callback would be blocking and attempt to lock the resource externally
lockCallback(); lockCallback();
lockCallback = {};
}
std::scoped_lock lock(trapMutex); std::scoped_lock lock(trapMutex);
// Check if we have a callback for this address // Retrieve any callbacks for the page that was faulted
auto[entries, intervals]{trapMap.GetAlignedRecursiveRange<PAGE_SIZE>(address)}; auto[entries, intervals]{trapMap.GetAlignedRecursiveRange<PAGE_SIZE>(address)};
if (entries.empty()) if (entries.empty())
return false; return false; // There's no callbacks associated with this page
// Do callbacks for every entry in the intervals // Do callbacks for every entry in the intervals
if (write) { if (write) {
@ -491,10 +492,12 @@ namespace skyline::nce {
if (!entry.writeCallback()) { if (!entry.writeCallback()) {
lockCallback = entry.lockCallback; lockCallback = entry.lockCallback;
continue; break;
} }
entry.protection = TrapProtection::None; // We don't need to protect this entry anymore entry.protection = TrapProtection::None; // We don't need to protect this entry anymore
} }
if (lockCallback)
continue; // We need to retry the loop because a callback was blocking
} else { } else {
bool allNone{true}; // If all entries require no protection, we can protect to allow all accesses bool allNone{true}; // If all entries require no protection, we can protect to allow all accesses
for (auto entryRef : entries) { for (auto entryRef : entries) {
@ -507,10 +510,12 @@ namespace skyline::nce {
if (!entry.readCallback()) { if (!entry.readCallback()) {
lockCallback = entry.lockCallback; lockCallback = entry.lockCallback;
continue; break;
} }
entry.protection = TrapProtection::WriteOnly; // We only need to trap writes to this entry entry.protection = TrapProtection::WriteOnly; // We only need to trap writes to this entry
} }
if (lockCallback)
continue; // We need to retry the loop because a callback was blocking
write = allNone; write = allNone;
} }
@ -525,7 +530,7 @@ namespace skyline::nce {
constexpr NCE::TrapHandle::TrapHandle(const TrapMap::GroupHandle &handle) : TrapMap::GroupHandle(handle) {} constexpr NCE::TrapHandle::TrapHandle(const TrapMap::GroupHandle &handle) : TrapMap::GroupHandle(handle) {}
NCE::TrapHandle NCE::TrapRegions(span<span<u8>> regions, bool writeOnly, const LockCallback& lockCallback, const TrapCallback &readCallback, const TrapCallback &writeCallback) { NCE::TrapHandle NCE::TrapRegions(span<span<u8>> regions, bool writeOnly, const LockCallback &lockCallback, const TrapCallback &readCallback, const TrapCallback &writeCallback) {
std::scoped_lock lock(trapMutex); std::scoped_lock lock(trapMutex);
auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite}; auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite};
TrapHandle handle{trapMap.Insert(regions, CallbackEntry{protection, lockCallback, readCallback, writeCallback})}; TrapHandle handle{trapMap.Insert(regions, CallbackEntry{protection, lockCallback, readCallback, writeCallback})};