Fix shared_cv deadlock

Was incorrect order of args for futex
This commit is contained in:
Ivan Chikish 2023-07-12 14:05:23 +03:00
parent 123321e2bc
commit de973e369f
5 changed files with 22 additions and 18 deletions

View file

@ -34,7 +34,7 @@ protected:
}
// Internal waiting function
void impl_wait(shared_mutex &mutex, unsigned _old,
void impl_wait(shared_mutex &mutex, unsigned _val,
std::uint64_t usec_timeout) noexcept;
// Try to notify up to _count threads
@ -44,13 +44,13 @@ public:
constexpr shared_cv() = default;
void wait(shared_mutex &mutex, std::uint64_t usec_timeout = -1) noexcept {
const unsigned _old = add_waiter();
if (!_old) {
const unsigned _val = add_waiter();
if (!_val) {
return;
}
mutex.unlock();
impl_wait(mutex, _old, usec_timeout);
impl_wait(mutex, _val, usec_timeout);
}
// Wake one thread

View file

@ -60,7 +60,7 @@ orbis::ErrorCode orbis::umtx_unlock_umtx(Thread *thread, ptr<umtx> umtx,
orbis::ErrorCode orbis::umtx_wait(Thread *thread, ptr<void> addr, ulong id,
std::uint64_t ut, bool is32) {
ORBIS_LOG_NOTICE(__FUNCTION__, addr, id, ut);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, addr, id, ut, is32);
auto [chain, key, lock] = g_context.getUmtxChain0(thread->tproc->pid, addr);
auto node = chain.enqueue(key, thread);
ErrorCode result = {};
@ -93,7 +93,7 @@ orbis::ErrorCode orbis::umtx_wait(Thread *thread, ptr<void> addr, ulong id,
}
orbis::ErrorCode orbis::umtx_wake(Thread *thread, ptr<void> addr, sint n_wake) {
ORBIS_LOG_NOTICE(__FUNCTION__, addr, n_wake);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, addr, n_wake);
auto [chain, key, lock] = g_context.getUmtxChain0(thread->tproc->pid, addr);
std::size_t count = chain.sleep_queue.count(key);
// TODO: check this
@ -128,7 +128,7 @@ void log_class_string<umutex_lock_mode>::format(std::string &out,
}
static ErrorCode do_lock_normal(Thread *thread, ptr<umutex> m, uint flags,
std::uint64_t ut, umutex_lock_mode mode) {
ORBIS_LOG_NOTICE(__FUNCTION__, m, flags, ut, mode);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, m, flags, ut, mode);
ErrorCode error = {};
while (true) {
@ -182,7 +182,7 @@ static ErrorCode do_lock_pp(Thread *thread, ptr<umutex> m, uint flags,
return ErrorCode::NOSYS;
}
static ErrorCode do_unlock_normal(Thread *thread, ptr<umutex> m, uint flags) {
ORBIS_LOG_NOTICE(__FUNCTION__, m, flags);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, m, flags);
int owner = m->owner.load(std::memory_order_acquire);
if ((owner & ~kUmutexContested) != thread->tid)
@ -267,7 +267,7 @@ orbis::ErrorCode orbis::umtx_set_ceiling(Thread *thread, ptr<umutex> m,
orbis::ErrorCode orbis::umtx_cv_wait(Thread *thread, ptr<ucond> cv,
ptr<umutex> m, std::uint64_t ut,
ulong wflags) {
ORBIS_LOG_NOTICE(__FUNCTION__, cv, m, ut, wflags);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, cv, m, ut, wflags);
const uint flags = uread(&cv->flags);
if ((wflags & kCvWaitClockId) != 0) {
ORBIS_LOG_FATAL("umtx_cv_wait: CLOCK_ID unimplemented", wflags);
@ -315,7 +315,7 @@ orbis::ErrorCode orbis::umtx_cv_wait(Thread *thread, ptr<ucond> cv,
}
orbis::ErrorCode orbis::umtx_cv_signal(Thread *thread, ptr<ucond> cv) {
ORBIS_LOG_NOTICE(__FUNCTION__, cv);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, cv);
auto [chain, key, lock] = g_context.getUmtxChain0(thread->tproc->pid, cv);
std::size_t count = chain.sleep_queue.count(key);
if (chain.notify_one(key) >= count)
@ -324,7 +324,7 @@ orbis::ErrorCode orbis::umtx_cv_signal(Thread *thread, ptr<ucond> cv) {
}
orbis::ErrorCode orbis::umtx_cv_broadcast(Thread *thread, ptr<ucond> cv) {
ORBIS_LOG_NOTICE(__FUNCTION__, cv);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, cv);
auto [chain, key, lock] = g_context.getUmtxChain0(thread->tproc->pid, cv);
chain.notify_all(key);
cv->has_waiters.store(0, std::memory_order::relaxed);
@ -407,10 +407,10 @@ orbis::ErrorCode orbis::umtx_sem_wake(Thread *thread, ptr<void> obj,
orbis::ErrorCode orbis::umtx_nwake_private(Thread *thread, ptr<void *> uaddrs,
std::int64_t count) {
ORBIS_LOG_NOTICE(__FUNCTION__, uaddrs, count);
ORBIS_LOG_NOTICE(__FUNCTION__, thread, uaddrs, count);
while (count-- > 0) {
void *uaddr;
auto error = uread(uaddr, uaddrs);
auto error = uread(uaddr, uaddrs++);
if (error != ErrorCode{})
return error;
umtx_wake_private(thread, uaddr, 1);

View file

@ -1,20 +1,21 @@
#include "orbis/utils/SharedCV.hpp"
#include "orbis/utils/Logs.hpp"
#include <linux/futex.h>
#include <syscall.h>
#include <unistd.h>
namespace orbis::utils {
void shared_cv::impl_wait(shared_mutex &mutex, unsigned _old,
void shared_cv::impl_wait(shared_mutex &mutex, unsigned _val,
std::uint64_t usec_timeout) noexcept {
// Not supposed to fail
if (!_old)
if (!_val)
std::abort();
// Wait with timeout
struct timespec timeout {};
timeout.tv_nsec = (usec_timeout % 1000'000) * 1000;
timeout.tv_sec = (usec_timeout / 1000'000);
syscall(SYS_futex, &m_value, FUTEX_WAIT, _old,
syscall(SYS_futex, &m_value, FUTEX_WAIT, _val,
usec_timeout + 1 ? &timeout : nullptr, 0, 0);
// Cleanup
@ -64,6 +65,8 @@ void shared_cv::impl_wake(shared_mutex &mutex, int _count) noexcept {
// Add lock signal (mutex was immediately locked)
if (locked && max_sig)
value |= c_locked_mask;
else if (locked)
std::abort();
// Add normal signals
value += c_signal_one * max_sig;
@ -84,7 +87,7 @@ void shared_cv::impl_wake(shared_mutex &mutex, int _count) noexcept {
// Wake up one thread + requeue remaining waiters
unsigned awake_count = locked ? 1 : 0;
if (auto r = syscall(SYS_futex, &m_value, FUTEX_REQUEUE, awake_count,
&mutex, _count - awake_count, 0);
_count - awake_count, &mutex, 0);
r < _count) {
// Keep awaking waiters
return impl_wake(mutex, is_one ? 1 : INT_MAX);

View file

@ -1,4 +1,5 @@
#include "utils/SharedMutex.hpp"
#include "utils/Logs.hpp"
#include <linux/futex.h>
#include <syscall.h>
#include <unistd.h>

View file

@ -135,7 +135,7 @@ orbis::SysResult close(orbis::Thread *thread, orbis::sint fd) {
#define IOC_DIRMASK (IOC_VOID | IOC_OUT | IOC_IN)
#define _IOC(inout, group, num, len) \
((unsigned long)((inout) | (((len) & IOCPARM_MASK) << 16) | ((group) << 8) | \
((unsigned long)((inout) | (((len)&IOCPARM_MASK) << 16) | ((group) << 8) | \
(num)))
#define _IO(g, n) _IOC(IOC_VOID, (g), (n), 0)
#define _IOWINT(g, n) _IOC(IOC_VOID, (g), (n), sizeof(int))