[orbis-kernel] umtx: implement rwlock ops

This commit is contained in:
DH 2023-10-30 21:54:43 +03:00
parent 60e11486f4
commit 06a0910c80
3 changed files with 263 additions and 27 deletions

View file

@ -79,12 +79,10 @@ ErrorCode umtx_cv_wait(Thread *thread, ptr<ucond> cv, ptr<umutex> m,
std::uint64_t ut, ulong wflags);
ErrorCode umtx_cv_signal(Thread *thread, ptr<ucond> cv);
ErrorCode umtx_cv_broadcast(Thread *thread, ptr<ucond> cv);
ErrorCode umtx_rw_rdlock(Thread *thread, ptr<void> obj, std::int64_t val,
ptr<void> uaddr1, ptr<void> uaddr2);
ErrorCode umtx_rw_wrlock(Thread *thread, ptr<void> obj, std::int64_t val,
ptr<void> uaddr1, ptr<void> uaddr2);
ErrorCode umtx_rw_unlock(Thread *thread, ptr<void> obj, std::int64_t val,
ptr<void> uaddr1, ptr<void> uaddr2);
ErrorCode umtx_rw_rdlock(Thread *thread, ptr<urwlock> rwlock, slong fflag,
ulong ut);
ErrorCode umtx_rw_wrlock(Thread *thread, ptr<urwlock> rwlock, ulong ut);
ErrorCode umtx_rw_unlock(Thread *thread, ptr<urwlock> rwlock);
ErrorCode umtx_wake_private(Thread *thread, ptr<void> uaddr, sint n_wake);
ErrorCode umtx_wait_umutex(Thread *thread, ptr<umutex> m, std::uint64_t ut);
ErrorCode umtx_wake_umutex(Thread *thread, ptr<umutex> m);

View file

