orbis-kernel: umtx: implement notify_n

This commit is contained in:
DH 2024-10-31 14:19:22 +03:00
parent cc0e81e88f
commit 2723eb0bfd
2 changed files with 40 additions and 34 deletions

View file

@ -36,13 +36,16 @@ struct UmtxCond {
struct UmtxChain {
utils::shared_mutex mtx;
utils::kmultimap<UmtxKey, UmtxCond> sleep_queue;
utils::kmultimap<UmtxKey, UmtxCond> spare_queue;
using queue_type = utils::kmultimap<UmtxKey, UmtxCond>;
queue_type sleep_queue;
queue_type spare_queue;
std::pair<const UmtxKey, UmtxCond> *enqueue(UmtxKey &key, Thread *thr);
void erase(std::pair<const UmtxKey, UmtxCond> *obj);
queue_type::iterator erase(queue_type::iterator it);
uint notify_one(const UmtxKey &key);
uint notify_all(const UmtxKey &key);
uint notify_n(const UmtxKey &key, sint count);
};
class alignas(__STDCPP_DEFAULT_NEW_ALIGNMENT__) KernelContext final {

View file

@ -4,6 +4,7 @@
#include "orbis/utils/AtomicOp.hpp"
#include "orbis/utils/Logs.hpp"
#include "time.hpp"
#include <limits>
namespace orbis {
std::pair<const UmtxKey, UmtxCond> *UmtxChain::enqueue(UmtxKey &key,
@ -20,29 +21,48 @@ std::pair<const UmtxKey, UmtxCond> *UmtxChain::enqueue(UmtxKey &key,
void UmtxChain::erase(std::pair<const UmtxKey, UmtxCond> *obj) {
for (auto [it, e] = sleep_queue.equal_range(obj->first); it != e; it++) {
if (&*it == obj) {
auto node = sleep_queue.extract(it);
node.key() = {};
spare_queue.insert(spare_queue.begin(), std::move(node));
erase(it);
return;
}
}
}
uint UmtxChain::notify_one(const UmtxKey &key) {
UmtxChain::queue_type::iterator UmtxChain::erase(queue_type::iterator it) {
auto next = std::next(it);
auto node = sleep_queue.extract(it);
node.key() = {};
spare_queue.insert(spare_queue.begin(), std::move(node));
return next;
}
uint UmtxChain::notify_n(const UmtxKey &key, sint count) {
auto it = sleep_queue.find(key);
if (it == sleep_queue.end())
return 0;
it->second.thr = nullptr;
it->second.cv.notify_all(mtx);
this->erase(&*it);
return 1;
uint n = 0;
while (count > 0) {
it->second.thr = nullptr;
it->second.cv.notify_all(mtx);
it = erase(it);
n++;
count--;
if (it == sleep_queue.end()) {
break;
}
}
return n;
}
uint UmtxChain::notify_one(const UmtxKey &key) {
return notify_n(key, 1);
}
uint UmtxChain::notify_all(const UmtxKey &key) {
uint n = 0;
while (notify_one(key))
n++;
return n;
return notify_n(key, std::numeric_limits<sint>::max());
}
} // namespace orbis
@ -108,18 +128,12 @@ orbis::ErrorCode orbis::umtx_wait(Thread *thread, ptr<void> addr, ulong id,
orbis::ErrorCode orbis::umtx_wake(Thread *thread, ptr<void> addr, sint n_wake) {
ORBIS_LOG_NOTICE(__FUNCTION__, thread->tid, addr, n_wake);
auto [chain, key, lock] = g_context.getUmtxChain0(thread, true, addr);
std::size_t count = chain.sleep_queue.count(key);
if (key.pid == 0) {
// IPC workaround (TODO)
chain.notify_all(key);
return {};
}
// TODO: check this
while (count--) {
chain.notify_one(key);
if (n_wake-- <= 1)
break;
}
chain.notify_n(key, n_wake);
return {};
}
@ -645,12 +659,7 @@ orbis::ErrorCode orbis::umtx_rw_unlock(Thread *thread, ptr<urwlock> rwlock) {
}
}
if (count == 1) {
chain.notify_one(key);
} else if (count != 0) {
chain.notify_all(key);
}
chain.notify_n(key, count);
return {};
}
@ -658,13 +667,7 @@ orbis::ErrorCode orbis::umtx_wake_private(Thread *thread, ptr<void> addr,
sint n_wake) {
ORBIS_LOG_TRACE(__FUNCTION__, thread->tid, addr, n_wake);
auto [chain, key, lock] = g_context.getUmtxChain0(thread, false, addr);
std::size_t count = chain.sleep_queue.count(key);
// TODO: check this
while (count--) {
chain.notify_one(key);
if (n_wake-- <= 1)
break;
}
chain.notify_n(key, n_wake);
return {};
}