Add Policy, Implement Labels

This commit is contained in:
Merry 2022-07-03 10:50:45 +01:00
parent 1d38501ca5
commit 7a9c21c613
5 changed files with 248 additions and 82 deletions

View file

@ -1,12 +1,25 @@
// SPDX-FileCopyrightText: 2022 merryhime
// SPDX-License-Identifier: MIT
#define OAKNUT_STD_ENCODE(TYPE, ACCESS, SIZE) \
template<std::uint32_t splat> \
std::uint32_t encode(TYPE v) \
{ \
static_assert(std::popcount(splat) == SIZE); \
return detail::pdep(static_cast<std::uint32_t>(ACCESS), splat); \
template<std::uint32_t mask_>
static constexpr std::uint32_t pdep(std::uint32_t val)
{
std::uint32_t mask = mask_;
std::uint32_t res = 0;
for (std::uint32_t bb = 1; mask; bb += bb) {
if (val & bb)
res |= mask & -mask;
mask &= mask - 1;
}
return res;
}
#define OAKNUT_STD_ENCODE(TYPE, ACCESS, SIZE) \
template<std::uint32_t splat> \
std::uint32_t encode(TYPE v) \
{ \
static_assert(std::popcount(splat) == SIZE); \
return pdep<splat>(static_cast<std::uint32_t>(ACCESS)); \
}
OAKNUT_STD_ENCODE(RReg, v.index() & 31, 5)
@ -44,54 +57,41 @@ std::uint32_t encode(MovImm16 v)
if ((v.m_encoded & mask) != v.m_encoded)
throw "invalid MovImm16";
}
return detail::pdep(v.m_encoded, splat);
return pdep<splat>(v.m_encoded);
}
template<std::uint32_t splat, std::size_t imm_size>
std::uint32_t encode(Imm<imm_size> v)
{
static_assert(std::popcount(splat) >= imm_size);
return detail::pdep(v.value(), splat);
return pdep<splat>(v.value());
}
template<std::uint32_t splat, int A, int B>
std::uint32_t encode(ImmChoice<A, B> v)
{
static_assert(std::popcount(splat) == 1);
return detail::pdep(v.m_encoded, splat);
}
template<std::uint32_t splat, std::size_t size, std::size_t align>
std::uint32_t encode(AddrOffset<size, align> v)
{
static_assert(std::popcount(splat) == size - align);
return detail::pdep(v.m_encoded, splat);
}
template<std::uint32_t splat, std::size_t size>
std::uint32_t encode(PageOffset<size> v)
{
throw "to be implemented";
return pdep<splat>(v.m_encoded);
}
template<std::uint32_t splat, std::size_t size, std::size_t align>
std::uint32_t encode(SOffset<size, align> v)
{
static_assert(std::popcount(splat) == size - align);
return detail::pdep(v.m_encoded, splat);
return pdep<splat>(v.m_encoded);
}
template<std::uint32_t splat, std::size_t size, std::size_t align>
std::uint32_t encode(POffset<size, align> v)
{
static_assert(std::popcount(splat) == size - align);
return detail::pdep(v.m_encoded, splat);
return pdep<splat>(v.m_encoded);
}
template<std::uint32_t splat>
std::uint32_t encode(std::uint32_t v)
{
return detail::pdep(v, splat);
return pdep<splat>(v);
}
#undef OAKNUT_STD_ENCODE

View file

