diff --git a/include/oaknut/impl/arm64_encode_helpers.inc.hpp b/include/oaknut/impl/arm64_encode_helpers.inc.hpp index 72a3bac..9d629d4 100644 --- a/include/oaknut/impl/arm64_encode_helpers.inc.hpp +++ b/include/oaknut/impl/arm64_encode_helpers.inc.hpp @@ -1,12 +1,25 @@ // SPDX-FileCopyrightText: 2022 merryhime // SPDX-License-Identifier: MIT -#define OAKNUT_STD_ENCODE(TYPE, ACCESS, SIZE) \ - template \ - std::uint32_t encode(TYPE v) \ - { \ - static_assert(std::popcount(splat) == SIZE); \ - return detail::pdep(static_cast(ACCESS), splat); \ +template +static constexpr std::uint32_t pdep(std::uint32_t val) +{ + std::uint32_t mask = mask_; + std::uint32_t res = 0; + for (std::uint32_t bb = 1; mask; bb += bb) { + if (val & bb) + res |= mask & -mask; + mask &= mask - 1; + } + return res; +} + +#define OAKNUT_STD_ENCODE(TYPE, ACCESS, SIZE) \ + template \ + std::uint32_t encode(TYPE v) \ + { \ + static_assert(std::popcount(splat) == SIZE); \ + return pdep(static_cast(ACCESS)); \ } OAKNUT_STD_ENCODE(RReg, v.index() & 31, 5) @@ -44,54 +57,41 @@ std::uint32_t encode(MovImm16 v) if ((v.m_encoded & mask) != v.m_encoded) throw "invalid MovImm16"; } - return detail::pdep(v.m_encoded, splat); + return pdep(v.m_encoded); } template std::uint32_t encode(Imm v) { static_assert(std::popcount(splat) >= imm_size); - return detail::pdep(v.value(), splat); + return pdep(v.value()); } template std::uint32_t encode(ImmChoice v) { static_assert(std::popcount(splat) == 1); - return detail::pdep(v.m_encoded, splat); -} - -template -std::uint32_t encode(AddrOffset v) -{ - static_assert(std::popcount(splat) == size - align); - return detail::pdep(v.m_encoded, splat); -} - -template -std::uint32_t encode(PageOffset v) -{ - throw "to be implemented"; + return pdep(v.m_encoded); } template std::uint32_t encode(SOffset v) { static_assert(std::popcount(splat) == size - align); - return detail::pdep(v.m_encoded, splat); + return pdep(v.m_encoded); } template std::uint32_t encode(POffset v) { static_assert(std::popcount(splat) == size - align); - return detail::pdep(v.m_encoded, splat); + return pdep(v.m_encoded); } template std::uint32_t encode(std::uint32_t v) { - return detail::pdep(v, splat); + return pdep(v); } #undef OAKNUT_STD_ENCODE diff --git a/include/oaknut/impl/imm.hpp b/include/oaknut/impl/imm.hpp index ceeef34..2fd35a6 100644 --- a/include/oaknut/impl/imm.hpp +++ b/include/oaknut/impl/imm.hpp @@ -36,7 +36,8 @@ public: } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_value; }; @@ -71,7 +72,8 @@ public: } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -110,7 +112,8 @@ public: } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -121,15 +124,6 @@ constexpr std::uint64_t mask_from_esize(std::size_t esize) return (~std::uint64_t{0}) >> (64 - esize); } -constexpr std::optional encode_bit_imm(std::uint32_t value) -{ - const std::uint64_t value_u64 = (static_cast(value) << 32) | static_cast(value); - const auto result = encode_bit_imm(value_u64); - if (result && (*result & 0x3FF) != *result) - return std::nullopt; - return result; -} - constexpr std::uint64_t inverse_mask_from_trailing_ones(std::uint64_t value) { return ~value | (value + 1); @@ -170,6 +164,15 @@ constexpr std::optional encode_bit_imm(std::uint64_t value) return static_cast(((((-esize) << 7) | (S << 6) | R) ^ 0x1000) & 0x1fff); } +constexpr std::optional encode_bit_imm(std::uint32_t value) +{ + const std::uint64_t value_u64 = (static_cast(value) << 32) | static_cast(value); + const auto result = encode_bit_imm(value_u64); + if (result && (*result & 0x3FF) != *result) + return std::nullopt; + return result; +} + } // namespace detail struct BitImm32 { @@ -187,7 +190,8 @@ public: } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -206,7 +210,8 @@ public: } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -224,7 +229,8 @@ struct ImmChoice { } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -238,7 +244,8 @@ struct LslShift { } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; diff --git a/include/oaknut/impl/offset.hpp b/include/oaknut/impl/offset.hpp index 198276a..6bb5c35 100644 --- a/include/oaknut/impl/offset.hpp +++ b/include/oaknut/impl/offset.hpp @@ -5,9 +5,12 @@ #include #include +#include namespace oaknut { +struct Label; + namespace detail { constexpr std::uint64_t inverse_mask_from_size(std::size_t size) @@ -33,6 +36,18 @@ constexpr std::uint64_t sign_extend(std::uint64_t value) template struct AddrOffset { AddrOffset(std::ptrdiff_t diff) + : m_payload(encode(diff)) + {} + + AddrOffset(Label& label) + : m_payload(&label) + {} + + AddrOffset(void* ptr) + : m_payload(ptr) + {} + + static std::uint32_t encode(std::ptrdiff_t diff) { const std::uint64_t diff_u64 = static_cast(diff); if (detail::sign_extend(diff_u64) != diff_u64) @@ -40,23 +55,37 @@ struct AddrOffset { if (diff_u64 != (diff_u64 & detail::inverse_mask_from_size(alignment))) throw "misalignment"; - m_encoded = static_cast((diff_u64 & detail::mask_from_size(bitsize)) >> alignment); + return static_cast((diff_u64 & detail::mask_from_size(bitsize)) >> alignment); } private: - friend class CodeGenerator; - std::uint32_t m_encoded; + template + friend class BasicCodeGenerator; + std::variant m_payload; }; template struct PageOffset { PageOffset(void* ptr) - : m_ptr(ptr) + : m_payload(ptr) {} + PageOffset(Label& label) + : m_payload(&label) + {} + + static std::uint32_t encode(std::uintptr_t current_addr, std::uintptr_t target) + { + const std::int64_t page_diff = (static_cast(target) >> 12) - (static_cast(current_addr) >> 12); + if (detail::sign_extend(page_diff) != page_diff) + throw "out of range"; + return static_cast(page_diff & detail::mask_from_size(bitsize)); + } + private: - friend class CodeGenerator; - void* m_ptr; + template + friend class BasicCodeGenerator; + std::variant m_payload; }; template @@ -73,7 +102,8 @@ struct SOffset { } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; @@ -91,7 +121,8 @@ struct POffset { } private: - friend class CodeGenerator; + template + friend class BasicCodeGenerator; std::uint32_t m_encoded; }; diff --git a/include/oaknut/impl/reg.hpp b/include/oaknut/impl/reg.hpp index d8a93b8..4d5239d 100644 --- a/include/oaknut/impl/reg.hpp +++ b/include/oaknut/impl/reg.hpp @@ -50,7 +50,8 @@ struct RReg : public Reg { XReg toX() const; WReg toW() const; - friend class CodeGenerator; + template + friend class BasicCodeGenerator; }; struct ZrReg : public RReg { @@ -70,7 +71,8 @@ struct XReg : public RReg { constexpr /* implicit */ XReg(ZrReg) : RReg(64, 31) {} - friend class CodeGenerator; + template + friend class BasicCodeGenerator; }; struct WReg : public RReg { @@ -80,7 +82,8 @@ struct WReg : public RReg { constexpr /* implicit */ WReg(WzrReg) : RReg(32, 31) {} - friend class CodeGenerator; + template + friend class BasicCodeGenerator; }; XReg RReg::toX() const @@ -118,7 +121,8 @@ struct XRegSp : public RReg { throw "unexpected ZR passed into an XRegSp"; } - friend class CodeGenerator; + template + friend class BasicCodeGenerator; }; struct WRegWsp : public RReg { @@ -132,7 +136,8 @@ struct WRegWsp : public RReg { throw "unexpected WZR passed into an WRegWsp"; } - friend class CodeGenerator; + template + friend class BasicCodeGenerator; }; } // namespace oaknut diff --git a/include/oaknut/oaknut.hpp b/include/oaknut/oaknut.hpp index 33ba367..4c87e26 100644 --- a/include/oaknut/oaknut.hpp +++ b/include/oaknut/oaknut.hpp @@ -4,7 +4,11 @@ #include #include #include +#include +#include #include +#include +#include #include "oaknut/impl/enum.hpp" #include "oaknut/impl/imm.hpp" @@ -31,25 +35,133 @@ constexpr std::uint32_t get_bits() return result; } -constexpr std::uint32_t pdep(std::uint32_t val, std::uint32_t mask) -{ - std::uint32_t res = 0; - for (std::uint32_t bb = 1; mask; bb += bb) { - if (val & bb) - res |= mask & -mask; - mask &= mask - 1; - } - return res; -} +template +struct overloaded : Ts... { + using Ts::operator()...; +}; } // namespace detail -class CodeGenerator { +struct Label { public: - explicit CodeGenerator(std::uint32_t* ptr) - : m_ptr(ptr) + Label() = default; + +private: + template + friend class BasicCodeGenerator; + + explicit Label(std::uintptr_t addr) + : m_addr(addr) {} + using EmitFunctionType = std::uint32_t (*)(std::uintptr_t wb_addr, std::uintptr_t resolved_addr); + + struct Writeback { + std::uintptr_t m_wb_addr; + std::uint32_t m_mask; + EmitFunctionType m_fn; + }; + + std::optional m_addr; + std::vector m_wbs; +}; + +template +class BasicCodeGenerator : public Policy { +public: + BasicCodeGenerator(typename Policy::constructor_argument_type arg) + : Policy(arg) + {} + + Label l() + { + return Label{Policy::current_address()}; + } + + void l(Label& label) + { + if (label.m_addr) + throw "label already resolved"; + + const auto target_addr = Policy::current_address(); + label.m_addr = target_addr; + for (auto& wb : label.m_wbs) { + const std::uint32_t value = wb.m_fn(wb.m_wb_addr, target_addr); + Policy::set_at_address(wb.m_wb_addr, value, wb.m_mask); + } + label.m_wbs.clear(); + } + +#include "oaknut/impl/arm64_mnemonics.inc.hpp" + +private: +#include "oaknut/impl/arm64_encode_helpers.inc.hpp" + + template + void emit(Ts... args) + { + std::uint32_t encoding = detail::get_bits(); + encoding |= (0 | ... | encode()>(std::forward(args))); + Policy::append(encoding); + } + + template + std::uint32_t encode(AddrOffset v) + { + static_assert(std::popcount(splat) == size - align); + + const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) { + const std::ptrdiff_t diff = target - current_addr; + return pdep(AddrOffset::encode(diff)); + }; + + return std::visit(detail::overloaded{ + [&](std::uint32_t encoding) { + return pdep(encoding); + }, + [&](Label* label) { + if (label->m_addr) { + return encode_fn(Policy::current_address(), *label->m_addr); + } + + label->m_wbs.emplace_back({Policy::current_address(), ~splat, static_cast(encode_fn)}); + return 0; + }, + [&](void* p) { + return encode_fn(Policy::current_address(), reinterpret_cast(p)); + }, + }, + v.m_payload); + } + + template + std::uint32_t encode(PageOffset v) + { + static_assert(std::popcount(splat) == size); + + const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) { + return pdep(PageOffset::encode(current_addr, target)); + }; + + return std::visit(detail::overloaded{ + [&](Label* label) { + if (label->m_addr) { + return encode_fn(Policy::current_address(), *label->m_addr); + } + + label->m_wbs.emplace_back({Policy::current_address(), ~splat, static_cast(encode_fn)}); + return 0; + }, + [&](void* p) { + return encode_fn(Policy::current_address(), reinterpret_cast(p)); + }, + }, + v.m_payload); + } +}; + +struct PointerCodeGeneratorPolicy { +public: template T ptr() { @@ -57,29 +169,40 @@ public: return reinterpret_cast(m_ptr); } - void set_ptr(std::uint32_t* ptr) + void set_ptr(std::uint32_t* ptr_) { - m_ptr = ptr; + m_ptr = ptr_; } -#include "oaknut/impl/arm64_mnemonics.inc.hpp" +protected: + using constructor_argument_type = std::uint32_t*; + + PointerCodeGeneratorPolicy(std::uint32_t* ptr_) + : m_ptr(ptr_) + {} + + void append(std::uint32_t instruction) + { + *m_ptr++ = instruction; + } + + std::uintptr_t current_address() + { + return reinterpret_cast(m_ptr); + } + + void set_at_address(std::uintptr_t addr, std::uint32_t value, std::uint32_t mask) + { + std::uint32_t* p = reinterpret_cast(addr); + *p = (*p & mask) | value; + } private: - template - void emit(Ts... args) - { - std::uint32_t encoding = detail::get_bits(); - encoding |= (0 | ... | encode()>(std::forward(args))); - - *m_ptr = encoding; - m_ptr++; - } - -#include "oaknut/impl/arm64_encode_helpers.inc.hpp" - std::uint32_t* m_ptr; }; +using CodeGenerator = BasicCodeGenerator; + namespace util { inline constexpr WReg W0{0}, W1{1}, W2{2}, W3{3}, W4{4}, W5{5}, W6{6}, W7{7}, W8{8}, W9{9}, W10{10}, W11{11}, W12{12}, W13{13}, W14{14}, W15{15}, W16{16}, W17{17}, W18{18}, W19{19}, W20{20}, W21{21}, W22{22}, W23{23}, W24{24}, W25{25}, W26{26}, W27{27}, W28{28}, W29{29}, W30{30};