From e27d2f0306ecfb05ac4ecf96916bfceeab5bd812 Mon Sep 17 00:00:00 2001 From: Merry Date: Sat, 9 Jul 2022 20:14:06 +0100 Subject: [PATCH] Implement CodeBlock --- include/oaknut/code_block.hpp | 124 ++++++++++++++++++++++++++++++++++ tests/basic.cpp | 24 +++---- 2 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 include/oaknut/code_block.hpp diff --git a/include/oaknut/code_block.hpp b/include/oaknut/code_block.hpp new file mode 100644 index 0000000..c7b0666 --- /dev/null +++ b/include/oaknut/code_block.hpp @@ -0,0 +1,124 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 merryhime +// SPDX-License-Identifier: MIT + +#include +#include +#include + +#if defined(_WIN32) +# include +#elif defined(__APPLE__) +# include +# include +# include +# include +#else +# include +#endif + +namespace oaknut { + +class CodeBlock { +public: + explicit CodeBlock(std::size_t size) + : m_size(size) + { +#if defined(_WIN32) + m_memory = (std::uint32_t*)VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE); +#elif defined(__APPLE__) + m_memory = (std::uint32_t*)mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE | MAP_JIT, -1, 0); +#else + m_memory = (std::uint32_t*)mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, -1, 0); +#endif + + if (m_memory == nullptr) + throw std::bad_alloc{}; + } + + ~CodeBlock() + { + if (m_memory == nullptr) + return; + +#if defined(_WIN32) + VirtualFree((void*)m_memory, 0, MEM_RELEASE); +#else + munmap(m_memory, m_size); +#endif + } + + CodeBlock(const CodeBlock&) = delete; + CodeBlock& operator=(const CodeBlock&) = delete; + CodeBlock(CodeBlock&&) = delete; + CodeBlock& operator=(CodeBlock&&) = delete; + + std::uint32_t* ptr() const + { + return m_memory; + } + + void protect() + { +#if defined(__APPLE__) + pthread_jit_write_protect_np(1); +#endif + } + + void unprotect() + { +#if defined(__APPLE__) + pthread_jit_write_protect_np(0); +#endif + } + + void invalidate(std::uint32_t* mem, std::size_t size) + { +#if defined(__APPLE__) + sys_icache_invalidate(mem, size); +#else + static std::size_t icache_line_size = 0x10000, dcache_line_size = 0x10000; + + std::uint64_t ctr; + __asm__ volatile("mrs %0, ctr_el0" + : "=r"(ctr)); + + const std::size_t isize = icache_line_size = std::min(icache_line_size, 4 << ((ctr >> 0) & 0xf)); + const std::size_t dsize = dcache_line_size = std::min(dcache_line_size, 4 << ((ctr >> 16) & 0xf)); + + const std::uintptr_t end = (std::uintptr_t)mem + size; + + for (std::uintptr_t addr = ((std::uintptr_t)mem) & ~(dsize - 1); addr < end; addr += dsize) { + __asm__ volatile("dc cvau, %0" + : + : "r"(addr) + : "memory"); + } + __asm__ volatile("dsb ish\n" + : + : + : "memory"); + + for (std::uintptr_t addr = ((std::uintptr_t)mem) & ~(isize - 1); addr < end; addr += isize) { + __asm__ volatile("ic ivau, %0" + : + : "r"(addr) + : "memory"); + } + __asm__ volatile("dsb ish\nisb\n" + : + : + : "memory"); +#endif + } + + void invalidate_all() + { + invalidate(m_memory, m_size); + } + +protected: + std::uint32_t* m_memory; + std::size_t m_size = 0; +}; + +} // namespace oaknut diff --git a/tests/basic.cpp b/tests/basic.cpp index 7bd1f50..1561436 100644 --- a/tests/basic.cpp +++ b/tests/basic.cpp @@ -5,32 +5,26 @@ #include #include -#include -#include -#include -#include +#include "oaknut/code_block.hpp" #include "oaknut/oaknut.hpp" TEST_CASE("Basic Test") { - const size_t page_size = getpagesize(); - std::printf("page size: %zu\n", page_size); - - std::uint32_t* mem = (std::uint32_t*)mmap(nullptr, page_size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_JIT | MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); - - pthread_jit_write_protect_np(false); - using namespace oaknut; using namespace oaknut::util; - CodeGenerator code{mem}; + CodeBlock mem{4096}; + CodeGenerator code{mem.ptr()}; + + mem.unprotect(); + code.MOVZ(W0, 42); code.RET(X30); - pthread_jit_write_protect_np(true); - sys_icache_invalidate(mem, page_size); + mem.protect(); + mem.invalidate_all(); - int result = ((int (*)())mem)(); + int result = ((int (*)())mem.ptr())(); REQUIRE(result == 42); }