@ -36,7 +36,8 @@ public:
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_value;
};
@ -71,7 +72,8 @@ public:
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -110,7 +112,8 @@ public:
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -121,15 +124,6 @@ constexpr std::uint64_t mask_from_esize(std::size_t esize)
return (~std::uint64_t{0}) >> (64 - esize);
}
constexpr std::optional<std::uint32_t> encode_bit_imm(std::uint32_t value)
{
const std::uint64_t value_u64 = (static_cast<std::uint64_t>(value) << 32) | static_cast<std::uint64_t>(value);
const auto result = encode_bit_imm(value_u64);
if (result && (*result & 0x3FF) != *result)
return std::nullopt;
return result;
}
constexpr std::uint64_t inverse_mask_from_trailing_ones(std::uint64_t value)
{
return ~value | (value + 1);
@ -170,6 +164,15 @@ constexpr std::optional<std::uint32_t> encode_bit_imm(std::uint64_t value)
return static_cast<std::uint32_t>(((((-esize) << 7) | (S << 6) | R) ^ 0x1000) & 0x1fff);
}
constexpr std::optional<std::uint32_t> encode_bit_imm(std::uint32_t value)
{
const std::uint64_t value_u64 = (static_cast<std::uint64_t>(value) << 32) | static_cast<std::uint64_t>(value);
const auto result = encode_bit_imm(value_u64);
if (result && (*result & 0x3FF) != *result)
return std::nullopt;
return result;
}
} // namespace detail
struct BitImm32 {
@ -187,7 +190,8 @@ public:
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -206,7 +210,8 @@ public:
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -224,7 +229,8 @@ struct ImmChoice {
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -238,7 +244,8 @@ struct LslShift {
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};

View file

@ -5,9 +5,12 @@
#include <cstddef>
#include <cstdint>
#include <variant>
namespace oaknut {
struct Label;
namespace detail {
constexpr std::uint64_t inverse_mask_from_size(std::size_t size)
@ -33,6 +36,18 @@ constexpr std::uint64_t sign_extend(std::uint64_t value)
template<std::size_t bitsize, std::size_t alignment>
struct AddrOffset {
AddrOffset(std::ptrdiff_t diff)
: m_payload(encode(diff))
{}
AddrOffset(Label& label)
: m_payload(&label)
{}
AddrOffset(void* ptr)
: m_payload(ptr)
{}
static std::uint32_t encode(std::ptrdiff_t diff)
{
const std::uint64_t diff_u64 = static_cast<std::uint64_t>(diff);
if (detail::sign_extend<bitsize>(diff_u64) != diff_u64)
@ -40,23 +55,37 @@ struct AddrOffset {
if (diff_u64 != (diff_u64 & detail::inverse_mask_from_size(alignment)))
throw "misalignment";
m_encoded = static_cast<std::uint32_t>((diff_u64 & detail::mask_from_size(bitsize)) >> alignment);
return static_cast<std::uint32_t>((diff_u64 & detail::mask_from_size(bitsize)) >> alignment);
}
private:
friend class CodeGenerator;
std::uint32_t m_encoded;
template<typename Policy>
friend class BasicCodeGenerator;
std::variant<std::uint32_t, Label*, void*> m_payload;
};
template<std::size_t bitsize>
struct PageOffset {
PageOffset(void* ptr)
: m_ptr(ptr)
: m_payload(ptr)
{}
PageOffset(Label& label)
: m_payload(&label)
{}
static std::uint32_t encode(std::uintptr_t current_addr, std::uintptr_t target)
{
const std::int64_t page_diff = (static_cast<std::int64_t>(target) >> 12) - (static_cast<std::int64_t>(current_addr) >> 12);
if (detail::sign_extend<bitsize>(page_diff) != page_diff)
throw "out of range";
return static_cast<std::uint32_t>(page_diff & detail::mask_from_size(bitsize));
}
private:
friend class CodeGenerator;
void* m_ptr;
template<typename Policy>
friend class BasicCodeGenerator;
std::variant<Label*, void*> m_payload;
};
template<std::size_t bitsize, std::size_t alignment>
@ -73,7 +102,8 @@ struct SOffset {
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};
@ -91,7 +121,8 @@ struct POffset {
}
private:
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
std::uint32_t m_encoded;
};

View file

@ -50,7 +50,8 @@ struct RReg : public Reg {
XReg toX() const;
WReg toW() const;
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
};
struct ZrReg : public RReg {
@ -70,7 +71,8 @@ struct XReg : public RReg {
constexpr /* implicit */ XReg(ZrReg)
: RReg(64, 31) {}
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
};
struct WReg : public RReg {
@ -80,7 +82,8 @@ struct WReg : public RReg {
constexpr /* implicit */ WReg(WzrReg)
: RReg(32, 31) {}
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
};
XReg RReg::toX() const
@ -118,7 +121,8 @@ struct XRegSp : public RReg {
throw "unexpected ZR passed into an XRegSp";
}
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
};
struct WRegWsp : public RReg {
@ -132,7 +136,8 @@ struct WRegWsp : public RReg {
throw "unexpected WZR passed into an WRegWsp";
}
friend class CodeGenerator;
template<typename Policy>
friend class BasicCodeGenerator;
};
} // namespace oaknut

View file

@ -4,7 +4,11 @@
#include <bit>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <tuple>
#include <type_traits>
#include <variant>
#include <vector>
#include "oaknut/impl/enum.hpp"
#include "oaknut/impl/imm.hpp"
@ -31,25 +35,133 @@ constexpr std::uint32_t get_bits()
return result;
}
constexpr std::uint32_t pdep(std::uint32_t val, std::uint32_t mask)
{
std::uint32_t res = 0;
for (std::uint32_t bb = 1; mask; bb += bb) {
if (val & bb)
res |= mask & -mask;
mask &= mask - 1;
}
return res;
}
template<class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
} // namespace detail
class CodeGenerator {
struct Label {
public:
explicit CodeGenerator(std::uint32_t* ptr)
: m_ptr(ptr)
Label() = default;
private:
template<typename Policy>
friend class BasicCodeGenerator;
explicit Label(std::uintptr_t addr)
: m_addr(addr)
{}
using EmitFunctionType = std::uint32_t (*)(std::uintptr_t wb_addr, std::uintptr_t resolved_addr);
struct Writeback {
std::uintptr_t m_wb_addr;
std::uint32_t m_mask;
EmitFunctionType m_fn;
};
std::optional<std::uintptr_t> m_addr;
std::vector<Writeback> m_wbs;
};
template<typename Policy>
class BasicCodeGenerator : public Policy {
public:
BasicCodeGenerator(typename Policy::constructor_argument_type arg)
: Policy(arg)
{}
Label l()
{
return Label{Policy::current_address()};
}
void l(Label& label)
{
if (label.m_addr)
throw "label already resolved";
const auto target_addr = Policy::current_address();
label.m_addr = target_addr;
for (auto& wb : label.m_wbs) {
const std::uint32_t value = wb.m_fn(wb.m_wb_addr, target_addr);
Policy::set_at_address(wb.m_wb_addr, value, wb.m_mask);
}
label.m_wbs.clear();
}
#include "oaknut/impl/arm64_mnemonics.inc.hpp"
private:
#include "oaknut/impl/arm64_encode_helpers.inc.hpp"
template<StringLiteral bs, StringLiteral... bargs, typename... Ts>
void emit(Ts... args)
{
std::uint32_t encoding = detail::get_bits<bs, "1">();
encoding |= (0 | ... | encode<detail::get_bits<bs, bargs>()>(std::forward<Ts>(args)));
Policy::append(encoding);
}
template<std::uint32_t splat, std::size_t size, std::size_t align>
std::uint32_t encode(AddrOffset<size, align> v)
{
static_assert(std::popcount(splat) == size - align);
const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) {
const std::ptrdiff_t diff = target - current_addr;
return pdep<splat>(AddrOffset<size, align>::encode(diff));
};
return std::visit(detail::overloaded{
[&](std::uint32_t encoding) {
return pdep<splat>(encoding);
},
[&](Label* label) {
if (label->m_addr) {
return encode_fn(Policy::current_address(), *label->m_addr);
}
label->m_wbs.emplace_back({Policy::current_address(), ~splat, static_cast<Label::EmitFunctionType>(encode_fn)});
return 0;
},
[&](void* p) {
return encode_fn(Policy::current_address(), reinterpret_cast<std::uintptr_t>(p));
},
},
v.m_payload);
}
template<std::uint32_t splat, std::size_t size>
std::uint32_t encode(PageOffset<size> v)
{
static_assert(std::popcount(splat) == size);
const auto encode_fn = [](std::uintptr_t current_addr, std::uintptr_t target) {
return pdep<splat>(PageOffset<size>::encode(current_addr, target));
};
return std::visit(detail::overloaded{
[&](Label* label) {
if (label->m_addr) {
return encode_fn(Policy::current_address(), *label->m_addr);
}
label->m_wbs.emplace_back({Policy::current_address(), ~splat, static_cast<Label::EmitFunctionType>(encode_fn)});
return 0;
},
[&](void* p) {
return encode_fn(Policy::current_address(), reinterpret_cast<std::uintptr_t>(p));
},
},
v.m_payload);
}
};
struct PointerCodeGeneratorPolicy {
public:
template<typename T>
T ptr()
{
@ -57,29 +169,40 @@ public:
return reinterpret_cast<T>(m_ptr);
}
void set_ptr(std::uint32_t* ptr)
void set_ptr(std::uint32_t* ptr_)
{
m_ptr = ptr;
m_ptr = ptr_;
}
#include "oaknut/impl/arm64_mnemonics.inc.hpp"
protected:
using constructor_argument_type = std::uint32_t*;
PointerCodeGeneratorPolicy(std::uint32_t* ptr_)
: m_ptr(ptr_)
{}
void append(std::uint32_t instruction)
{
*m_ptr++ = instruction;
}
std::uintptr_t current_address()
{
return reinterpret_cast<std::uintptr_t>(m_ptr);
}
void set_at_address(std::uintptr_t addr, std::uint32_t value, std::uint32_t mask)
{
std::uint32_t* p = reinterpret_cast<std::uint32_t*>(addr);
*p = (*p & mask) | value;
}
private:
template<StringLiteral bs, StringLiteral... bargs, typename... Ts>
void emit(Ts... args)
{
std::uint32_t encoding = detail::get_bits<bs, "1">();
encoding |= (0 | ... | encode<detail::get_bits<bs, bargs>()>(std::forward<Ts>(args)));
*m_ptr = encoding;
m_ptr++;
}
#include "oaknut/impl/arm64_encode_helpers.inc.hpp"
std::uint32_t* m_ptr;
};
using CodeGenerator = BasicCodeGenerator<PointerCodeGeneratorPolicy>;
namespace util {
inline constexpr WReg W0{0}, W1{1}, W2{2}, W3{3}, W4{4}, W5{5}, W6{6}, W7{7}, W8{8}, W9{9}, W10{10}, W11{11}, W12{12}, W13{13}, W14{14}, W15{15}, W16{16}, W17{17}, W18{18}, W19{19}, W20{20}, W21{21}, W22{22}, W23{23}, W24{24}, W25{25}, W26{26}, W27{27}, W28{28}, W29{29}, W30{30};