@ -126,11 +126,15 @@ orbis::SysResult orbis::sys__umtx_op(Thread *thread, ptr<void> obj, sint op,
false);
}
case 12:
return umtx_rw_rdlock(thread, obj, val, uaddr1, uaddr2);
return with_timeout([&](std::uint64_t ut) {
return umtx_rw_rdlock(thread, (ptr<urwlock>)obj, val, ut);
});
case 13:
return umtx_rw_wrlock(thread, obj, val, uaddr1, uaddr2);
return with_timeout([&](std::uint64_t ut) {
return umtx_rw_wrlock(thread, (ptr<urwlock>)obj, ut);
});
case 14:
return umtx_rw_unlock(thread, obj, val, uaddr1, uaddr2);
return umtx_rw_unlock(thread, (ptr<urwlock>)obj);
case 15: {
return with_timeout(
[&](std::uint64_t ut) {
@ -159,6 +163,9 @@ orbis::SysResult orbis::sys__umtx_op(Thread *thread, ptr<void> obj, sint op,
return umtx_nwake_private(thread, (ptr<void *>)obj, val);
case 22:
return umtx_wake2_umutex(thread, obj, val, uaddr1, uaddr2);
case 23:
ORBIS_LOG_ERROR("sys__umtx_op: unknown wake operation", op);
return umtx_wake_umutex(thread, (orbis::ptr<orbis::umutex>)obj);
}
return ErrorCode::INVAL;

View file

@ -49,12 +49,14 @@ uint UmtxChain::notify_all(const UmtxKey &key) {
orbis::ErrorCode orbis::umtx_lock_umtx(Thread *thread, ptr<umtx> umtx, ulong id,
std::uint64_t ut) {
ORBIS_LOG_TODO(__FUNCTION__, thread->tid, umtx, id, ut);
std::abort();
return ErrorCode::NOSYS;
}
orbis::ErrorCode orbis::umtx_unlock_umtx(Thread *thread, ptr<umtx> umtx,
ulong id) {
ORBIS_LOG_TODO(__FUNCTION__, thread->tid, umtx, id);
std::abort();
return ErrorCode::NOSYS;
}
@ -191,7 +193,7 @@ static ErrorCode do_lock_pi(Thread *thread, ptr<umutex> m, uint flags,
static ErrorCode do_lock_pp(Thread *thread, ptr<umutex> m, uint flags,
std::uint64_t ut, umutex_lock_mode mode) {
ORBIS_LOG_TODO(__FUNCTION__, m, flags, ut, mode);
return ErrorCode::NOSYS;
return do_lock_normal(thread, m, flags, ut, mode);
}
static ErrorCode do_unlock_normal(Thread *thread, ptr<umutex> m, uint flags) {
ORBIS_LOG_TRACE(__FUNCTION__, thread->tid, m, flags);
@ -227,7 +229,7 @@ static ErrorCode do_unlock_pi(Thread *thread, ptr<umutex> m, uint flags) {
}
static ErrorCode do_unlock_pp(Thread *thread, ptr<umutex> m, uint flags) {
ORBIS_LOG_TODO(__FUNCTION__, m, flags);
return ErrorCode::NOSYS;
return do_unlock_normal(thread, m, flags);
}
} // namespace orbis
@ -284,7 +286,8 @@ orbis::ErrorCode orbis::umtx_unlock_umutex(Thread *thread, ptr<umutex> m) {
orbis::ErrorCode orbis::umtx_set_ceiling(Thread *thread, ptr<umutex> m,
std::uint32_t ceiling,
ptr<uint32_t> oldCeiling) {
ORBIS_LOG_TODO(__FUNCTION__, m, ceiling, oldCeiling);
ORBIS_LOG_TRACE(__FUNCTION__, m, ceiling, oldCeiling);
std::abort();
return ErrorCode::NOSYS;
}
@ -301,10 +304,12 @@ orbis::ErrorCode orbis::umtx_cv_wait(Thread *thread, ptr<ucond> cv,
}
if ((wflags & kCvWaitClockId) != 0 && ut + 1) {
ORBIS_LOG_FATAL("umtx_cv_wait: CLOCK_ID unimplemented", wflags);
std::abort();
return ErrorCode::NOSYS;
}
if ((wflags & kCvWaitAbsTime) != 0 && ut + 1) {
ORBIS_LOG_FATAL("umtx_cv_wait: ABSTIME unimplemented", wflags);
std::abort();
return ErrorCode::NOSYS;
}
@ -373,25 +378,250 @@ orbis::ErrorCode orbis::umtx_cv_broadcast(Thread *thread, ptr<ucond> cv) {
return {};
}
orbis::ErrorCode orbis::umtx_rw_rdlock(Thread *thread, ptr<void> obj,
std::int64_t val, ptr<void> uaddr1,
ptr<void> uaddr2) {
ORBIS_LOG_TODO(__FUNCTION__, obj, val, uaddr1, uaddr2);
return ErrorCode::NOSYS;
orbis::ErrorCode orbis::umtx_rw_rdlock(Thread *thread, ptr<urwlock> rwlock,
slong fflag, ulong ut) {
ORBIS_LOG_TRACE(__FUNCTION__, thread->tid, rwlock, fflag, ut);
auto flags = rwlock->flags;
auto [chain, key, lock] = g_context.getUmtxChain1(thread, flags & 1, rwlock);
auto wrflags = kUrwLockWriteOwner;
if (!(fflag & kUrwLockPreferReader) && !(flags & kUrwLockPreferReader)) {
wrflags |= kUrwLockWriteWaiters;
}
while (true) {
auto state = rwlock->state.load(std::memory_order::relaxed);
while ((state & wrflags) == 0) {
if ((state & kUrwLockMaxReaders) == kUrwLockMaxReaders) {
return ErrorCode::AGAIN;
}
if (rwlock->state.compare_exchange_strong(state, state + 1)) {
return {};
}
}
while ((state & wrflags) && !(state & kUrwLockReadWaiters)) {
if (rwlock->state.compare_exchange_weak(state,
state | kUrwLockReadWaiters)) {
break;
}
}
if (!(state & wrflags)) {
continue;
}
++rwlock->blocked_readers;
ErrorCode result{};
while (state & wrflags) {
auto node = chain.enqueue(key, thread);
if (ut + 1 == 0) {
while (true) {
node->second.cv.wait(chain.mtx, ut);
if (node->second.thr != thread) {
break;
}
}
} else {
auto start = std::chrono::steady_clock::now();
std::uint64_t udiff = 0;
while (true) {
node->second.cv.wait(chain.mtx, ut - udiff);
if (node->second.thr != thread)
break;
udiff = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count();
if (udiff >= ut) {
result = ErrorCode::TIMEDOUT;
break;
}
}
}
if (node->second.thr != thread) {
result = {};
} else {
chain.erase(node);
}
if (result != ErrorCode{}) {
break;
}
state = rwlock->state.load(std::memory_order::relaxed);
}
if (--rwlock->blocked_readers == 0) {
while (true) {
if (!rwlock->state.compare_exchange_weak(
state, state & ~kUrwLockReadWaiters)) {
break;
}
}
}
}
return {};
}
orbis::ErrorCode orbis::umtx_rw_wrlock(Thread *thread, ptr<void> obj,
std::int64_t val, ptr<void> uaddr1,
ptr<void> uaddr2) {
ORBIS_LOG_TODO(__FUNCTION__, obj, val, uaddr1, uaddr2);
return ErrorCode::NOSYS;
orbis::ErrorCode orbis::umtx_rw_wrlock(Thread *thread, ptr<urwlock> rwlock,
ulong ut) {
ORBIS_LOG_TRACE(__FUNCTION__, thread->tid, rwlock, ut);
auto flags = rwlock->flags;
auto [chain, key, lock] = g_context.getUmtxChain1(thread, flags & 1, rwlock);
uint32_t blocked_readers = 0;
ErrorCode error = {};
while (true) {
auto state = rwlock->state.load(std::memory_order::relaxed);
while (!(state & kUrwLockWriteOwner) && (state & kUrwLockMaxReaders) == 0) {
if (!rwlock->state.compare_exchange_strong(state,
state | kUrwLockWriteOwner)) {
return {};
}
}
if (error != ErrorCode{}) {
if (!(state & (kUrwLockWriteOwner | kUrwLockWriteWaiters)) &&
blocked_readers != 0) {
chain.notify_one(key);
}
break;
}
state = rwlock->state.load(std::memory_order::relaxed);
while (
((state & kUrwLockWriteOwner) || (state & kUrwLockMaxReaders) != 0) &&
(state & kUrwLockWriteWaiters) == 0) {
if (!rwlock->state.compare_exchange_strong(
state, state | kUrwLockWriteWaiters)) {
break;
}
}
if (!(state & kUrwLockWriteOwner) && (state & kUrwLockMaxReaders) == 0) {
continue;
}
++rwlock->blocked_writers;
while ((state & kUrwLockWriteOwner) || (state & kUrwLockMaxReaders) != 0) {
auto node = chain.enqueue(key, thread);
if (ut + 1 == 0) {
while (true) {
node->second.cv.wait(chain.mtx, ut);
if (node->second.thr != thread) {
break;
}
}
} else {
auto start = std::chrono::steady_clock::now();
std::uint64_t udiff = 0;
while (true) {
node->second.cv.wait(chain.mtx, ut - udiff);
if (node->second.thr != thread)
break;
udiff = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count();
if (udiff >= ut) {
error = ErrorCode::TIMEDOUT;
break;
}
}
}
if (node->second.thr != thread) {
error = {};
} else {
chain.erase(node);
}
if (error != ErrorCode{}) {
break;
}
state = rwlock->state.load(std::memory_order::relaxed);
}
if (--rwlock->blocked_writers == 0) {
state = rwlock->state.load(std::memory_order::relaxed);
while (true) {
if (rwlock->state.compare_exchange_weak(
state, state & ~kUrwLockWriteWaiters)) {
break;
}
}
blocked_readers = rwlock->blocked_readers;
} else {
blocked_readers = 0;
}
}
return error;
}
orbis::ErrorCode orbis::umtx_rw_unlock(Thread *thread, ptr<void> obj,
std::int64_t val, ptr<void> uaddr1,
ptr<void> uaddr2) {
ORBIS_LOG_TODO(__FUNCTION__, obj, val, uaddr1, uaddr2);
return ErrorCode::NOSYS;
orbis::ErrorCode orbis::umtx_rw_unlock(Thread *thread, ptr<urwlock> rwlock) {
auto flags = rwlock->flags;
auto [chain, key, lock] = g_context.getUmtxChain1(thread, flags & 1, rwlock);
auto state = rwlock->state.load(std::memory_order::relaxed);
if (state & kUrwLockWriteOwner) {
while (true) {
if (rwlock->state.compare_exchange_weak(state, state & ~kUrwLockWriteOwner)) {
break;
}
if (!(state & kUrwLockWriteOwner)) {
return ErrorCode::PERM;
}
}
} else if ((state & kUrwLockMaxReaders) != 0) {
while (true) {
if (rwlock->state.compare_exchange_weak(state, state - 1)) {
break;
}
if ((state & kUrwLockMaxReaders) == 0) {
return ErrorCode::PERM;
}
}
} else {
return ErrorCode::PERM;
}
unsigned count = 0;
if (!(flags & kUrwLockPreferReader)) {
if (state & kUrwLockWriteWaiters) {
count = 1;
} else if (state & kUrwLockReadWaiters) {
count = UINT_MAX;
}
} else {
if (state & kUrwLockReadWaiters) {
count = UINT_MAX;
} else if (state & kUrwLockWriteWaiters) {
count = 1;
}
}
if (count == 1) {
chain.notify_one(key);
} else if (count != 0) {
chain.notify_all(key);
}
return {};
}
orbis::ErrorCode orbis::umtx_wake_private(Thread *thread, ptr<void> addr,
@ -522,5 +752,6 @@ orbis::ErrorCode orbis::umtx_wake2_umutex(Thread *thread, ptr<void> obj,
std::int64_t val, ptr<void> uaddr1,
ptr<void> uaddr2) {
ORBIS_LOG_TODO(__FUNCTION__, obj, val, uaddr1, uaddr2);
std::abort();
return ErrorCode::NOSYS;
}