[amdgpu] Implement V_FMA_F32, IMAGE_SAMPLE_LZ, V_CVT_OFF_F32_I4

Loops fix
Decompile spirv on error
Wait for rpcsx-os if memory not exists
This commit is contained in:
DH 2023-07-14 04:33:45 +03:00
parent d6c8353636
commit 665d74740a
10 changed files with 458 additions and 77 deletions

View file

@ -1,10 +1,12 @@
#include "Fragment.hpp"
#include "ConverterContext.hpp"
#include "Instruction.hpp"
#include "RegisterId.hpp"
#include "RegisterState.hpp"
#include <spirv/GLSL.std.450.h>
#include <spirv/spirv-instruction.hpp>
#include <spirv_cross/spirv.hpp>
#include <util/unreachable.hpp>
#include <bit>
@ -553,7 +555,8 @@ enum class CmpKind {
NLT,
NE,
TRU,
T = TRU
T = TRU,
CLASS
};
enum class CmpFlags { None = 0, X = 1 << 0, S = 1 << 1, SX = S | X };
@ -562,7 +565,8 @@ inline CmpFlags operator&(CmpFlags a, CmpFlags b) {
}
Value doCmpOp(Fragment &fragment, TypeId type, spirv::Value src0,
spirv::Value src1, CmpKind kind, CmpFlags flags) {
spirv::Value src1, CmpKind kind, CmpFlags flags,
std::uint8_t typeMask = 0) {
spirv::BoolValue cmp;
auto boolT = fragment.context->getBoolType();
@ -652,6 +656,89 @@ Value doCmpOp(Fragment &fragment, TypeId type, spirv::Value src0,
case CmpKind::TRU:
cmp = fragment.context->getTrue();
break;
case CmpKind::CLASS: {
enum class FloatClass {
SNan = 0,
QNan = 1,
NInf = 2,
NNorm = 3,
NDenom = 4,
NZero = 5,
PZero = 6,
PDenom = 7,
PNorm = 8,
PInf = 9,
};
auto testCmpClass = [&](FloatClass fclass,
spirv::FloatValue val) -> spirv::BoolValue {
switch (fclass) {
case FloatClass::SNan:
case FloatClass::QNan:
return fragment.builder.createIsNan(boolT, val);
case FloatClass::NInf:
return fragment.builder.createLogicalAnd(
boolT,
fragment.builder.createFOrdLessThan(
boolT, val, fragment.context->getFloat32(0)),
fragment.builder.createIsInf(boolT, val));
case FloatClass::NZero:
case FloatClass::PZero:
return fragment.builder.createFOrdEqual(
boolT, val, fragment.context->getFloat32(0));
case FloatClass::NNorm:
case FloatClass::NDenom:
case FloatClass::PDenom:
case FloatClass::PNorm:
util::unreachable();
case FloatClass::PInf:
return fragment.builder.createLogicalAnd(
boolT,
fragment.builder.createFOrdGreaterThan(
boolT, val, fragment.context->getFloat32(0)),
fragment.builder.createIsInf(boolT, val));
}
util::unreachable();
};
// we cannot differ signaling and quiet nan
if (typeMask & 3) {
typeMask = (typeMask & ~3) | 2;
}
// we cannot differ positive and negative zero
if (typeMask & 0x60) {
typeMask = (typeMask & ~0x60) | 0x40;
}
for (int i = 0; i < 10; ++i) {
if (typeMask & (1 << i)) {
auto lhs =
testCmpClass((FloatClass)i, spirv::cast<spirv::FloatValue>(src0));
auto rhs =
testCmpClass((FloatClass)i, spirv::cast<spirv::FloatValue>(src1));
auto bitResult = fragment.builder.createLogicalAnd(boolT, lhs, rhs);
if (cmp) {
cmp = fragment.builder.createLogicalOr(boolT, cmp, bitResult);
} else {
cmp = bitResult;
}
}
}
if (!cmp) {
cmp = fragment.context->getFalse();
}
break;
}
}
if (!cmp) {
@ -1563,7 +1650,20 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
auto src0 = fragment.getScalarOperand(inst.src0, type).value;
auto src1 = fragment.getScalarOperand(inst.src1, type).value;
auto result = doCmpOp(fragment, type, src0, src1, kind, flags);
std::int8_t typeMask = 0;
if (kind == CmpKind::CLASS) {
auto value = fragment.context->findSint32Value(
fragment.getScalarOperand(inst.src2, type).value);
if (!value) {
// util::unreachable();
typeMask = 2;
} else {
typeMask = *value;
}
}
auto result = doCmpOp(fragment, type, src0, src1, kind, flags, typeMask);
fragment.setScalarOperand(inst.vdst, result);
fragment.setScalarOperand(inst.vdst + 1, {fragment.context->getUInt32Type(),
fragment.context->getUInt32(0)});
@ -1978,8 +2078,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_T_I32:
cmpOp(TypeId::SInt32, CmpKind::T);
break;
// case Vop3::Op::V3_CMP_CLASS_F32: cmpOp(TypeId::Float32, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F32:
cmpOp(TypeId::Float32, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMP_LT_I16:
cmpOp(TypeId::SInt16, CmpKind::LT);
break;
@ -1998,8 +2099,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_GE_I16:
cmpOp(TypeId::SInt16, CmpKind::GE);
break;
// case Vop3::Op::V3_CMP_CLASS_F16: cmpOp(TypeId::Float16, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F16:
cmpOp(TypeId::Float16, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMPX_F_I32:
cmpOp(TypeId::SInt32, CmpKind::F, CmpFlags::X);
break;
@ -2024,8 +2126,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_T_I32:
cmpOp(TypeId::SInt32, CmpKind::T, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F32: cmpOp(TypeId::Float32, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F32:
cmpOp(TypeId::Float32, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMPX_LT_I16:
cmpOp(TypeId::SInt16, CmpKind::LT, CmpFlags::X);
break;
@ -2044,8 +2147,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_GE_I16:
cmpOp(TypeId::SInt16, CmpKind::GE, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F16: cmpOp(TypeId::Float16, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F16:
cmpOp(TypeId::Float16, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMP_F_I64:
cmpOp(TypeId::SInt64, CmpKind::F);
break;
@ -2070,8 +2174,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_T_I64:
cmpOp(TypeId::SInt64, CmpKind::T);
break;
// case Vop3::Op::V3_CMP_CLASS_F64: cmpOp(TypeId::Float64, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F64:
cmpOp(TypeId::Float64, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMP_LT_U16:
cmpOp(TypeId::UInt16, CmpKind::LT);
break;
@ -2114,8 +2219,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_T_I64:
cmpOp(TypeId::SInt64, CmpKind::T, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F64: cmpOp(TypeId::Float64, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F64:
cmpOp(TypeId::Float64, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMPX_LT_U16:
cmpOp(TypeId::UInt16, CmpKind::LT, CmpFlags::X);
break;
@ -2515,8 +2621,6 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
fragment.getScalarOperand(inst.src0, TypeId::SInt32).value);
auto src1 = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src1, TypeId::SInt32).value);
auto src2 = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src2, TypeId::SInt32).value);
auto operandT = fragment.context->getSint32Type();
src0 = fragment.builder.createShiftLeftLogical(
@ -2548,6 +2652,42 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_MAX3_F32: {
auto src0 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto src1 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src1, TypeId::Float32).value);
auto src2 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src2, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto boolT = fragment.context->getBoolType();
auto max01 = fragment.builder.createSelect(
floatT, fragment.builder.createFOrdGreaterThanEqual(boolT, src0, src1),
src0, src1);
auto result = fragment.builder.createSelect(
floatT, fragment.builder.createFOrdGreaterThanEqual(boolT, max01, src2),
max01, src2);
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_FMA_F32: {
auto src0 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto src1 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src1, TypeId::Float32).value);
auto src2 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src2, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto glslStd450 = fragment.context->getGlslStd450();
auto result = fragment.builder.createExtInst(
floatT, glslStd450, GLSLstd450Fma, {{src0, src1, src2}});
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_CNDMASK_B32: {
auto src0 = fragment.getScalarOperand(inst.src0, TypeId::UInt32).value;
auto src1 = fragment.getScalarOperand(inst.src1, TypeId::UInt32).value;
@ -3321,6 +3461,40 @@ void convertMimg(Fragment &fragment, Mimg inst) {
}
break;
}
case Mimg::Op::IMAGE_SAMPLE_LZ: {
auto image =
fragment.createImage(RegisterId::Raw(inst.srsrc << 2), inst.r128);
auto sampler = fragment.createSampler(RegisterId::Raw(inst.ssamp << 2));
auto coord0 = fragment.getVectorOperand(inst.vaddr, TypeId::Float32).value;
auto coord1 =
fragment.getVectorOperand(inst.vaddr + 1, TypeId::Float32).value;
auto coord2 =
fragment.getVectorOperand(inst.vaddr + 2, TypeId::Float32).value;
auto coords = fragment.builder.createCompositeConstruct(
fragment.context->getFloat32x3Type(),
{{coord0, coord1, coord2}}); // TODO
auto sampledImage2dT = fragment.context->getSampledImage2DType();
auto float4T = fragment.context->getFloat32x4Type();
auto floatT = fragment.context->getFloat32Type();
auto sampledImage =
fragment.builder.createSampledImage(sampledImage2dT, image, sampler);
auto value = fragment.builder.createImageSampleExplicitLod(
float4T, sampledImage, coords, spv::ImageOperandsMask::Lod,
{{fragment.context->getFloat32(0)}});
for (std::uint32_t dstOffset = 0, i = 0; i < 4; ++i) {
if (inst.dmask & (1 << i)) {
fragment.setVectorOperand(
inst.vdata + dstOffset++,
{floatT,
fragment.builder.createCompositeExtract(floatT, value, {{i}})});
}
}
break;
}
case Mimg::Op::IMAGE_SAMPLE: {
auto image =
fragment.createImage(RegisterId::Raw(inst.srsrc << 2), inst.r128);
@ -3499,11 +3673,28 @@ void convertVop1(Fragment &fragment, Vop1 inst) {
break;
}
case Vop1::Op::V_CVT_OFF_F32_I4: {
auto src = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src0, TypeId::SInt32).value);
auto floatT = fragment.context->getFloat32Type();
auto int32T = fragment.context->getSint32Type();
src = spirv::cast<spirv::SIntValue>(fragment.builder.createBitwiseAnd(
int32T, src, fragment.context->getSInt32(0b1111)));
src = fragment.builder.createISub(int32T, src,
fragment.context->getSInt32(8));
auto fsrc = fragment.builder.createConvertSToF(floatT, src);
auto result = fragment.builder.createFDiv(floatT, fsrc,
fragment.context->getFloat32(16));
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop1::Op::V_RSQ_F32: {
auto src = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto float1 = fragment.context->getFloat32(1);
auto glslStd450 = fragment.context->getGlslStd450();
auto result = fragment.builder.createExtInst(