diff --git a/rpcs3/Emu/CPU/CPUTranslator.cpp b/rpcs3/Emu/CPU/CPUTranslator.cpp index f799e4b6be..08e8e9ad30 100644 --- a/rpcs3/Emu/CPU/CPUTranslator.cpp +++ b/rpcs3/Emu/CPU/CPUTranslator.cpp @@ -201,6 +201,13 @@ void cpu_translator::initialize(llvm::LLVMContext& context, llvm::ExecutionEngin m_use_vnni = true; m_use_gfni = true; } + +#ifdef ARCH_ARM64 + if (utils::has_dotprod()) + { + m_use_dotprod = true; + } +#endif } llvm::Value* cpu_translator::bitcast(llvm::Value* val, llvm::Type* type) const diff --git a/rpcs3/Emu/CPU/CPUTranslator.h b/rpcs3/Emu/CPU/CPUTranslator.h index 27abb22219..81049c6074 100644 --- a/rpcs3/Emu/CPU/CPUTranslator.h +++ b/rpcs3/Emu/CPU/CPUTranslator.h @@ -3090,6 +3090,9 @@ protected: // For now, setting this flag will speed up SPU verification // but I will remove this later with explicit parralelism - Whatcookie bool m_use_avx = true; + + // ARMv8 SDOT/UDOT + bool m_use_dotprod = false; #else // Allow FMA bool m_use_fma = false; @@ -3658,6 +3661,40 @@ public: return result; } +template + value_t udot(T1 a, T2 b, T3 c) + { + value_t result; + + const auto data0 = a.eval(m_ir); + const auto data1 = b.eval(m_ir); + const auto data2 = c.eval(m_ir); + + // ARM hardware requires the multipliers to be treated as 16-byte vectors + //const auto op1 = bitcast(data1, get_type()); + //const auto op2 = bitcast(data2, get_type()); + + // Use the variadic get_intrinsic to resolve the overloaded AArch64 intrinsic + result.value = m_ir->CreateCall(get_intrinsic(llvm::Intrinsic::aarch64_neon_udot), {data0, data1, data2}); + return result; + } + + template + value_t sdot(T1 a, T2 b, T3 c) + { + value_t result; + + const auto data0 = a.eval(m_ir); + const auto data1 = b.eval(m_ir); + const auto data2 = c.eval(m_ir); + + //const auto op1 = bitcast(data1, get_type()); + //const auto op2 = bitcast(data2, get_type()); + + result.value = m_ir->CreateCall(get_intrinsic(llvm::Intrinsic::aarch64_neon_sdot), {data0, data1, data2}); + return result; + } + template value_t vpermb(T1 a, T2 b) { diff --git a/rpcs3/Emu/Cell/SPULLVMRecompiler.cpp b/rpcs3/Emu/Cell/SPULLVMRecompiler.cpp index d42acd3560..45b4c83ecf 100644 --- a/rpcs3/Emu/Cell/SPULLVMRecompiler.cpp +++ b/rpcs3/Emu/Cell/SPULLVMRecompiler.cpp @@ -5313,13 +5313,24 @@ public: return; } +#ifdef ARCH_ARM64 + if (m_use_dotprod) +#else if (m_use_vnni) +#endif { - const auto [a, b] = get_vrs(op.ra, op.rb); const auto zeroes = splat(0); +#ifdef ARCH_ARM64 + const auto [a, b] = get_vrs(op.ra, op.rb); + const auto ones = splat(0x01); + const auto ax = bitcast(udot(zeroes, a, ones)); + const auto bx = bitcast(udot(zeroes, b, ones)); +#else + const auto [a, b] = get_vrs(op.ra, op.rb); const auto ones = splat(0x01010101); const auto ax = bitcast(vpdpbusd(zeroes, a, ones)); const auto bx = bitcast(vpdpbusd(zeroes, b, ones)); +#endif set_vr(op.rt, shuffle2(ax, bx, 0, 8, 2, 10, 4, 12, 6, 14)); return; }