Fix and refactor SVC SignalToAddress/WaitForAddress

SVC `SignalToAddress` had a bug with the behavior of `SignalAndModifyBasedOnWaitingThreadCountIfEqual` which was entirely incorrect and led to deadlocks in titles such as ARMS that were dependent on it. This commit corrects the behavior and refactors both SVCs and moves their arbitration/waiting to inside the corresponding `KProcess` function rather than the SVC to avoid redundancies and improve code readability.
This commit is contained in:
PixelyIon 2022-05-05 17:58:50 +05:30
parent 396979e897
commit 37327f1955
3 changed files with 75 additions and 44 deletions

View File

@ -1205,11 +1205,8 @@ namespace skyline::kernel::svc {
return; return;
} }
enum class ArbitrationType : u32 { using ArbitrationType = type::KProcess::ArbitrationType;
WaitIfLessThan = 0, auto arbitrationType{static_cast<ArbitrationType>(static_cast<u32>(state.ctx->gpr.w1))};
DecrementAndWaitIfLessThan = 1,
WaitIfEqual = 2,
} arbitrationType{static_cast<ArbitrationType>(static_cast<u32>(state.ctx->gpr.w1))};
u32 value{state.ctx->gpr.w2}; u32 value{state.ctx->gpr.w2};
i64 timeout{static_cast<i64>(state.ctx->gpr.x3)}; i64 timeout{static_cast<i64>(state.ctx->gpr.x3)};
@ -1217,28 +1214,17 @@ namespace skyline::kernel::svc {
switch (arbitrationType) { switch (arbitrationType) {
case ArbitrationType::WaitIfLessThan: case ArbitrationType::WaitIfLessThan:
Logger::Debug("Waiting on 0x{:X} if less than {} for {}ns", address, value, timeout); Logger::Debug("Waiting on 0x{:X} if less than {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfLessThan);
return *address < value;
});
break; break;
case ArbitrationType::DecrementAndWaitIfLessThan: case ArbitrationType::DecrementAndWaitIfLessThan:
Logger::Debug("Waiting on and decrementing 0x{:X} if less than {} for {}ns", address, value, timeout); Logger::Debug("Waiting on and decrementing 0x{:X} if less than {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::DecrementAndWaitIfLessThan);
u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)};
do {
if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check
return false;
} while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST));
return true;
});
break; break;
case ArbitrationType::WaitIfEqual: case ArbitrationType::WaitIfEqual:
Logger::Debug("Waiting on 0x{:X} if equal to {} for {}ns", address, value, timeout); Logger::Debug("Waiting on 0x{:X} if equal to {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfEqual);
return *address == value;
});
break; break;
default: default:
@ -1267,11 +1253,8 @@ namespace skyline::kernel::svc {
return; return;
} }
enum class SignalType : u32 { using SignalType = type::KProcess::SignalType;
Signal = 0, auto signalType{static_cast<SignalType>(static_cast<u32>(state.ctx->gpr.w1))};
SignalAndIncrementIfEqual = 1,
SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2,
} signalType{static_cast<SignalType>(static_cast<u32>(state.ctx->gpr.w1))};
u32 value{state.ctx->gpr.w2}; u32 value{state.ctx->gpr.w2};
i32 count{static_cast<i32>(state.ctx->gpr.w3)}; i32 count{static_cast<i32>(state.ctx->gpr.w3)};
@ -1279,21 +1262,17 @@ namespace skyline::kernel::svc {
switch (signalType) { switch (signalType) {
case SignalType::Signal: case SignalType::Signal:
Logger::Debug("Signalling 0x{:X} for {} waiters", address, count); Logger::Debug("Signalling 0x{:X} for {} waiters", address, count);
result = state.process->SignalToAddress(address, value, count); result = state.process->SignalToAddress(address, value, count, SignalType::Signal);
break; break;
case SignalType::SignalAndIncrementIfEqual: case SignalType::SignalAndIncrementIfEqual:
Logger::Debug("Signalling 0x{:X} and incrementing if equal to {} for {} waiters", address, value, count); Logger::Debug("Signalling 0x{:X} and incrementing if equal to {} for {} waiters", address, value, count);
result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32) { result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndIncrementIfEqual);
return __atomic_compare_exchange_n(address, &value, value + 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
});
break; break;
case SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual: case SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual:
Logger::Debug("Signalling 0x{:X} and setting to waiting thread count if equal to {} for {} waiters", address, value, count); Logger::Debug("Signalling 0x{:X} and setting to waiting thread count if equal to {} for {} waiters", address, value, count);
result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32 waiterCount) { result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual);
return __atomic_compare_exchange_n(address, &value, waiterCount, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
});
break; break;
default: default:

View File

