diff --git a/rx/include/rx/Serializer.hpp b/rx/include/rx/Serializer.hpp new file mode 100644 index 000000000..a1c9c03aa --- /dev/null +++ b/rx/include/rx/Serializer.hpp @@ -0,0 +1,494 @@ +#pragma once + +#include "rx/align.hpp" +#include "rx/refl.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rx { +struct Serializer; +struct Deserializer; + +template struct TypeSerializer; + +namespace detail { +// TODO: replace with std::is_trivially_relocatable once available +template +concept TriviallyRelocatable = + std::is_trivially_copyable_v> && + std::is_trivially_default_constructible_v> && + !std::is_pointer_v && !std::is_reference_v; + +template +concept TypeSerializable = requires(Serializer &s, const T &value) { + TypeSerializer>::serialize(s, value); +} && (requires(Deserializer &d) { + { + TypeSerializer>::deserialize(d) + } -> std::same_as>; +} || requires(Deserializer &d, T &value) { + TypeSerializer>::deserialize(d, value); +}); + +template +concept IsRange = requires(T &object) { + object.size(); + *object.begin(); + object.begin() != object.end(); +}; + +struct StructSerializerField { + std::size_t offset; + std::size_t alignment; + std::size_t size; + void (*serialize)(rx::Serializer &s, const void *object); + void (*deserialize)(rx::Deserializer &s, void *object); +}; + +// try to call free function +template +void callSerializeFunction(Serializer &s, const T &value) + requires requires { serialize(s, value); } +{ + serialize(s, value); +} + +// try to call free function +template +void callDeserializeFunction(Deserializer &s, T &value) + requires requires { deserialize(s, value); } +{ + deserialize(s, value); +} + +template +T callDeserializeFunction(Deserializer &s) + requires requires(T &value) { deserialize(s); } +{ + return deserialize(s); +} + +template +concept SerializableImpl = requires(Serializer &s, const T &value) { + value.serialize(s); +} || requires(Serializer &s, const T &value) { + callSerializeFunction(s, value); +} || requires(Serializer &s, const T &value) { + TypeSerializer>::serialize(s, value); +}; + +template +concept DeserializableImpl = requires(Deserializer &d, T &value) { + value.deserialize(d); +} || requires(Deserializer &d, T &value) { + value = std::remove_cvref_t::deserialize(d); +} || requires(Deserializer &d, T &value) { + callDeserializeFunction(d, value); +} || requires(Deserializer &d, T &value) { + value = callDeserializeFunction>(d); +} || requires(Deserializer &d) { + { + TypeSerializer>::deserialize(d) + } -> std::same_as>; +} || requires(Deserializer &d, T &value) { + TypeSerializer>::deserialize(d, value); +}; +} // namespace detail + +template +concept Serializable = + detail::SerializableImpl && detail::DeserializableImpl; + +namespace detail { +struct SerializableFieldTest { + template + requires(std::is_default_constructible_v) + constexpr operator FieldT(); +}; + +struct SerializableAnyFieldTest { + template constexpr operator FieldT(); +}; + +template constexpr bool isSerializableField() { + auto impl = []( + std::index_sequence, + std::index_sequence) { + return requires { + T{(Before, SerializableAnyFieldTest{})..., SerializableFieldTest{}, + (After, SerializableAnyFieldTest{})...}; + }; + }; + + return impl(std::make_index_sequence{}, + std::make_index_sequence - I - 1>{}); +} + +template > +constexpr bool isSerializableFields() { + auto impl = [](std::index_sequence) { + return requires { T{(I, SerializableFieldTest{})...}; }; + }; + + return impl(std::make_index_sequence{}); +} + +template +concept SerializableClass = !detail::IsRange && std::is_class_v && + rx::fieldCount > 0 && isSerializableFields(); +} // namespace detail + +struct Serializer { + virtual ~Serializer() = default; + virtual void write(std::span data) = 0; + + template void serialize(const T &value) { + if constexpr (requires { value.serialize(*this); }) { + value.serialize(*this); + } else if constexpr (requires { callSerializeFunction(*this, value); }) { + callSerializeFunction(*this, value); + } else { + TypeSerializer>::serialize(*this, value); + } + } +}; + +struct Deserializer { + virtual ~Deserializer() = default; + virtual void read(std::span data) = 0; + + template [[nodiscard]] std::remove_cvref_t deserialize() { + using type = std::remove_cvref_t; + if constexpr (requires { + { type::deserialize(*this) } -> std::convertible_to; + }) { + return T::deserialize(*this); + } else if constexpr (requires(type &result) { + type::deserialize(*this, result); + }) { + type result; + T::deserialize(*this, result); + return result; + } else if constexpr (requires(type &result) { + { + result.deserialize(*this) + } -> std::convertible_to; + }) { + T result; + result.deserialize(*this); + return result; + } else if constexpr (requires(type &result) { + type::deserialize(*this, result); + }) { + type result; + T::deserialize(*this, result); + return result; + } else if constexpr (requires { + detail::callDeserializeFunction(*this); + }) { + return detail::callDeserializeFunction(*this); + } else if constexpr (requires(type &value) { + detail::callDeserializeFunction(*this, value); + }) { + + type result; + detail::callDeserializeFunction(*this, result); + return result; + } else if constexpr (requires { + TypeSerializer::deserialize(*this); + }) { + return TypeSerializer::deserialize(*this); + } else { + type result; + TypeSerializer::deserialize(*this, result); + return result; + } + } + + template void deserialize(T &result) { + if constexpr (requires { T::deserialize(*this, result); }) { + T::deserialize(*this, result); + } else if constexpr (requires { result.deserialize(*this); }) { + result.deserialize(*this); + } else if constexpr (requires { + { T::deserialize(*this) } -> std::convertible_to; + }) { + result = T::deserialize(*this); + } else if constexpr (requires { + detail::callDeserializeFunction(*this, result); + }) { + detail::callDeserializeFunction(*this, result); + } else if constexpr (requires { + detail::callDeserializeFunction(*this); + }) { + result = detail::callDeserializeFunction(*this); + } else if constexpr (requires { + TypeSerializer::deserialize(*this, result); + }) { + TypeSerializer::deserialize(*this, result); + } else { + result = TypeSerializer::deserialize(*this); + } + } + + void setFailure() { mFailure = true; } + [[nodiscard]] bool failure() const { return mFailure; } + +private: + bool mFailure = false; +}; + +template struct TypeSerializer { + static void serialize(Serializer &s, const T &t) { + std::byte rawBytes[sizeof(T)]; + std::memcpy(rawBytes, &t, sizeof(T)); + s.write(rawBytes); + } + + static T deserialize(Deserializer &s) { + alignas(T) std::byte rawBytes[sizeof(T)]; + s.read(rawBytes); + + return std::move(std::bit_cast(rawBytes)); + } +}; + +template +struct TypeSerializer> { + static void serialize(Serializer &s, const std::pair &t) { + s.serialize(t.first); + s.serialize(t.second); + } + static std::pair deserialize(Deserializer &s) { + auto a = s.deserialize(); + auto b = s.deserialize(); + + return { + std::move(a), + std::move(b), + }; + } +}; + +template struct TypeSerializer> { + static void serialize(Serializer &s, const std::tuple &t) { + std::apply([&s](auto &value) { s.serialize(value); }, t); + } + + static std::tuple deserialize(Deserializer &s) { + return std::tuple{s.deserialize()...}; + } +}; + +template + requires std::is_default_constructible_v && + requires(Serializer &s, T &object) { + s.serialize(object.size()); + s.serialize(*object.begin()); + object.resize(1); + object.begin() != object.end(); + } +struct TypeSerializer { + using item_type = std::remove_cvref_t().begin())>; + + static void serialize(Serializer &s, const T &t) { + s.serialize(static_cast(t.size())); + + if constexpr (detail::TriviallyRelocatable && + requires { reinterpret_cast(t.data()); }) { + s.write({reinterpret_cast(t.data()), + t.size() * sizeof(item_type)}); + } else { + for (auto &item : t) { + s.serialize(item); + } + } + } + + static T deserialize(Deserializer &s) { + auto size = s.deserialize(); + + T t; + t.resize(size); + + if constexpr (detail::TriviallyRelocatable && + requires { reinterpret_cast(t.data()); }) { + s.read({reinterpret_cast(t.data()), + t.size() * sizeof(item_type)}); + } else { + for (auto &item : t) { + s.deserialize(item); + } + } + + return t; + } +}; + +template + requires( + std::is_default_constructible_v && + requires(Serializer &s, T &object) { + s.serialize(object.size()); + s.serialize(*object.begin()); + object.insert(std::move(*object.begin())); + object.begin() != object.end(); + } && !requires(Serializer &s, T &object) { object.resize(1); }) +struct TypeSerializer { + using item_type = std::remove_cvref_t().begin())>; + + static void serialize(Serializer &s, const T &t) { + s.serialize(static_cast(t.size())); + + for (auto &item : t) { + s.serialize(item); + } + } + + static T deserialize(Deserializer &s) { + auto size = s.deserialize(); + + if (s.failure()) { + return {}; + } + + T result; + + for (std::uint32_t i = 0; i < size; ++i) { + result.insert(s.deserialize()); + + if (s.failure()) { + return {}; + } + } + + return result; + } +}; + +namespace detail { +template struct StructSerializerBuilder { + static constexpr std::array> + build() { + StructSerializerBuilder result; + + auto impl = [&](std::index_sequence) { + static_cast(T{FieldVisitor{&result, I}...}); + }; + + impl(std::make_index_sequence>{}); + + std::size_t nextOffset = 0; + for (auto &field : result.fields) { + auto fieldOffset = alignUp(nextOffset, field.alignment); + nextOffset = fieldOffset + field.size; + field.offset = fieldOffset; + } + + return result.fields; + } + +private: + struct FieldVisitor { + StructSerializerBuilder *builder; + std::size_t fieldIndex; + + template constexpr operator FieldT() { + builder->addField(fieldIndex); + return {}; + } + }; + + template constexpr void addField(std::size_t index) { + fields[index] = StructSerializerField{ + .offset = 0, + .alignment = alignof(FieldT), + .size = sizeof(FieldT), + .serialize = + +[](rx::Serializer &s, const void *object) { + s.serialize(*static_cast(object)); + }, + .deserialize = + +[](rx::Deserializer &s, void *object) { + s.deserialize(*static_cast(object)); + }, + }; + } + + std::array> fields; +}; +} // namespace detail + +template + requires(!requires { + std::index_sequence::build().size()>{}; + }) +struct TypeSerializer { + static const auto &getFields() { + static const auto fields = detail::StructSerializerBuilder::build(); + return fields; + } + + static void serialize(Serializer &s, const T &object) { + s.serialize(sizeof(object)); + auto bytes = std::bit_cast(&object); + for (auto field : getFields()) { + field.serialize(s, bytes + field.offset); + } + } + + static void deserialize(Deserializer &s, T &object) { + if (s.deserialize() != sizeof(object)) { + s.setFailure(); + return; + } + + auto bytes = std::bit_cast(&object); + for (auto field : getFields()) { + field.deserialize(s, bytes + field.offset); + if (s.failure()) { + return; + } + } + } +}; + +// all fields are constructable at compile time overload +template + requires requires { + std::index_sequence::build().size()>{}; + } +struct TypeSerializer { + static constexpr auto fields = detail::StructSerializerBuilder::build(); + + static void serialize(Serializer &s, const T &object) { + s.serialize(sizeof(object)); + auto bytes = std::bit_cast(&object); + for (auto field : fields) { + field.serialize(s, bytes + field.offset); + } + } + + static void deserialize(Deserializer &s, T &object) { + if (s.deserialize() != sizeof(object)) { + s.setFailure(); + return; + } + + auto bytes = std::bit_cast(&object); + for (auto field : fields) { + field.deserialize(s, bytes + field.offset); + + if (s.failure()) { + return; + } + } + } +}; +} // namespace rx