Linux: use futex_waitv syscall for atomic waiting

In order to make this possible, some unnecessary features were removed.
This commit is contained in:
Ivan Chikish 2023-07-31 23:57:26 +03:00 committed by Ivan
parent 831a9fe012
commit d34287b2cc
51 changed files with 441 additions and 574 deletions

View file

@ -1,6 +1,7 @@
#include "atomic.hpp"
#if defined(__linux__)
// This definition is unused on Linux
#define USE_FUTEX
#elif !defined(_WIN32)
#define USE_STD
@ -40,8 +41,8 @@ namespace utils
// Total number of entries.
static constexpr usz s_hashtable_size = 1u << 17;
// Reference counter combined with shifted pointer (which is assumed to be 47 bit)
static constexpr uptr s_ref_mask = (1u << 17) - 1;
// Reference counter combined with shifted pointer (which is assumed to be 48 bit)
static constexpr uptr s_ref_mask = 0xffff;
// Fix for silly on-first-use initializer
static bool s_null_wait_cb(const void*, u64, u64){ return true; };
@ -55,163 +56,17 @@ static thread_local bool(*s_tls_one_time_wait_cb)(u64 attempts) = nullptr;
// Callback for notification functions for optimizations
static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = nullptr;
static inline bool operator &(atomic_wait::op lhs, atomic_wait::op_flag rhs)
{
return !!(static_cast<u8>(lhs) & static_cast<u8>(rhs));
}
// Compare data in memory with old value, and return true if they are equal
static NEVER_INLINE bool ptr_cmp(const void* data, u32 _size, u128 old128, u128 mask128, atomic_wait::info* ext = nullptr)
static NEVER_INLINE bool ptr_cmp(const void* data, u32 old, atomic_wait::info* ext = nullptr)
{
using atomic_wait::op;
using atomic_wait::op_flag;
const u8 size = static_cast<u8>(_size);
const op flag{static_cast<u8>(_size >> 8)};
bool result = false;
if (size <= 8)
{
u64 new_value = 0;
u64 old_value = static_cast<u64>(old128);
u64 mask = static_cast<u64>(mask128) & (u64{umax} >> ((64 - size * 8) & 63));
// Don't load memory on empty mask
switch (mask ? size : 0)
{
case 0: break;
case 1: new_value = reinterpret_cast<const atomic_t<u8>*>(data)->load(); break;
case 2: new_value = reinterpret_cast<const atomic_t<u16>*>(data)->load(); break;
case 4: new_value = reinterpret_cast<const atomic_t<u32>*>(data)->load(); break;
case 8: new_value = reinterpret_cast<const atomic_t<u64>*>(data)->load(); break;
default:
{
fmt::throw_exception("Bad size (arg=0x%x)", _size);
}
}
if (flag & op_flag::bit_not)
{
new_value = ~new_value;
}
if (!mask) [[unlikely]]
{
new_value = 0;
old_value = 0;
}
else
{
if (flag & op_flag::byteswap)
{
switch (size)
{
case 2:
{
new_value = stx::se_storage<u16>::swap(static_cast<u16>(new_value));
old_value = stx::se_storage<u16>::swap(static_cast<u16>(old_value));
mask = stx::se_storage<u16>::swap(static_cast<u16>(mask));
break;
}
case 4:
{
new_value = stx::se_storage<u32>::swap(static_cast<u32>(new_value));
old_value = stx::se_storage<u32>::swap(static_cast<u32>(old_value));
mask = stx::se_storage<u32>::swap(static_cast<u32>(mask));
break;
}
case 8:
{
new_value = stx::se_storage<u64>::swap(new_value);
old_value = stx::se_storage<u64>::swap(old_value);
mask = stx::se_storage<u64>::swap(mask);
break;
}
default:
{
break;
}
}
}
// Make most significant bit sign bit
const auto shv = std::countl_zero(mask);
new_value &= mask;
old_value &= mask;
new_value <<= shv;
old_value <<= shv;
}
s64 news = new_value;
s64 olds = old_value;
u64 newa = news < 0 ? (0ull - new_value) : new_value;
u64 olda = olds < 0 ? (0ull - old_value) : old_value;
switch (op{static_cast<u8>(static_cast<u8>(flag) & 0xf)})
{
case op::eq: result = old_value == new_value; break;
case op::slt: result = olds < news; break;
case op::sgt: result = olds > news; break;
case op::ult: result = old_value < new_value; break;
case op::ugt: result = old_value > new_value; break;
case op::alt: result = olda < newa; break;
case op::agt: result = olda > newa; break;
case op::pop:
{
// Count is taken from least significant byte and ignores some flags
const u64 count = static_cast<u64>(old128) & 0xff;
result = count < utils::popcnt64(new_value);
break;
}
default:
{
fmt::throw_exception("ptr_cmp(): unrecognized atomic wait operation.");
}
}
}
else if (size == 16 && (flag == op::eq || flag == (op::eq | op_flag::inverse)))
{
u128 new_value = 0;
u128 old_value = old128;
u128 mask = mask128;
// Don't load memory on empty mask
if (mask) [[likely]]
{
new_value = atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data));
}
// TODO
result = !((old_value ^ new_value) & mask);
}
else if (size > 16 && !~mask128 && (flag == op::eq || flag == (op::eq | op_flag::inverse)))
{
// Interpret old128 as a pointer to the old value
ensure(!(old128 >> (64 + 17)));
result = std::memcmp(data, reinterpret_cast<const void*>(static_cast<uptr>(old128)), size) == 0;
}
else
{
fmt::throw_exception("ptr_cmp(): no alternative operations are supported for non-standard atomic wait yet.");
}
if (flag & op_flag::inverse)
{
result = !result;
}
// Check other wait variables if provided
if (result)
if (reinterpret_cast<const atomic_t<u32>*>(data)->load() == old)
{
if (ext) [[unlikely]]
{
for (auto e = ext; e->data; e++)
{
if (!ptr_cmp(e->data, e->size, e->old, e->mask))
if (!ptr_cmp(e->data, e->old))
{
return false;
}
@ -283,18 +138,15 @@ namespace
#endif
// Essentially a fat semaphore
struct cond_handle
struct alignas(64) cond_handle
{
// Combined pointer (most significant 47 bits) and ref counter (17 least significant bits)
// Combined pointer (most significant 48 bits) and ref counter (16 least significant bits)
atomic_t<u64> ptr_ref;
u64 tid;
u128 mask;
u128 oldv;
u32 oldv;
u64 tsc0;
u16 link;
u8 size;
u8 flag;
atomic_t<u32> sync;
#ifdef USE_STD
@ -316,7 +168,7 @@ namespace
mtx.init(mtx);
#endif
ensure(!ptr_ref.exchange((iptr << 17) | 1));
ensure(!ptr_ref.exchange((iptr << 16) | 1));
}
void destroy()
@ -324,10 +176,7 @@ namespace
tid = 0;
tsc0 = 0;
link = 0;
size = 0;
flag = 0;
sync.release(0);
mask = 0;
oldv = 0;
#ifdef USE_STD
@ -517,7 +366,7 @@ namespace
// TLS storage for few allocaded "semaphores" to allow skipping initialization
static thread_local tls_cond_handler s_tls_conds{};
static u32 cond_alloc(uptr iptr, u128 mask, u32 tls_slot = -1)
static u32 cond_alloc(uptr iptr, u32 tls_slot = -1)
{
// Try to get cond from tls slot instead
u16* ptls = tls_slot >= std::size(s_tls_conds.cond) ? nullptr : s_tls_conds.cond + tls_slot;
@ -526,8 +375,7 @@ static u32 cond_alloc(uptr iptr, u128 mask, u32 tls_slot = -1)
{
// Fast reinitialize
const u32 id = std::exchange(*ptls, 0);
s_cond_list[id].mask = mask;
s_cond_list[id].ptr_ref.release((iptr << 17) | 1);
s_cond_list[id].ptr_ref.release((iptr << 16) | 1);
return id;
}
@ -581,7 +429,6 @@ static u32 cond_alloc(uptr iptr, u128 mask, u32 tls_slot = -1)
const u32 id = level3 * 64 + std::countr_one(bits);
// Initialize new "semaphore"
s_cond_list[id].mask = mask;
s_cond_list[id].init(iptr);
return id;
}
@ -625,8 +472,6 @@ static void cond_free(u32 cond_id, u32 tls_slot = -1)
{
// Fast finalization
cond->sync.release(0);
cond->size = 0;
cond->mask = 0;
*ptls = static_cast<u16>(cond_id);
return;
}
@ -652,7 +497,7 @@ static void cond_free(u32 cond_id, u32 tls_slot = -1)
s_cond_sem1.atomic_op(FN(x -= u128{1} << (level1 * 14)));
}
static cond_handle* cond_id_lock(u32 cond_id, u128 mask, uptr iptr = 0)
static cond_handle* cond_id_lock(u32 cond_id, uptr iptr = 0)
{
bool did_ref = false;
@ -673,7 +518,7 @@ static cond_handle* cond_id_lock(u32 cond_id, u128 mask, uptr iptr = 0)
return false;
}
if (iptr && (val >> 17) != iptr)
if (iptr && (val >> 16) != iptr)
{
// Pointer mismatch
return false;
@ -686,11 +531,6 @@ static cond_handle* cond_id_lock(u32 cond_id, u128 mask, uptr iptr = 0)
return false;
}
if (!(mask & cond->mask) && cond->size)
{
return false;
}
if (!did_ref)
{
val++;
@ -702,7 +542,7 @@ static cond_handle* cond_id_lock(u32 cond_id, u128 mask, uptr iptr = 0)
if (ok)
{
// Check other fields again
if (const u32 sync_val = cond->sync; sync_val == 0 || sync_val == 3 || (cond->size && !(mask & cond->mask)))
if (const u32 sync_val = cond->sync; sync_val == 0 || sync_val == 3)
{
did_ref = true;
continue;
@ -713,7 +553,7 @@ static cond_handle* cond_id_lock(u32 cond_id, u128 mask, uptr iptr = 0)
if ((old & s_ref_mask) == s_ref_mask)
{
fmt::throw_exception("Reference count limit (131071) reached in an atomic notifier.");
fmt::throw_exception("Reference count limit (%u) reached in an atomic notifier.", s_ref_mask);
}
break;
@ -736,8 +576,8 @@ namespace
u64 bits: 24; // Allocated bits
u64 prio: 24; // Reserved
u64 ref : 17; // Ref counter
u64 iptr: 47; // First pointer to use slot (to count used slots)
u64 ref : 16; // Ref counter
u64 iptr: 48; // First pointer to use slot (to count used slots)
};
// Need to spare 16 bits for ref counter
@ -760,7 +600,7 @@ namespace
static void slot_free(uptr ptr, atomic_t<u16>* slot, u32 tls_slot) noexcept;
template <typename F>
static auto slot_search(uptr iptr, u128 mask, F func) noexcept;
static auto slot_search(uptr iptr, F func) noexcept;
};
static_assert(sizeof(root_info) == 64);
@ -944,7 +784,7 @@ void root_info::slot_free(uptr iptr, atomic_t<u16>* slot, u32 tls_slot) noexcept
}
template <typename F>
FORCE_INLINE auto root_info::slot_search(uptr iptr, u128 mask, F func) noexcept
FORCE_INLINE auto root_info::slot_search(uptr iptr, F func) noexcept
{
u32 index = 0;
[[maybe_unused]] u32 total = 0;
@ -974,7 +814,7 @@ FORCE_INLINE auto root_info::slot_search(uptr iptr, u128 mask, F func) noexcept
for (u32 i = 0; i < cond_count; i++)
{
if (cond_id_lock(cond_ids[i], mask, iptr))
if (cond_id_lock(cond_ids[i], iptr))
{
if (func(cond_ids[i]))
{
@ -994,18 +834,82 @@ FORCE_INLINE auto root_info::slot_search(uptr iptr, u128 mask, F func) noexcept
}
}
SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old_value, u64 timeout, u128 mask, atomic_wait::info* ext)
SAFE_BUFFERS(void)
atomic_wait_engine::wait(const void* data, u32 old_value, u64 timeout, atomic_wait::info* ext)
{
const auto stamp0 = utils::get_unique_tsc();
uint ext_size = 0;
if (!s_tls_wait_cb(data, 0, stamp0))
#ifdef __linux__
::timespec ts{};
if (timeout + 1)
{
if (ext) [[unlikely]]
{
// futex_waitv uses absolute timeout
::clock_gettime(CLOCK_MONOTONIC, &ts);
}
ts.tv_sec += timeout / 1'000'000'000;
ts.tv_nsec += timeout % 1'000'000'000;
if (ts.tv_nsec > 1'000'000'000)
{
ts.tv_sec++;
ts.tv_nsec -= 1'000'000'000;
}
}
futex_waitv vec[atomic_wait::max_list]{};
vec[0].flags = FUTEX_32 | FUTEX_PRIVATE_FLAG;
vec[0].uaddr = reinterpret_cast<__u64>(data);
vec[0].val = old_value;
if (ext) [[unlikely]]
{
for (auto e = ext; e->data; e++)
{
ext_size++;
vec[ext_size].flags = FUTEX_32 | FUTEX_PRIVATE_FLAG;
vec[ext_size].uaddr = reinterpret_cast<__u64>(e->data);
vec[ext_size].val = e->old;
}
}
if (ext_size) [[unlikely]]
{
if (syscall(SYS_futex_waitv, +vec, ext_size + 1, 0, timeout + 1 ? &ts : nullptr, CLOCK_MONOTONIC) == -1)
{
if (errno == ENOSYS)
{
fmt::throw_exception("futex_waitv is not supported (Linux kernel is too old)");
}
if (errno == EINVAL)
{
fmt::throw_exception("futex_waitv: bad param");
}
}
}
else
{
if (futex(const_cast<void*>(data), FUTEX_WAIT_PRIVATE, old_value, timeout + 1 ? &ts : nullptr) == -1)
{
if (errno == EINVAL)
{
fmt::throw_exception("futex: bad param");
}
}
}
return;
#endif
if (!s_tls_wait_cb(data, 0, 0))
{
return;
}
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 17);
const auto stamp0 = utils::get_unique_tsc();
uint ext_size = 0;
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 16);
uptr iptr_ext[atomic_wait::max_list - 1]{};
@ -1026,18 +930,18 @@ SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old
}
}
iptr_ext[ext_size] = reinterpret_cast<uptr>(e->data) & (~s_ref_mask >> 17);
iptr_ext[ext_size] = reinterpret_cast<uptr>(e->data) & (~s_ref_mask >> 16);
ext_size++;
}
}
const u32 cond_id = cond_alloc(iptr, mask, 0);
const u32 cond_id = cond_alloc(iptr, 0);
u32 cond_id_ext[atomic_wait::max_list - 1]{};
for (u32 i = 0; i < ext_size; i++)
{
cond_id_ext[i] = cond_alloc(iptr_ext[i], ext[i].mask, i + 1);
cond_id_ext[i] = cond_alloc(iptr_ext[i], i + 1);
}
const auto slot = root_info::slot_alloc(iptr);
@ -1060,8 +964,6 @@ SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old
// Store some info for notifiers (some may be unused)
cond->link = 0;
cond->size = static_cast<u8>(size);
cond->flag = static_cast<u8>(size >> 8);
cond->oldv = old_value;
cond->tsc0 = stamp0;
@ -1071,8 +973,6 @@ SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old
{
// Extensions point to original cond_id, copy remaining info
cond_ext[i]->link = cond_id;
cond_ext[i]->size = static_cast<u8>(ext[i].size);
cond_ext[i]->flag = static_cast<u8>(ext[i].size >> 8);
cond_ext[i]->oldv = ext[i].old;
cond_ext[i]->tsc0 = stamp0;
@ -1105,7 +1005,7 @@ SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old
u64 attempts = 0;
while (ptr_cmp(data, size, old_value, mask, ext))
while (ptr_cmp(data, old_value, ext))
{
if (s_tls_one_time_wait_cb)
{
@ -1263,7 +1163,7 @@ SAFE_BUFFERS(void) atomic_wait_engine::wait(const void* data, u32 size, u128 old
}
template <bool NoAlert = false>
static u32 alert_sema(u32 cond_id, u32 size, u128 mask)
static u32 alert_sema(u32 cond_id, u32 size)
{
ensure(cond_id);
@ -1271,11 +1171,11 @@ static u32 alert_sema(u32 cond_id, u32 size, u128 mask)
u32 ok = 0;
if (!cond->size || mask & cond->mask)
if (true)
{
// Redirect if necessary
const auto _old = cond;
const auto _new = _old->link ? cond_id_lock(_old->link, u128(-1)) : _old;
const auto _new = _old->link ? cond_id_lock(_old->link) : _old;
if (_new && _new->tsc0 == _old->tsc0)
{
@ -1336,50 +1236,58 @@ void atomic_wait_engine::set_notify_callback(void(*cb)(const void*, u64))
s_tls_notify_cb = cb;
}
void atomic_wait_engine::notify_one(const void* data, u32 size, u128 mask)
void atomic_wait_engine::notify_one(const void* data)
{
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 17);
if (s_tls_notify_cb)
s_tls_notify_cb(data, 0);
root_info::slot_search(iptr, mask, [&](u32 cond_id)
#ifdef __linux__
futex(const_cast<void*>(data), FUTEX_WAKE_PRIVATE, 1);
#else
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 16);
root_info::slot_search(iptr, [&](u32 cond_id)
{
if (alert_sema(cond_id, size, mask))
if (alert_sema(cond_id, 4))
{
return true;
}
return false;
});
#endif
if (s_tls_notify_cb)
s_tls_notify_cb(data, -1);
}
SAFE_BUFFERS(void) atomic_wait_engine::notify_all(const void* data, u32 size, u128 mask)
SAFE_BUFFERS(void)
atomic_wait_engine::notify_all(const void* data)
{
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 17);
if (s_tls_notify_cb)
s_tls_notify_cb(data, 0);
#ifdef __linux__
futex(const_cast<void*>(data), FUTEX_WAKE_PRIVATE, 1);
#else
const uptr iptr = reinterpret_cast<uptr>(data) & (~s_ref_mask >> 16);
// Array count for batch notification
u32 count = 0;
// Array itself.
u32 cond_ids[128];
root_info::slot_search(iptr, mask, [&](u32 cond_id)
root_info::slot_search(iptr, [&](u32 cond_id)
{
if (count >= 128)
{
// Unusual big amount of sema: fallback to notify_one alg
alert_sema(cond_id, size, mask);
alert_sema(cond_id, 4);
return false;
}
u32 res = alert_sema<true>(cond_id, size, mask);
u32 res = alert_sema<true>(cond_id, 4);
if (~res <= u16{umax})
{
@ -1395,7 +1303,7 @@ SAFE_BUFFERS(void) atomic_wait_engine::notify_all(const void* data, u32 size, u1
{
const u32 cond_id = *(std::end(cond_ids) - i - 1);
if (!s_cond_list[cond_id].wakeup(size ? 1 : 2))
if (!s_cond_list[cond_id].wakeup(1))
{
*(std::end(cond_ids) - i - 1) = ~cond_id;
}
@ -1434,6 +1342,7 @@ SAFE_BUFFERS(void) atomic_wait_engine::notify_all(const void* data, u32 size, u1
{
cond_free(~*(std::end(cond_ids) - i - 1));
}
#endif
if (s_tls_notify_cb)
s_tls_notify_cb(data, -1);