@ -54,7 +54,7 @@ namespace skyline::kernel::type {
u8 *KProcess::AllocateTlsSlot() { u8 *KProcess::AllocateTlsSlot() {
std::scoped_lock lock{tlsMutex}; std::scoped_lock lock{tlsMutex};
u8 *slot; u8 *slot;
for (auto &tlsPage: tlsPages) for (auto &tlsPage : tlsPages)
if ((slot = tlsPage->ReserveSlot())) if ((slot = tlsPage->ReserveSlot()))
return slot; return slot;
@ -268,13 +268,32 @@ namespace skyline::kernel::type {
__atomic_store_n(key, false, __ATOMIC_SEQ_CST); // We need to update the boolean flag denoting that there are no more threads waiting on this conditional variable __atomic_store_n(key, false, __ATOMIC_SEQ_CST); // We need to update the boolean flag denoting that there are no more threads waiting on this conditional variable
} }
Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, bool (*arbitrationFunction)(u32 *, u32)) { Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type) {
TRACE_EVENT_FMT("kernel", "WaitForAddress 0x{:X}", address); TRACE_EVENT_FMT("kernel", "WaitForAddress 0x{:X}", address);
{ {
std::scoped_lock lock{syncWaiterMutex}; std::scoped_lock lock{syncWaiterMutex};
if (!arbitrationFunction(address, value)) [[unlikely]]
switch (type) {
case ArbitrationType::WaitIfLessThan:
if (*address >= value) [[unlikely]]
return result::InvalidState; return result::InvalidState;
break;
case ArbitrationType::DecrementAndWaitIfLessThan: {
u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)};
do {
if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check
return result::InvalidState;
} while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST));
break;
}
case ArbitrationType::WaitIfEqual:
if (*address != value) [[unlikely]]
return result::InvalidState;
break;
}
auto queue{syncWaiters.equal_range(address)}; auto queue{syncWaiters.equal_range(address)};
syncWaiters.insert(std::upper_bound(queue.first, queue.second, state.thread->priority.load(), [](const i8 priority, const SyncWaiters::value_type &it) { return it.second->priority > priority; }), {address, state.thread}); syncWaiters.insert(std::upper_bound(queue.first, queue.second, state.thread->priority.load(), [](const i8 priority, const SyncWaiters::value_type &it) { return it.second->priority > priority; }), {address, state.thread});
@ -303,15 +322,36 @@ namespace skyline::kernel::type {
return {}; return {};
} }
Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount)) { Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type) {
TRACE_EVENT_FMT("kernel", "SignalToAddress 0x{:X}", address); TRACE_EVENT_FMT("kernel", "SignalToAddress 0x{:X}", address);
std::scoped_lock lock{syncWaiterMutex}; std::scoped_lock lock{syncWaiterMutex};
auto queue{syncWaiters.equal_range(address)}; auto queue{syncWaiters.equal_range(address)};
if (mutateFunction) if (type != SignalType::Signal) {
if (!mutateFunction(address, value, (amount <= 0) ? 0 : std::min(static_cast<u32>(std::distance(queue.first, queue.second) - amount), 0U))) [[unlikely]] u32 newValue{value};
if (type == SignalType::SignalAndIncrementIfEqual) {
newValue++;
} else if (type == SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual) {
if (amount <= 0) {
if (queue.first != queue.second)
newValue -= 2;
else
newValue++;
} else {
if (queue.first != queue.second) {
i32 waiterCount{static_cast<i32>(std::distance(queue.first, queue.second))};
if (waiterCount < amount)
newValue--;
} else {
newValue++;
}
}
}
if (!__atomic_compare_exchange_n(address, &value, newValue, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)) [[unlikely]]
return result::InvalidState; return result::InvalidState;
}
i32 waiterCount{amount}; i32 waiterCount{amount};
for (auto it{queue.first}; it != queue.second && (amount <= 0 || waiterCount); it = syncWaiters.erase(it), waiterCount--) for (auto it{queue.first}; it != queue.second && (amount <= 0 || waiterCount); it = syncWaiters.erase(it), waiterCount--)

View File

@ -230,15 +230,27 @@ namespace skyline {
*/ */
void ConditionalVariableSignal(u32 *key, i32 amount); void ConditionalVariableSignal(u32 *key, i32 amount);
/** enum class ArbitrationType : u32 {
* @brief Waits on the supplied address with the specified arbitration function WaitIfLessThan = 0,
*/ DecrementAndWaitIfLessThan = 1,
Result WaitForAddress(u32 *address, u32 value, i64 timeout, bool(*arbitrationFunction)(u32 *address, u32 value)); WaitIfEqual = 2,
};
/** /**
* @brief Signals a variable amount of waiters at the supplied address * @brief Waits on the supplied address with the specified arbitration type
*/ */
Result SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount) = nullptr); Result WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type);
enum class SignalType : u32 {
Signal = 0,
SignalAndIncrementIfEqual = 1,
SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2,
};
/**
* @brief Signals a variable for amount of waiters at the supplied address with the specified signal type
*/
Result SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type);
}; };
} }
} }