From 091a9eec260b60ae244df1080e2c2c6af4e40e05 Mon Sep 17 00:00:00 2001 From: DH Date: Sun, 24 Nov 2024 20:49:13 +0300 Subject: [PATCH] kernel: umtx: implement umtx_wake2_umutex fixed shared unlock --- orbis-kernel/include/orbis/umtx.hpp | 6 +- orbis-kernel/src/sys/sys_umtx.cpp | 8 +- orbis-kernel/src/umtx.cpp | 121 ++++++++++++++++++++++------ 3 files changed, 102 insertions(+), 33 deletions(-) diff --git a/orbis-kernel/include/orbis/umtx.hpp b/orbis-kernel/include/orbis/umtx.hpp index dd42eb3cf..79234d6e8 100644 --- a/orbis-kernel/include/orbis/umtx.hpp +++ b/orbis-kernel/include/orbis/umtx.hpp @@ -52,6 +52,8 @@ inline constexpr auto kUmtxOpSemWait = 19; inline constexpr auto kUmtxOpSemWake = 20; inline constexpr auto kUmtxOpNwakePrivate = 21; inline constexpr auto kUmtxOpMutexWake2 = 22; +inline constexpr auto kUmtxOpMutexWake3 = 23; + inline constexpr auto kSemNamed = 2; @@ -114,6 +116,6 @@ ErrorCode umtx_sem_wait(Thread *thread, ptr sem, std::uint64_t ut); ErrorCode umtx_sem_wake(Thread *thread, ptr sem); ErrorCode umtx_nwake_private(Thread *thread, ptr uaddrs, std::int64_t count); -ErrorCode umtx_wake2_umutex(Thread *thread, ptr obj, std::int64_t val, - ptr uaddr1, ptr uaddr2); +ErrorCode umtx_wake2_umutex(Thread *thread, ptr m, sint wakeFlags); +ErrorCode umtx_wake3_umutex(Thread *thread, ptr m, sint wakeFlags); } // namespace orbis \ No newline at end of file diff --git a/orbis-kernel/src/sys/sys_umtx.cpp b/orbis-kernel/src/sys/sys_umtx.cpp index b6da684f5..dd07f6bed 100644 --- a/orbis-kernel/src/sys/sys_umtx.cpp +++ b/orbis-kernel/src/sys/sys_umtx.cpp @@ -168,11 +168,9 @@ orbis::SysResult orbis::sys__umtx_op(Thread *thread, ptr obj, sint op, case kUmtxOpNwakePrivate: return umtx_nwake_private(thread, (ptr)obj, val); case kUmtxOpMutexWake2: - return umtx_wake2_umutex(thread, obj, val, uaddr1, uaddr2); - case 23: - ORBIS_LOG_ERROR("sys__umtx_op: unknown wake operation", op, val, uaddr1, uaddr2); - // thread->where(); - return umtx_wake_umutex(thread, (orbis::ptr)obj, val); + return umtx_wake2_umutex(thread, (orbis::ptr)obj, val); + case kUmtxOpMutexWake3: + return umtx_wake3_umutex(thread, (orbis::ptr)obj, val); } return ErrorCode::INVAL; diff --git a/orbis-kernel/src/umtx.cpp b/orbis-kernel/src/umtx.cpp index ab8552be6..01939f784 100644 --- a/orbis-kernel/src/umtx.cpp +++ b/orbis-kernel/src/umtx.cpp @@ -41,15 +41,6 @@ uint UmtxChain::notify_n(const UmtxKey &key, sint count) { uint n = 0; while (count > 0) { - while (true) { - auto flags = it->second.thr->suspendFlags.load(); - if (~flags & kThreadSuspendFlag) { - break; - } - - orbis::scoped_unblock unblock; - it->second.thr->suspendFlags.wait(flags); - } it->second.thr = nullptr; it->second.cv.notify_all(mtx); it = erase(it); @@ -57,7 +48,7 @@ uint UmtxChain::notify_n(const UmtxKey &key, sint count) { n++; count--; - if (it->first != key || it == sleep_queue.end()) { + if (it == sleep_queue.end() || it->first != key) { break; } } @@ -211,7 +202,8 @@ static ErrorCode do_lock_normal(Thread *thread, ptr m, uint flags, orbis::scoped_unblock unblock; error = orbis::toErrorCode(node->second.cv.wait(chain.mtx, ut)); } - if (error == ErrorCode{} && !isSpuriousWakeup(error) && node->second.thr == thread && m->owner.load() != 0) { + if (error == ErrorCode{} && !isSpuriousWakeup(error) && + node->second.thr == thread && m->owner.load() != 0) { error = ErrorCode::TIMEDOUT; } } @@ -244,17 +236,20 @@ static ErrorCode do_unlock_normal(Thread *thread, ptr m, uint flags) { } std::size_t count = chain.sleep_queue.count(key); - bool ok = m->owner.compare_exchange_strong( - owner, count <= 1 ? kUmutexUnowned : kUmutexContested); + bool ok; if (key.pid == 0) { + ok = m->owner.compare_exchange_strong(owner, kUmutexUnowned); // IPC workaround (TODO) chain.notify_all(key); if (!ok) return ErrorCode::INVAL; return {}; } - if (count) - chain.notify_all(key); + + ok = m->owner.compare_exchange_strong(owner, count <= 1 ? kUmutexUnowned + : kUmutexContested); + chain.notify_one(key); + if (!ok) return ErrorCode::INVAL; return {}; @@ -718,23 +713,26 @@ orbis::ErrorCode orbis::umtx_wait_umutex(Thread *thread, ptr m, orbis::ErrorCode orbis::umtx_wake_umutex(Thread *thread, ptr m, sint wakeFlags) { - ORBIS_LOG_TRACE(__FUNCTION__, m); - int owner = m->owner.load(std::memory_order::acquire); - if ((owner & ~kUmutexContested) != 0) - return {}; + ORBIS_LOG_TRACE(__FUNCTION__, thread->tid, m); uint flags; if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{}) return err; auto [chain, key, lock] = g_context.getUmtxChain1(thread, flags, m); + + int owner = m->owner.load(std::memory_order::acquire); + if ((owner & ~kUmutexContested) != 0) + return {}; + std::size_t count = chain.sleep_queue.count(key); if (count <= 1) { owner = kUmutexContested; m->owner.compare_exchange_strong(owner, kUmutexUnowned); } + if (count != 0 && (owner & ~kUmutexContested) == 0) { - if ((wakeFlags & 0x400) || (flags & 1)) { + if (flags & 1) { chain.notify_all(key); } else { chain.notify_one(key); @@ -822,10 +820,81 @@ orbis::ErrorCode orbis::umtx_nwake_private(Thread *thread, ptr uaddrs, return {}; } -orbis::ErrorCode orbis::umtx_wake2_umutex(Thread *thread, ptr obj, - std::int64_t val, ptr uaddr1, - ptr uaddr2) { - ORBIS_LOG_TODO(__FUNCTION__, obj, val, uaddr1, uaddr2); - std::abort(); - return ErrorCode::NOSYS; +orbis::ErrorCode orbis::umtx_wake2_umutex(Thread *thread, ptr m, + sint wakeFlags) { + ORBIS_LOG_NOTICE(__FUNCTION__, thread->tid, m); + + uint flags; + if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{}) + return err; + + auto [chain, key, lock] = g_context.getUmtxChain1(thread, wakeFlags & 1, m); + + int owner = 0; + + std::size_t count = chain.sleep_queue.count(key); + + if (count > 1) { + owner = m->owner.load(std::memory_order::acquire); + + while ((owner & kUmutexContested) == 0) { + if (m->owner.compare_exchange_weak(owner, owner | kUmutexContested)) { + break; + } + } + } else if (count == 1) { + owner = m->owner.load(std::memory_order::acquire); + + while ((owner & ~kUmutexContested) != 0 && + (owner & kUmutexContested) == 0) { + if (m->owner.compare_exchange_weak(owner, owner | kUmutexContested)) { + break; + } + } + } + + if (count != 0 && (owner & ~kUmutexContested) == 0) { + chain.notify_one(key); + return {}; + } + + return {}; +} + +orbis::ErrorCode orbis::umtx_wake3_umutex(Thread *thread, ptr m, + sint wakeFlags) { + ORBIS_LOG_NOTICE(__FUNCTION__, thread->tid, m); + + uint flags; + if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{}) + return err; + + auto [chain, key, lock] = g_context.getUmtxChain1(thread, wakeFlags & 1, m); + + int owner = 0; + std::size_t count = chain.sleep_queue.count(key); + + if (count > 1) { + owner = m->owner.load(std::memory_order::acquire); + + while ((owner & kUmutexContested) == 0) { + if (m->owner.compare_exchange_weak(owner, owner | kUmutexContested)) { + break; + } + } + } else if (count == 1) { + owner = m->owner.load(std::memory_order::acquire); + + while ((owner & ~kUmutexContested) != 0 && + (owner & kUmutexContested) == 0) { + if (m->owner.compare_exchange_weak(owner, owner | kUmutexContested)) { + break; + } + } + } + + if (count != 0 && (owner & ~kUmutexContested) == 0) { + chain.notify_one(key); + } + return {}; }