SPU LLVM: ARM64: Use UDOT for emulating SUMB

This commit is contained in:
Malcolm 2026-02-17 18:27:32 -05:00 committed by Elad
parent b734ceb2e7
commit 4542020c86
3 changed files with 56 additions and 1 deletions

View file

@ -201,6 +201,13 @@ void cpu_translator::initialize(llvm::LLVMContext& context, llvm::ExecutionEngin
m_use_vnni = true;
m_use_gfni = true;
}
#ifdef ARCH_ARM64
if (utils::has_dotprod())
{
m_use_dotprod = true;
}
#endif
}
llvm::Value* cpu_translator::bitcast(llvm::Value* val, llvm::Type* type) const

View file

@ -3090,6 +3090,9 @@ protected:
// For now, setting this flag will speed up SPU verification
// but I will remove this later with explicit parralelism - Whatcookie
bool m_use_avx = true;
// ARMv8 SDOT/UDOT
bool m_use_dotprod = false;
#else
// Allow FMA
bool m_use_fma = false;
@ -3658,6 +3661,40 @@ public:
return result;
}
template <typename T1, typename T2, typename T3>
value_t<u32[4]> udot(T1 a, T2 b, T3 c)
{
value_t<u32[4]> result;
const auto data0 = a.eval(m_ir);
const auto data1 = b.eval(m_ir);
const auto data2 = c.eval(m_ir);
// ARM hardware requires the multipliers to be treated as 16-byte vectors
//const auto op1 = bitcast(data1, get_type<u8[16]>());
//const auto op2 = bitcast(data2, get_type<u8[16]>());
// Use the variadic get_intrinsic to resolve the overloaded AArch64 intrinsic
result.value = m_ir->CreateCall(get_intrinsic<u32[4], u8[16]>(llvm::Intrinsic::aarch64_neon_udot), {data0, data1, data2});
return result;
}
template <typename T1, typename T2, typename T3>
value_t<u32[4]> sdot(T1 a, T2 b, T3 c)
{
value_t<u32[4]> result;
const auto data0 = a.eval(m_ir);
const auto data1 = b.eval(m_ir);
const auto data2 = c.eval(m_ir);
//const auto op1 = bitcast(data1, get_type<u8[16]>());
//const auto op2 = bitcast(data2, get_type<u8[16]>());
result.value = m_ir->CreateCall(get_intrinsic<u32[4], u8[16]>(llvm::Intrinsic::aarch64_neon_sdot), {data0, data1, data2});
return result;
}
template <typename T1, typename T2>
value_t<u8[16]> vpermb(T1 a, T2 b)
{

View file

@ -5313,13 +5313,24 @@ public:
return;
}
#ifdef ARCH_ARM64
if (m_use_dotprod)
#else
if (m_use_vnni)
#endif
{
const auto [a, b] = get_vrs<u32[4]>(op.ra, op.rb);
const auto zeroes = splat<u32[4]>(0);
#ifdef ARCH_ARM64
const auto [a, b] = get_vrs<u8[16]>(op.ra, op.rb);
const auto ones = splat<u8[16]>(0x01);
const auto ax = bitcast<u16[8]>(udot(zeroes, a, ones));
const auto bx = bitcast<u16[8]>(udot(zeroes, b, ones));
#else
const auto [a, b] = get_vrs<u32[4]>(op.ra, op.rb);
const auto ones = splat<u32[4]>(0x01010101);
const auto ax = bitcast<u16[8]>(vpdpbusd(zeroes, a, ones));
const auto bx = bitcast<u16[8]>(vpdpbusd(zeroes, b, ones));
#endif
set_vr(op.rt, shuffle2(ax, bx, 0, 8, 2, 10, 4, 12, 6, 14));
return;
}