#pragma once // No BOM and only basic ASCII in this header, or a neko will die #include #include #include "atomic.hpp" namespace stx { template constexpr bool is_same_ptr() noexcept { // I would like to make it a trait if there is some trick. // And believe it shall possible with constexpr bit_cast. // Otherwise I hope it will compile in null code anyway. const auto u = reinterpret_cast(0x11223344556); const volatile void* x = u; return static_cast(u) == x; } // TODO template constexpr bool is_same_ptr_v = true; template constexpr bool is_same_ptr_cast_v = std::is_same_v || std::is_convertible_v && is_same_ptr_v; template class single_ptr; template class shared_ptr; template class atomic_ptr; // Basic assumption of userspace pointer size constexpr uint c_ptr_size = 47; // Use lower 17 bits as atomic_ptr internal refcounter (pointer is shifted) constexpr uint c_ref_mask = 0x1ffff, c_ref_size = 17; struct shared_counter { // Stored destructor void (*destroy)(shared_counter* _this); // Reference counter atomic_t refs{1}; }; template struct align_filler { }; template struct align_filler Size)>> { char dummy[Align - Size]; }; // Control block with data and reference counter template class alignas(T) shared_data final : align_filler { public: shared_counter m_ctr; T m_data; template explicit constexpr shared_data(Args&&... args) noexcept : m_ctr{} , m_data(std::forward(args)...) { } }; template class alignas(T) shared_data final : align_filler { public: std::size_t m_count; shared_counter m_ctr; constexpr shared_data() noexcept = default; }; // Simplified unique pointer. Wwell, not simplified, std::unique_ptr is preferred. // This one is shared_ptr counterpart, it has a control block with refs = 1. template class single_ptr { std::remove_extent_t* m_ptr{}; shared_counter* d() const noexcept { // Shared counter, deleter, should be at negative offset return std::launder(reinterpret_cast(reinterpret_cast(m_ptr) - sizeof(shared_counter))); } template friend class shared_ptr; template friend class atomic_ptr; public: using pointer = T*; using element_type = std::remove_extent_t; constexpr single_ptr() noexcept = default; constexpr single_ptr(std::nullptr_t) noexcept {} single_ptr(const single_ptr&) = delete; single_ptr(single_ptr&& r) noexcept : m_ptr(r.m_ptr) { r.m_ptr = nullptr; } template >> single_ptr(single_ptr&& r) noexcept : m_ptr(r.m_ptr) { verify(HERE), is_same_ptr(); r.m_ptr = nullptr; } ~single_ptr() { reset(); } single_ptr& operator=(const single_ptr&) = delete; single_ptr& operator=(std::nullptr_t) noexcept { reset(); } single_ptr& operator=(single_ptr&& r) noexcept { m_ptr = r.m_ptr; r.m_ptr = nullptr; return *this; } template >> single_ptr& operator=(single_ptr&& r) noexcept { verify(HERE), is_same_ptr(); m_ptr = r.m_ptr; r.m_ptr = nullptr; return *this; } void reset() noexcept { if (m_ptr) [[likely]] { const auto o = d(); o->destroy(o); m_ptr = nullptr; } } void swap(single_ptr& r) noexcept { std::swap(m_ptr, r.m_ptr); } element_type* get() const noexcept { return m_ptr; } decltype(auto) operator*() const noexcept { if constexpr (std::is_void_v) { return; } else { return *m_ptr; } } element_type* operator->() const noexcept { return m_ptr; } decltype(auto) operator[](std::ptrdiff_t idx) const noexcept { if constexpr (std::is_void_v) { return; } else if constexpr (std::is_array_v) { return m_ptr[idx]; } else { return *m_ptr; } } explicit constexpr operator bool() const noexcept { return m_ptr != nullptr; } // "Moving" "static cast" template (std::declval())), typename = std::enable_if_t>> explicit operator single_ptr() && noexcept { verify(HERE), is_same_ptr(); single_ptr r; r.m_ptr = static_cast(std::exchange(m_ptr, nullptr)); return r; } }; template static std::enable_if_t) && (Init || !sizeof...(Args)), single_ptr> make_single(Args&&... args) noexcept { static_assert(offsetof(shared_data, m_data) - offsetof(shared_data, m_ctr) == sizeof(shared_counter)); using etype = std::remove_extent_t; shared_data* ptr = nullptr; if constexpr (Init && !std::is_array_v) { ptr = new shared_data(std::forward(args)...); } else { ptr = new shared_data; if constexpr (Init && std::is_array_v) { // Weird case, destroy and reinitialize every fixed array arg (fill) for (auto& e : ptr->m_data) { e.~etype(); new (&e) etype(std::forward(args)...); } } } ptr->m_ctr.destroy = [](shared_counter* _this) { delete reinterpret_cast*>(reinterpret_cast(_this) - offsetof(shared_data, m_ctr)); }; single_ptr r; if constexpr (std::is_array_v) { reinterpret_cast(r) = +ptr->m_data; } else { reinterpret_cast(r) = &ptr->m_data; } return r; } template static std::enable_if_t, single_ptr> make_single(std::size_t count) noexcept { static_assert(sizeof(shared_data) - offsetof(shared_data, m_ctr) == sizeof(shared_counter)); using etype = std::remove_extent_t; const std::size_t size = sizeof(shared_data) + count * sizeof(etype); std::byte* bytes = nullptr; if constexpr (alignof(etype) > (__STDCPP_DEFAULT_NEW_ALIGNMENT__)) { bytes = new (std::align_val_t{alignof(etype)}) std::byte[size]; } else { bytes = new std::byte[size]; } // Initialize control block shared_data* ptr = new (reinterpret_cast*>(bytes)) shared_data(); // Initialize array next to the control block etype* arr = reinterpret_cast(bytes + sizeof(shared_data)); if constexpr (Init) { std::uninitialized_value_construct_n(arr, count); } else { std::uninitialized_default_construct_n(arr, count); } ptr->m_count = count; ptr->m_ctr.destroy = [](shared_counter* _this) { shared_data* ptr = reinterpret_cast*>(reinterpret_cast(_this) - offsetof(shared_data, m_ctr)); std::byte* bytes = reinterpret_cast(ptr); std::destroy_n(std::launder(reinterpret_cast(bytes + sizeof(shared_data))), ptr->m_count); ptr->~shared_data(); if constexpr (alignof(etype) > (__STDCPP_DEFAULT_NEW_ALIGNMENT__)) { ::operator delete[](bytes, std::align_val_t{alignof(etype)}); } else { delete[] bytes; } }; single_ptr r; reinterpret_cast*&>(r) = std::launder(arr); return r; } // Simplified shared pointer template class shared_ptr { std::remove_extent_t* m_ptr{}; shared_counter* d() const noexcept { // Shared counter, deleter, should be at negative offset return std::launder(reinterpret_cast(reinterpret_cast(m_ptr) - sizeof(shared_counter))); } template friend class atomic_ptr; public: using pointer = T*; using element_type = std::remove_extent_t; constexpr shared_ptr() noexcept = default; constexpr shared_ptr(std::nullptr_t) noexcept {} shared_ptr(const shared_ptr& r) noexcept : m_ptr(r.m_ptr) { if (m_ptr) d()->refs++; } template >> shared_ptr(const shared_ptr& r) noexcept : m_ptr(r.m_ptr) { verify(HERE), is_same_ptr(); if (m_ptr) d()->refs++; } shared_ptr(shared_ptr&& r) noexcept : m_ptr(r.m_ptr) { r.m_ptr = nullptr; } template >> shared_ptr(shared_ptr&& r) noexcept : m_ptr(r.m_ptr) { verify(HERE), is_same_ptr(); r.m_ptr = nullptr; } template >> shared_ptr(single_ptr&& r) noexcept : m_ptr(r.m_ptr) { verify(HERE), is_same_ptr(); r.m_ptr = nullptr; } ~shared_ptr() { reset(); } shared_ptr& operator=(const shared_ptr& r) noexcept { shared_ptr(r).swap(*this); return *this; } template >> shared_ptr& operator=(const shared_ptr& r) noexcept { verify(HERE), is_same_ptr(); shared_ptr(r).swap(*this); return *this; } shared_ptr& operator=(shared_ptr&& r) noexcept { shared_ptr(std::move(r)).swap(*this); return *this; } template >> shared_ptr& operator=(shared_ptr&& r) noexcept { verify(HERE), is_same_ptr(); shared_ptr(std::move(r)).swap(*this); return *this; } template >> shared_ptr& operator=(single_ptr&& r) noexcept { verify(HERE), is_same_ptr(); shared_ptr(std::move(r)).swap(*this); return *this; } // Set to null void reset() noexcept { const auto o = d(); if (m_ptr && !--o->refs) [[unlikely]] { o->destroy(o); m_ptr = nullptr; } } // Converts to unique (single) ptr if reference is 1, otherwise returns null. Nullifies self. template (std::declval())), typename = std::enable_if_t>> explicit operator single_ptr() && noexcept { verify(HERE), is_same_ptr(); const auto o = d(); if (m_ptr && !--o->refs) { // Convert last reference to single_ptr instance. o->refs.release(1); single_ptr r; r.m_ptr = static_cast(std::exchange(m_ptr, nullptr)); return r; } // Otherwise, both pointers are gone. Didn't seem right to do it in the constructor. m_ptr = nullptr; return {}; } void swap(shared_ptr& r) noexcept { std::swap(this->m_ptr, r.m_ptr); } element_type* get() const noexcept { return m_ptr; } decltype(auto) operator*() const noexcept { if constexpr (std::is_void_v) { return; } else { return *m_ptr; } } element_type* operator->() const noexcept { return m_ptr; } decltype(auto) operator[](std::ptrdiff_t idx) const noexcept { if constexpr (std::is_void_v) { return; } else if constexpr (std::is_array_v) { return m_ptr[idx]; } else { return *m_ptr; } } std::size_t use_count() const noexcept { if (m_ptr) { return d()->refs; } else { return 0; } } explicit constexpr operator bool() const noexcept { return m_ptr != nullptr; } // Basic "static cast" support template (std::declval())), typename = std::enable_if_t>> explicit operator shared_ptr() const& noexcept { verify(HERE), is_same_ptr(); if (m_ptr) { d()->refs++; } shared_ptr r; r.m_ptr = static_cast(m_ptr); return r; } // "Moving" "static cast" template (std::declval())), typename = std::enable_if_t>> explicit operator shared_ptr() && noexcept { verify(HERE), is_same_ptr(); shared_ptr r; r.m_ptr = static_cast(std::exchange(m_ptr, nullptr)); return r; } }; template static std::enable_if_t && (!Init || !sizeof...(Args)), shared_ptr> make_shared(Args&&... args) noexcept { return make_single(std::forward(args)...); } template static std::enable_if_t, shared_ptr> make_shared(std::size_t count) noexcept { return make_single(count); } // Atomic simplified shared pointer template class atomic_ptr { mutable atomic_t m_val{0}; static shared_counter* d(uptr val) { return std::launder(reinterpret_cast((val >> c_ref_size) - sizeof(shared_counter))); } shared_counter* d() const noexcept { return d(m_val); } public: using pointer = T*; using element_type = std::remove_extent_t; using shared_type = shared_ptr; constexpr atomic_ptr() noexcept = default; // Optimized value construct template >> explicit atomic_ptr(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); m_val = reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size; d()->refs.raw() += c_ref_mask; } template >> atomic_ptr(const shared_ptr& r) noexcept : m_val(reinterpret_cast(r.m_ptr) << c_ref_size) { verify(HERE), is_same_ptr(); // Obtain a ref + as many refs as an atomic_ptr can additionally reference if (m_val) d()->refs += c_ref_mask + 1; } template >> atomic_ptr(shared_ptr&& r) noexcept : m_val(reinterpret_cast(r.m_ptr) << c_ref_size) { verify(HERE), is_same_ptr(); r.m_ptr = nullptr; if (m_val) d()->refs += c_ref_mask; } template >> atomic_ptr(single_ptr&& r) noexcept : m_val(reinterpret_cast(r.m_ptr) << c_ref_size) { verify(HERE), is_same_ptr(); r.m_ptr = nullptr; if (m_val) d()->refs += c_ref_mask; } ~atomic_ptr() { const uptr v = m_val.raw(); const auto o = d(v); if (v >> c_ref_size && !o->refs.sub_fetch(c_ref_mask + 1 - (v & c_ref_mask))) { o->destroy(o); } } // Optimized value assignment atomic_ptr& operator=(std::remove_cv_t value) noexcept { shared_type r = make_single(std::move(value)); r.d()->refs += c_ref_mask; atomic_ptr old; old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size); return *this; } template >> atomic_ptr& operator=(const shared_ptr& r) noexcept { verify(HERE), is_same_ptr(); store(r); return *this; } template >> atomic_ptr& operator=(shared_ptr&& r) noexcept { verify(HERE), is_same_ptr(); store(std::move(r)); return *this; } template >> atomic_ptr& operator=(single_ptr&& r) noexcept { verify(HERE), is_same_ptr(); store(std::move(r)); return *this; } shared_type load() const noexcept { shared_type r; // Add reference const auto [prev, did_ref] = m_val.fetch_op([](uptr& val) { if (val >> c_ref_size) { val++; return true; } return false; }); if (!did_ref) { // Null pointer return r; } // Set referenced pointer r.m_ptr = std::launder(reinterpret_cast(prev >> c_ref_size)); r.d()->refs++; // Dereference if same pointer m_val.fetch_op([prev = prev](uptr& val) { if (val >> c_ref_size == prev >> c_ref_size) { val--; return true; } return false; }); return r; } operator shared_type() const noexcept { return load(); } template >> void store(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); r.d()->refs.raw() += c_ref_mask; atomic_ptr old; old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size); } void store(shared_type value) noexcept { if (value.m_ptr) { // Consume value and add refs value.d()->refs += c_ref_mask; } atomic_ptr old; old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(value.m_ptr, nullptr)) << c_ref_size); } template >> [[nodiscard]] shared_type exchange(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); r.d()->refs.raw() += c_ref_mask; atomic_ptr old; old.m_val.raw() += m_val.exchange(reinterpret_cast(r.m_ptr) << c_ref_size); old.m_val.raw() += 1; r.m_ptr = std::launder(reinterpret_cast(old.m_val >> c_ref_size)); return r; } [[nodiscard]] shared_type exchange(shared_type value) noexcept { if (value.m_ptr) { // Consume value and add refs value.d()->refs += c_ref_mask; } atomic_ptr old; old.m_val.raw() += m_val.exchange(reinterpret_cast(value.m_ptr) << c_ref_size); old.m_val.raw() += 1; value.m_ptr = std::launder(reinterpret_cast(old.m_val >> c_ref_size)); return value; } // bool compare_exchange(shared_type& cmp_and_old, shared_type exch) // { // } // template >> // shared_type compare_and_swap(const shared_ptr& cmp, shared_type exch) // { // } // template >> // bool compare_and_swap_test(const shared_ptr& cmp, shared_type exch) // { // } // template >> // shared_type compare_and_swap(const single_ptr& cmp, shared_type exch) // { // } // template >> // bool compare_and_swap_test(const single_ptr& cmp, shared_type exch) // { // } // Simple atomic load is much more effective than load(), but it's a non-owning reference const volatile void* observe() const noexcept { return reinterpret_cast(m_val >> c_ref_size); } explicit constexpr operator bool() const noexcept { return m_val != 0; } template >> bool is_equal(const shared_ptr& r) const noexcept { return observe() == r.get(); } template >> bool is_equal(const single_ptr& r) const noexcept { return observe() == r.get(); } }; } namespace std { template void swap(stx::single_ptr& lhs, stx::single_ptr& rhs) noexcept { lhs.swap(rhs); } template void swap(stx::shared_ptr& lhs, stx::shared_ptr& rhs) noexcept { lhs.swap(rhs); } } using stx::single_ptr; using stx::shared_ptr; using stx::atomic_ptr; using stx::make_single;