rpcsx/rpcs3/util/shared_ptr.hpp
Elad 575a245f8d
IDM: Implement lock-free smart pointers (#16403)
Replaces `std::shared_pointer` with `stx::atomic_ptr` and `stx::shared_ptr`.

Notes to programmers:

* This pr kills the use of `dynamic_cast`, `std::dynamic_pointer_cast` and `std::weak_ptr` on IDM objects, possible replacement is to save the object ID on the base object, then use idm::check/get_unlocked to the destination type via the saved ID which may be null. Null pointer check is how you can tell type mismatch (as dynamic cast) or object destruction (as weak_ptr locking).
* Double-inheritance on IDM objects should be used with care, `stx::shared_ptr` does not support constant-evaluated pointer offsetting to parent/child type.
* `idm::check/get_unlocked` can now be used anywhere.

Misc fixes:
* Fixes some segfaults with RPCN with interaction with IDM.
* Fix deadlocks in access violation handler due locking recursion.
* Fixes race condition in process exit-spawn on memory containers read.
* Fix bug that theoretically can prevent RPCS3 from booting - fix `id_manager::typeinfo` comparison to compare members instead of `memcmp` which can fail spuriously on padding bytes.
* Ensure all IDM inherited types of base, either has `id_base` or `id_type` defined locally, this allows to make getters such as `idm::get_unlocked<lv2_socket, lv2_socket_raw>()` which were broken before. (requires save-states invalidation)
* Removes broken operator[] overload of `stx::shared_ptr` and `stx::single_ptr` for non-array types.
2024-12-22 20:59:48 +02:00

1134 lines
26 KiB
C++

#pragma once // No BOM and only basic ASCII in this header, or a neko will die
#include <cstdint>
#include <memory>
#include <utility>
#include "atomic.hpp"
#include "bless.hpp"
namespace stx
{
template <typename To, typename From>
constexpr bool same_ptr_implicit_v = std::is_convertible_v<const volatile From*, const volatile To*> ? PtrSame<From, To> : false;
template <typename T>
class single_ptr;
template <typename T>
class shared_ptr;
template <typename T>
class atomic_ptr;
// Basic assumption of userspace pointer size
constexpr uint c_ptr_size = 48;
// Use lower 16 bits as atomic_ptr internal counter of borrowed refs (pointer itself is shifted)
constexpr uint c_ref_mask = 0xffff, c_ref_size = 16;
// Remaining pointer bits
constexpr uptr c_ptr_mask = static_cast<uptr>(-1) << c_ref_size;
struct shared_counter
{
// Stored destructor
atomic_t<void (*)(shared_counter* _this) noexcept> destroy{};
// Reference counter
atomic_t<usz> refs{1};
};
template <usz Size, usz Align>
struct align_filler
{
};
template <usz Size, usz Align> requires (Align > Size)
struct align_filler<Size, Align>
{
char dummy[Align - Size];
};
// Control block with data and reference counter
template <typename T>
class shared_data final : align_filler<sizeof(shared_counter), alignof(T)>
{
public:
shared_counter m_ctr{};
T m_data;
template <typename... Args>
explicit constexpr shared_data(Args&&... args) noexcept
: m_data(std::forward<Args>(args)...)
{
}
};
template <typename T>
class shared_data<T[]> final : align_filler<sizeof(shared_counter) + sizeof(usz), alignof(T)>
{
public:
usz m_count{};
shared_counter m_ctr{};
constexpr shared_data() noexcept = default;
};
// Simplified unique pointer. In some cases, std::unique_ptr is preferred.
// This one is shared_ptr counterpart, it has a control block with refs and deleter.
// It's trivially convertible to shared_ptr, and back if refs == 1.
template <typename T>
class single_ptr
{
std::remove_extent_t<T>* m_ptr{};
shared_counter* d() const noexcept
{
// Shared counter, deleter, should be at negative offset
return std::launder(reinterpret_cast<shared_counter*>(reinterpret_cast<u64>(m_ptr) - sizeof(shared_counter)));
}
template <typename U>
friend class single_ptr;
template <typename U>
friend class shared_ptr;
template <typename U>
friend class atomic_ptr;
public:
using element_type = std::remove_extent_t<T>;
constexpr single_ptr() noexcept = default;
single_ptr(const single_ptr&) = delete;
// Default constructor or null_ptr should be used instead
[[deprecated("Use null_ptr")]] single_ptr(std::nullptr_t) = delete;
explicit single_ptr(shared_data<T>&, element_type* ptr) noexcept
: m_ptr(ptr)
{
}
single_ptr(single_ptr&& r) noexcept
: m_ptr(r.m_ptr)
{
r.m_ptr = nullptr;
}
template <typename U> requires same_ptr_implicit_v<T, U>
single_ptr(single_ptr<U>&& r) noexcept
{
m_ptr = r.m_ptr;
r.m_ptr = nullptr;
}
~single_ptr() noexcept
{
reset();
}
single_ptr& operator=(const single_ptr&) = delete;
[[deprecated("Use null_ptr")]] single_ptr& operator=(std::nullptr_t) = delete;
single_ptr& operator=(single_ptr&& r) noexcept
{
single_ptr(std::move(r)).swap(*this);
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
single_ptr& operator=(single_ptr<U>&& r) noexcept
{
single_ptr(std::move(r)).swap(*this);
return *this;
}
void reset() noexcept
{
if (m_ptr) [[likely]]
{
const auto o = d();
o->destroy.load()(o);
m_ptr = nullptr;
}
}
void swap(single_ptr& r) noexcept
{
std::swap(m_ptr, r.m_ptr);
}
element_type* get() const noexcept
{
return m_ptr;
}
element_type& operator*() const noexcept requires (!std::is_void_v<element_type>)
{
return *m_ptr;
}
element_type* operator->() const noexcept
{
return m_ptr;
}
element_type& operator[](std::ptrdiff_t idx) const noexcept requires (!std::is_void_v<element_type> && std::is_array_v<T>)
{
return m_ptr[idx];
}
template <typename... Args> requires (std::is_invocable_v<T, Args&&...>)
decltype(auto) operator()(Args&&... args) const noexcept
{
return std::invoke(*m_ptr, std::forward<Args>(args)...);
}
explicit constexpr operator bool() const noexcept
{
return m_ptr != nullptr;
}
// "Moving" "static cast"
template <typename U> requires PtrSame<T, U>
explicit operator single_ptr<U>() && noexcept
{
single_ptr<U> r;
r.m_ptr = static_cast<decltype(r.m_ptr)>(std::exchange(m_ptr, nullptr));
return r;
}
template <typename U> requires same_ptr_implicit_v<T, U>
bool operator==(const single_ptr<U>& r) const noexcept
{
return get() == r.get();
}
};
#ifndef _MSC_VER
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
#endif
template <typename T, bool Init = true, typename... Args>
requires(!std::is_unbounded_array_v<T> && (Init || std::is_array_v<T>) && (Init || !sizeof...(Args)))
static single_ptr<T> make_single(Args&&... args) noexcept
{
static_assert(offsetof(shared_data<T>, m_data) - offsetof(shared_data<T>, m_ctr) == sizeof(shared_counter));
using etype = std::remove_extent_t<T>;
shared_data<T>* ptr = nullptr;
if constexpr (!std::is_array_v<T>)
{
ptr = new shared_data<T>(std::forward<Args>(args)...);
}
else
{
ptr = new shared_data<T>;
if constexpr (Init && std::is_array_v<T>)
{
// Weird case, destroy and reinitialize every fixed array arg (fill)
for (auto& e : ptr->m_data)
{
e.~etype();
new (&e) etype(std::forward<Args>(args)...);
}
}
}
ptr->m_ctr.destroy.raw() = [](shared_counter* _this) noexcept
{
delete reinterpret_cast<shared_data<T>*>(reinterpret_cast<u64>(_this) - offsetof(shared_data<T>, m_ctr));
};
return single_ptr<T>(*ptr, &ptr->m_data);
}
template <typename T, bool Init = true, usz Align = alignof(std::remove_extent_t<T>)>
requires (std::is_unbounded_array_v<T> && std::is_default_constructible_v<std::remove_extent_t<T>>)
static single_ptr<T> make_single(usz count) noexcept
{
static_assert(sizeof(shared_data<T>) - offsetof(shared_data<T>, m_ctr) == sizeof(shared_counter));
using etype = std::remove_extent_t<T>;
const usz size = sizeof(shared_data<T>) + count * sizeof(etype);
std::byte* bytes = nullptr;
if constexpr (Align > (__STDCPP_DEFAULT_NEW_ALIGNMENT__))
{
bytes = static_cast<std::byte*>(::operator new(size, std::align_val_t{Align}));
}
else
{
bytes = new std::byte[size];
}
// Initialize control block
shared_data<T>* ptr = new (reinterpret_cast<shared_data<T>*>(bytes)) shared_data<T>();
// Initialize array next to the control block
etype* arr = reinterpret_cast<etype*>(bytes + sizeof(shared_data<T>));
if constexpr (Init)
{
std::uninitialized_value_construct_n(arr, count);
}
else
{
std::uninitialized_default_construct_n(arr, count);
}
ptr->m_count = count;
ptr->m_ctr.destroy.raw() = [](shared_counter* _this) noexcept
{
shared_data<T>* ptr = reinterpret_cast<shared_data<T>*>(reinterpret_cast<u64>(_this) - offsetof(shared_data<T>, m_ctr));
std::byte* bytes = reinterpret_cast<std::byte*>(ptr);
std::destroy_n(std::launder(reinterpret_cast<etype*>(bytes + sizeof(shared_data<T>))), ptr->m_count);
ptr->~shared_data<T>();
if constexpr (Align > (__STDCPP_DEFAULT_NEW_ALIGNMENT__))
{
::operator delete[](bytes, std::align_val_t{Align});
}
else
{
delete[] bytes;
}
};
return single_ptr<T>(*ptr, std::launder(arr));
}
template <typename T>
static single_ptr<std::remove_reference_t<T>> make_single_value(T&& value)
{
return make_single<std::remove_reference_t<T>>(std::forward<T>(value));
}
#ifndef _MSC_VER
#pragma GCC diagnostic pop
#endif
// Simplified shared pointer
template <typename T>
class shared_ptr
{
std::remove_extent_t<T>* m_ptr{};
shared_counter* d() const noexcept
{
// Shared counter, deleter, should be at negative offset
return std::launder(reinterpret_cast<shared_counter*>(reinterpret_cast<u64>(m_ptr) - sizeof(shared_counter)));
}
template <typename U>
friend class shared_ptr;
template <typename U>
friend class atomic_ptr;
public:
using element_type = std::remove_extent_t<T>;
constexpr shared_ptr() noexcept = default;
shared_ptr(const shared_ptr& r) noexcept
: m_ptr(r.m_ptr)
{
if (m_ptr)
d()->refs++;
}
// Default constructor or null_ptr constant should be used instead
[[deprecated("Use null_ptr")]] shared_ptr(std::nullptr_t) = delete;
// Not-so-aliasing constructor: emulates std::enable_shared_from_this without its overhead
explicit shared_ptr(T* _this) noexcept
: m_ptr(_this)
{
// Random checks which may fail on invalid pointer
ensure((reinterpret_cast<u64>(d()->destroy) - 0x10000) >> 47 == 0);
ensure((d()->refs++ - 1) >> 58 == 0);
}
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr(const shared_ptr<U>& r) noexcept
{
m_ptr = r.m_ptr;
if (m_ptr)
d()->refs++;
}
shared_ptr(shared_ptr&& r) noexcept
: m_ptr(r.m_ptr)
{
r.m_ptr = nullptr;
}
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr(shared_ptr<U>&& r) noexcept
{
m_ptr = r.m_ptr;
r.m_ptr = nullptr;
}
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr(single_ptr<U>&& r) noexcept
{
m_ptr = r.m_ptr;
r.m_ptr = nullptr;
}
~shared_ptr() noexcept
{
reset();
}
shared_ptr& operator=(const shared_ptr& r) noexcept
{
shared_ptr(r).swap(*this);
return *this;
}
[[deprecated("Use null_ptr")]] shared_ptr& operator=(std::nullptr_t) = delete;
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr& operator=(const shared_ptr<U>& r) noexcept
{
shared_ptr(r).swap(*this);
return *this;
}
shared_ptr& operator=(shared_ptr&& r) noexcept
{
shared_ptr(std::move(r)).swap(*this);
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr& operator=(shared_ptr<U>&& r) noexcept
{
shared_ptr(std::move(r)).swap(*this);
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
shared_ptr& operator=(single_ptr<U>&& r) noexcept
{
shared_ptr(std::move(r)).swap(*this);
return *this;
}
// Set to null
void reset() noexcept
{
const auto o = d();
if (m_ptr && !--o->refs) [[unlikely]]
{
o->destroy(o);
m_ptr = nullptr;
}
}
// Converts to unique (single) ptr if reference is 1, otherwise returns null. Nullifies self.
template <typename U> requires PtrSame<T, U>
explicit operator single_ptr<U>() && noexcept
{
const auto o = d();
if (m_ptr && !--o->refs)
{
// Convert last reference to single_ptr instance.
o->refs.release(1);
single_ptr<T> r;
r.m_ptr = static_cast<decltype(r.m_ptr)>(std::exchange(m_ptr, nullptr));
return r;
}
// Otherwise, both pointers are gone. Didn't seem right to do it in the constructor.
m_ptr = nullptr;
return {};
}
void swap(shared_ptr& r) noexcept
{
std::swap(this->m_ptr, r.m_ptr);
}
element_type* get() const noexcept
{
return m_ptr;
}
element_type& operator*() const noexcept requires (!std::is_void_v<element_type>)
{
return *m_ptr;
}
element_type* operator->() const noexcept
{
return m_ptr;
}
element_type& operator[](std::ptrdiff_t idx) const noexcept requires (!std::is_void_v<element_type> && std::is_array_v<T>)
{
return m_ptr[idx];
}
template <typename... Args> requires (std::is_invocable_v<T, Args&&...>)
decltype(auto) operator()(Args&&... args) const noexcept
{
return std::invoke(*m_ptr, std::forward<Args>(args)...);
}
usz use_count() const noexcept
{
if (m_ptr)
{
return d()->refs;
}
else
{
return 0;
}
}
explicit constexpr operator bool() const noexcept
{
return m_ptr != nullptr;
}
// Basic "static cast" support
template <typename U> requires PtrSame<T, U>
explicit operator shared_ptr<U>() const& noexcept
{
if (m_ptr)
{
d()->refs++;
}
shared_ptr<U> r;
r.m_ptr = static_cast<decltype(r.m_ptr)>(m_ptr);
return r;
}
// "Moving" "static cast"
template <typename U> requires PtrSame<T, U>
explicit operator shared_ptr<U>() && noexcept
{
shared_ptr<U> r;
r.m_ptr = static_cast<decltype(r.m_ptr)>(std::exchange(m_ptr, nullptr));
return r;
}
template <typename U> requires same_ptr_implicit_v<T, U>
bool operator==(const shared_ptr<U>& r) const noexcept
{
return get() == r.get();
}
};
template <typename T, typename... Args>
requires(!std::is_unbounded_array_v<T> && std::is_constructible_v<std::remove_extent_t<T>, Args&& ...>)
static shared_ptr<T> make_shared(Args&&... args) noexcept
{
return make_single<T>(std::forward<Args>(args)...);
}
template <typename T, bool Init = true>
requires (std::is_unbounded_array_v<T> && std::is_default_constructible_v<std::remove_extent_t<T>>)
static shared_ptr<T> make_shared(usz count) noexcept
{
return make_single<T, Init>(count);
}
template <typename T>
requires (std::is_constructible_v<std::remove_reference_t<T>, T&&>)
static shared_ptr<std::remove_reference_t<T>> make_shared_value(T&& value)
{
return make_single_value(std::forward<T>(value));
}
// Atomic simplified shared pointer
template <typename T>
class atomic_ptr
{
mutable atomic_t<uptr> m_val{0};
static shared_counter* d(uptr val)
{
return std::launder(reinterpret_cast<shared_counter*>((val >> c_ref_size) - sizeof(shared_counter)));
}
shared_counter* d() const noexcept
{
return d(m_val);
}
template <typename U>
friend class atomic_ptr;
public:
using element_type = std::remove_extent_t<T>;
using shared_type = shared_ptr<T>;
constexpr atomic_ptr() noexcept = default;
// Optimized value construct
template <typename... Args> requires (!(sizeof...(Args) == 1 && (std::is_same_v<std::remove_cvref_t<Args>, shared_type> || ...)) && std::is_constructible_v<T, Args...>)
explicit atomic_ptr(Args&&... args) noexcept
{
shared_type r = make_single<T>(std::forward<Args>(args)...);
m_val = reinterpret_cast<uptr>(std::exchange(r.m_ptr, nullptr)) << c_ref_size;
d()->refs.raw() += c_ref_mask;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr(const shared_ptr<U>& r) noexcept
{
// Obtain a ref + as many refs as an atomic_ptr can additionally reference
m_val = reinterpret_cast<uptr>(r.m_ptr) << c_ref_size;
if (m_val)
d()->refs += c_ref_mask + 1;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr(shared_ptr<U>&& r) noexcept
{
m_val = reinterpret_cast<uptr>(r.m_ptr) << c_ref_size;
r.m_ptr = nullptr;
if (m_val)
d()->refs += c_ref_mask;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr(single_ptr<U>&& r) noexcept
{
m_val = reinterpret_cast<uptr>(r.m_ptr) << c_ref_size;
r.m_ptr = nullptr;
if (m_val)
d()->refs += c_ref_mask;
}
~atomic_ptr()
{
const uptr v = m_val.raw();
if (v >> c_ref_size)
{
const auto o = d(v);
if (!o->refs.sub_fetch(c_ref_mask + 1 - (v & c_ref_mask)))
{
o->destroy.load()(o);
}
}
}
// Optimized value assignment
atomic_ptr& operator=(std::remove_cv_t<T> value) noexcept
{
shared_type r = make_single<T>(std::move(value));
r.d()->refs.raw() += c_ref_mask;
atomic_ptr old;
old.m_val.raw() = m_val.exchange(reinterpret_cast<uptr>(std::exchange(r.m_ptr, nullptr)) << c_ref_size);
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr& operator=(const shared_ptr<U>& r) noexcept
{
store(r);
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr& operator=(shared_ptr<U>&& r) noexcept
{
store(std::move(r));
return *this;
}
template <typename U> requires same_ptr_implicit_v<T, U>
atomic_ptr& operator=(single_ptr<U>&& r) noexcept
{
store(std::move(r));
return *this;
}
void reset() noexcept
{
store(shared_type{});
}
shared_type load() const noexcept
{
shared_type r;
// Add reference
const auto [prev, did_ref] = m_val.fetch_op([](uptr& val)
{
if (val >> c_ref_size)
{
val++;
return true;
}
return false;
});
if (!did_ref)
{
// Null pointer
return r;
}
// Set referenced pointer
r.m_ptr = std::launder(reinterpret_cast<element_type*>(prev >> c_ref_size));
r.d()->refs++;
// Dereference if still the same pointer
const auto [_, did_deref] = m_val.fetch_op([prev = prev](uptr& val)
{
if (val >> c_ref_size == prev >> c_ref_size)
{
val--;
return true;
}
return false;
});
if (!did_deref)
{
// Otherwise fix ref count (atomic_ptr has been overwritten)
r.d()->refs--;
}
return r;
}
// Atomically inspect pointer with the possibility to reference it if necessary
template <typename F, typename RT = std::invoke_result_t<F, const shared_type&>>
RT peek_op(F op) const noexcept
{
shared_type r;
// Add reference
const auto [prev, did_ref] = m_val.fetch_op([](uptr& val)
{
if (val >> c_ref_size)
{
val++;
return true;
}
return false;
});
// Set fake unreferenced pointer
if (did_ref)
{
r.m_ptr = std::launder(reinterpret_cast<element_type*>(prev >> c_ref_size));
}
// Result temp storage
[[maybe_unused]] std::conditional_t<std::is_void_v<RT>, int, RT> result;
// Invoke
if constexpr (std::is_void_v<RT>)
{
std::invoke(op, std::as_const(r));
if (!did_ref)
{
return;
}
}
else
{
result = std::invoke(op, std::as_const(r));
if (!did_ref)
{
return result;
}
}
// Dereference if still the same pointer
const auto [_, did_deref] = m_val.fetch_op([prev = prev](uptr& val)
{
if (val >> c_ref_size == prev >> c_ref_size)
{
val--;
return true;
}
return false;
});
if (did_deref)
{
// Deactivate fake pointer
r.m_ptr = nullptr;
}
if constexpr (std::is_void_v<RT>)
{
return;
}
else
{
return result;
}
}
// Create an object from variadic args
// If a type needs shared_type to be constructed, std::reference_wrapper can be used
template <typename... Args> requires (!(sizeof...(Args) == 1 && (std::is_same_v<std::remove_cvref_t<Args>, shared_type> || ...)) && std::is_constructible_v<T, Args...>)
void store(Args&&... args) noexcept
{
shared_type r = make_single<T>(std::forward<Args>(args)...);
r.d()->refs.raw() += c_ref_mask;
atomic_ptr old;
old.m_val.raw() = m_val.exchange(reinterpret_cast<uptr>(std::exchange(r.m_ptr, nullptr)) << c_ref_size);
}
void store(shared_type value) noexcept
{
if (value.m_ptr)
{
// Consume value and add refs
value.d()->refs += c_ref_mask;
}
atomic_ptr old;
old.m_val.raw() = m_val.exchange(reinterpret_cast<uptr>(std::exchange(value.m_ptr, nullptr)) << c_ref_size);
}
template <typename... Args> requires (!(sizeof...(Args) == 1 && (std::is_same_v<std::remove_cvref_t<Args>, shared_type> || ...)) && std::is_constructible_v<T, Args...>)
[[nodiscard]] shared_type exchange(Args&&... args) noexcept
{
shared_type r = make_single<T>(std::forward<Args>(args)...);
r.d()->refs.raw() += c_ref_mask;
atomic_ptr old;
old.m_val.raw() += m_val.exchange(reinterpret_cast<uptr>(r.m_ptr) << c_ref_size);
old.m_val.raw() += 1;
r.m_ptr = std::launder(reinterpret_cast<element_type*>(old.m_val >> c_ref_size));
return r;
}
[[nodiscard]] shared_type exchange(shared_type value) noexcept
{
if (value.m_ptr)
{
// Consume value and add refs
value.d()->refs += c_ref_mask;
}
atomic_ptr old;
old.m_val.raw() += m_val.exchange(reinterpret_cast<uptr>(value.m_ptr) << c_ref_size);
old.m_val.raw() += 1;
value.m_ptr = std::launder(reinterpret_cast<element_type*>(old.m_val >> c_ref_size));
return value;
}
// Ineffective
[[nodiscard]] bool compare_exchange(shared_type& cmp_and_old, shared_type exch)
{
const uptr _old = reinterpret_cast<uptr>(cmp_and_old.m_ptr);
const uptr _new = reinterpret_cast<uptr>(exch.m_ptr);
if (exch.m_ptr)
{
exch.d()->refs += c_ref_mask;
}
atomic_ptr old;
const uptr _val = m_val.fetch_op([&](uptr& val)
{
if (val >> c_ref_size == _old)
{
// Set new value
val = _new << c_ref_size;
}
else if (val)
{
// Reference previous value
val++;
}
});
if (_val >> c_ref_size == _old)
{
// Success (exch is consumed, cmp_and_old is unchanged)
if (exch.m_ptr)
{
exch.m_ptr = nullptr;
}
// Cleanup
old.m_val.raw() = _val;
return true;
}
atomic_ptr old_exch;
old_exch.m_val.raw() = reinterpret_cast<uptr>(std::exchange(exch.m_ptr, nullptr)) << c_ref_size;
// Set to reset old cmp_and_old value
old.m_val.raw() = (reinterpret_cast<uptr>(cmp_and_old.m_ptr) << c_ref_size) | c_ref_mask;
if (!_val)
{
return false;
}
// Set referenced pointer
cmp_and_old.m_ptr = std::launder(reinterpret_cast<element_type*>(_val >> c_ref_size));
cmp_and_old.d()->refs++;
// Dereference if still the same pointer
const auto [_, did_deref] = m_val.fetch_op([_val](uptr& val)
{
if (val >> c_ref_size == _val >> c_ref_size)
{
val--;
return true;
}
return false;
});
if (!did_deref)
{
// Otherwise fix ref count (atomic_ptr has been overwritten)
cmp_and_old.d()->refs--;
}
return false;
}
// Unoptimized
template <typename U> requires same_ptr_implicit_v<T, U>
shared_type compare_and_swap(const shared_ptr<U>& cmp, shared_type exch)
{
shared_type old = cmp;
static_cast<void>(compare_exchange(old, std::move(exch)));
return old;
}
// More lightweight than compare_exchange
template <typename U> requires same_ptr_implicit_v<T, U>
bool compare_and_swap_test(const shared_ptr<U>& cmp, shared_type exch)
{
const uptr _old = reinterpret_cast<uptr>(cmp.m_ptr);
const uptr _new = reinterpret_cast<uptr>(exch.m_ptr);
if (exch.m_ptr)
{
exch.d()->refs += c_ref_mask;
}
atomic_ptr old;
const auto [_val, ok] = m_val.fetch_op([&](uptr& val)
{
if (val >> c_ref_size == _old)
{
// Set new value
val = _new << c_ref_size;
return true;
}
return false;
});
if (ok)
{
// Success (exch is consumed, cmp_and_old is unchanged)
exch.m_ptr = nullptr;
old.m_val.raw() = _val;
return true;
}
// Failure (return references)
old.m_val.raw() = reinterpret_cast<uptr>(std::exchange(exch.m_ptr, nullptr)) << c_ref_size;
return false;
}
// Unoptimized
template <typename U> requires same_ptr_implicit_v<T, U>
shared_type compare_and_swap(const single_ptr<U>& cmp, shared_type exch)
{
shared_type old = cmp;
static_cast<void>(compare_exchange(old, std::move(exch)));
return old;
}
// Supplementary
template <typename U> requires same_ptr_implicit_v<T, U>
bool compare_and_swap_test(const single_ptr<U>& cmp, shared_type exch)
{
return compare_and_swap_test(reinterpret_cast<const shared_ptr<U>&>(cmp), std::move(exch));
}
// Helper utility
void push_head(shared_type& next, shared_type exch) noexcept
{
if (exch.m_ptr) [[likely]]
{
// Add missing references first
exch.d()->refs += c_ref_mask;
}
if (next.m_ptr) [[unlikely]]
{
// Just in case
next.reset();
}
atomic_ptr old;
old.m_val.raw() = m_val.load();
do
{
// Update old head with current value
next.m_ptr = reinterpret_cast<T*>(old.m_val.raw() >> c_ref_size);
} while (!m_val.compare_exchange(old.m_val.raw(), reinterpret_cast<uptr>(exch.m_ptr) << c_ref_size));
// This argument is consumed (moved from)
exch.m_ptr = nullptr;
if (next.m_ptr)
{
// Compensation for `next` assignment
old.m_val.raw() += 1;
}
}
// Simple atomic load is much more effective than load(), but it's a non-owning reference
T* observe() const noexcept
{
return reinterpret_cast<T*>(m_val >> c_ref_size);
}
explicit constexpr operator bool() const noexcept
{
return m_val != 0;
}
template <typename U> requires same_ptr_implicit_v<T, U>
bool is_equal(const shared_ptr<U>& r) const noexcept
{
return static_cast<volatile const void*>(observe()) == r.get();
}
template <typename U> requires same_ptr_implicit_v<T, U>
bool is_equal(const single_ptr<U>& r) const noexcept
{
return static_cast<volatile const void*>(observe()) == r.get();
}
void wait(std::nullptr_t, atomic_wait_timeout timeout = atomic_wait_timeout::inf)
{
utils::bless<atomic_t<u32>>(&m_val)[1].wait(0, timeout);
}
void notify_one()
{
utils::bless<atomic_t<u32>>(&m_val)[1].notify_one();
}
void notify_all()
{
utils::bless<atomic_t<u32>>(&m_val)[1].notify_all();
}
};
// Some nullptr replacement for few cases
constexpr struct null_ptr_t
{
template <typename T>
constexpr operator single_ptr<T>() const noexcept
{
return {};
}
template <typename T>
constexpr operator shared_ptr<T>() const noexcept
{
return {};
}
template <typename T>
constexpr operator atomic_ptr<T>() const noexcept
{
return {};
}
explicit constexpr operator bool() const noexcept
{
return false;
}
constexpr std::nullptr_t get() const noexcept
{
return nullptr;
}
} null_ptr;
}
template <typename T>
struct std::hash<stx::single_ptr<T>>
{
usz operator()(const stx::single_ptr<T>& x) const noexcept
{
return std::hash<T*>()(x.get());
}
};
template <typename T>
struct std::hash<stx::shared_ptr<T>>
{
usz operator()(const stx::shared_ptr<T>& x) const noexcept
{
return std::hash<T*>()(x.get());
}
};
using stx::null_ptr;
using stx::single_ptr;
using stx::shared_ptr;
using stx::atomic_ptr;
using stx::make_single;
using stx::make_shared;
using stx::make_single_value;
using stx::make_shared_value;