[amdgpu] Implement V_FMA_F32, IMAGE_SAMPLE_LZ, V_CVT_OFF_F32_I4

Loops fix
Decompile spirv on error
Wait for rpcsx-os if memory not exists
This commit is contained in:
DH 2023-07-14 04:33:45 +03:00
parent d6c8353636
commit 665d74740a
10 changed files with 458 additions and 77 deletions

View file

@ -49,7 +49,7 @@ public:
function = fn;
memory = mem;
auto lastFragment = convertBlock(block, &function->entryFragment);
auto lastFragment = convertBlock(block, &function->entryFragment, nullptr);
if (lastFragment != nullptr) {
lastFragment->builder.createBranch(fn->exitFragment.entryBlockId);
@ -126,7 +126,8 @@ private:
return builder.createLogicalOr(boolT, loIsNotZero, hiIsNotZero);
}
Fragment *convertBlock(scf::Block *block, Fragment *rootFragment) {
Fragment *convertBlock(scf::Block *block, Fragment *rootFragment,
Fragment *loopMergeFragment) {
Fragment *currentFragment = nullptr;
for (scf::Node *node = block->getRootNode(); node != nullptr;
@ -178,6 +179,33 @@ private:
}
if (auto ifElse = dynCast<scf::IfElse>(node)) {
auto isBreakBlock = [](scf::Block *block) {
if (block->isEmpty()) {
return false;
}
if (block->getLastNode() != block->getRootNode()) {
return false;
}
return dynamic_cast<scf::Break *>(block->getRootNode()) != nullptr;
};
if (loopMergeFragment != nullptr && ifElse->ifTrue->isEmpty() &&
isBreakBlock(ifElse->ifFalse)) {
auto mergeFragment = function->createFragment();
currentFragment->appendBranch(*mergeFragment);
currentFragment->appendBranch(*loopMergeFragment);
currentFragment->builder.createBranchConditional(
currentFragment->branchCondition, mergeFragment->entryBlockId,
loopMergeFragment->entryBlockId);
initState(mergeFragment);
releaseStateOf(currentFragment);
currentFragment = mergeFragment;
continue;
}
auto ifTrueFragment = function->createFragment();
auto ifFalseFragment = function->createFragment();
auto mergeFragment = function->createFragment();
@ -185,18 +213,16 @@ private:
currentFragment->appendBranch(*ifTrueFragment);
currentFragment->appendBranch(*ifFalseFragment);
currentFragment->builder.createSelectionMerge(
mergeFragment->entryBlockId, {});
currentFragment->builder.createBranchConditional(
currentFragment->branchCondition, ifTrueFragment->entryBlockId,
ifFalseFragment->entryBlockId);
auto ifTrueLastBlock = convertBlock(ifElse->ifTrue, ifTrueFragment);
auto ifFalseLastBlock = convertBlock(ifElse->ifFalse, ifFalseFragment);
auto ifTrueLastBlock =
convertBlock(ifElse->ifTrue, ifTrueFragment, loopMergeFragment);
auto ifFalseLastBlock =
convertBlock(ifElse->ifFalse, ifFalseFragment, loopMergeFragment);
if (ifTrueLastBlock != nullptr) {
ifTrueLastBlock->builder.createBranch(mergeFragment->entryBlockId);
ifTrueLastBlock->appendBranch(*mergeFragment);
if (!ifTrueLastBlock->hasTerminator) {
ifTrueLastBlock->builder.createBranch(mergeFragment->entryBlockId);
ifTrueLastBlock->appendBranch(*mergeFragment);
}
if (ifTrueLastBlock->registers == nullptr) {
initState(ifTrueLastBlock);
@ -204,14 +230,23 @@ private:
}
if (ifFalseLastBlock != nullptr) {
ifFalseLastBlock->builder.createBranch(mergeFragment->entryBlockId);
ifFalseLastBlock->appendBranch(*mergeFragment);
if (!ifFalseLastBlock->hasTerminator) {
ifFalseLastBlock->builder.createBranch(mergeFragment->entryBlockId);
ifFalseLastBlock->appendBranch(*mergeFragment);
}
if (ifFalseLastBlock->registers == nullptr) {
initState(ifFalseLastBlock);
}
}
currentFragment->builder.createSelectionMerge(
mergeFragment->entryBlockId, {});
currentFragment->builder.createBranchConditional(
currentFragment->branchCondition, ifTrueFragment->entryBlockId,
ifFalseFragment->entryBlockId);
releaseStateOf(currentFragment);
initState(mergeFragment);
@ -226,6 +261,56 @@ private:
continue;
}
if (auto loop = dynCast<scf::Loop>(node)) {
auto headerFragment = function->createFragment();
auto bodyFragment = function->createFragment();
auto mergeFragment = function->createDetachedFragment();
auto continueFragment = function->createDetachedFragment();
currentFragment->builder.createBranch(headerFragment->entryBlockId);
currentFragment->appendBranch(*headerFragment);
initState(headerFragment);
releaseStateOf(currentFragment);
headerFragment->builder.createLoopMerge(
mergeFragment->entryBlockId, continueFragment->entryBlockId,
spv::LoopControlMask::MaskNone, {});
headerFragment->builder.createBranch(bodyFragment->entryBlockId);
headerFragment->appendBranch(*bodyFragment);
auto bodyLastBlock =
convertBlock(loop->body, bodyFragment, mergeFragment);
if (bodyLastBlock != nullptr) {
if (bodyLastBlock->registers == nullptr) {
initState(bodyLastBlock);
}
bodyLastBlock->builder.createBranch(continueFragment->entryBlockId);
bodyLastBlock->appendBranch(*continueFragment);
}
continueFragment->builder.createBranch(headerFragment->entryBlockId);
continueFragment->appendBranch(*headerFragment);
initState(continueFragment);
releaseStateOf(headerFragment);
initState(mergeFragment);
if (bodyLastBlock != nullptr) {
releaseStateOf(bodyLastBlock);
}
function->appendFragment(continueFragment);
function->appendFragment(mergeFragment);
releaseStateOf(continueFragment);
currentFragment = mergeFragment;
continue;
}
if (dynCast<scf::UnknownBlock>(node)) {
auto jumpAddress = currentFragment->jumpAddress;
@ -250,7 +335,7 @@ private:
auto targetFragment = function->createFragment();
currentFragment->builder.createBranch(targetFragment->entryBlockId);
currentFragment->appendBranch(*targetFragment);
auto result = convertBlock(scfBlock, targetFragment);
auto result = convertBlock(scfBlock, targetFragment, nullptr);
if (currentFragment->registers == nullptr) {
initState(targetFragment);
@ -264,9 +349,11 @@ private:
currentFragment->appendBranch(function->exitFragment);
currentFragment->builder.createBranch(
function->exitFragment.entryBlockId);
currentFragment->hasTerminator = true;
return nullptr;
}
node->dump();
util::unreachable();
}

View file

@ -1,10 +1,12 @@
#include "Fragment.hpp"
#include "ConverterContext.hpp"
#include "Instruction.hpp"
#include "RegisterId.hpp"
#include "RegisterState.hpp"
#include <spirv/GLSL.std.450.h>
#include <spirv/spirv-instruction.hpp>
#include <spirv_cross/spirv.hpp>
#include <util/unreachable.hpp>
#include <bit>
@ -553,7 +555,8 @@ enum class CmpKind {
NLT,
NE,
TRU,
T = TRU
T = TRU,
CLASS
};
enum class CmpFlags { None = 0, X = 1 << 0, S = 1 << 1, SX = S | X };
@ -562,7 +565,8 @@ inline CmpFlags operator&(CmpFlags a, CmpFlags b) {
}
Value doCmpOp(Fragment &fragment, TypeId type, spirv::Value src0,
spirv::Value src1, CmpKind kind, CmpFlags flags) {
spirv::Value src1, CmpKind kind, CmpFlags flags,
std::uint8_t typeMask = 0) {
spirv::BoolValue cmp;
auto boolT = fragment.context->getBoolType();
@ -652,6 +656,89 @@ Value doCmpOp(Fragment &fragment, TypeId type, spirv::Value src0,
case CmpKind::TRU:
cmp = fragment.context->getTrue();
break;
case CmpKind::CLASS: {
enum class FloatClass {
SNan = 0,
QNan = 1,
NInf = 2,
NNorm = 3,
NDenom = 4,
NZero = 5,
PZero = 6,
PDenom = 7,
PNorm = 8,
PInf = 9,
};
auto testCmpClass = [&](FloatClass fclass,
spirv::FloatValue val) -> spirv::BoolValue {
switch (fclass) {
case FloatClass::SNan:
case FloatClass::QNan:
return fragment.builder.createIsNan(boolT, val);
case FloatClass::NInf:
return fragment.builder.createLogicalAnd(
boolT,
fragment.builder.createFOrdLessThan(
boolT, val, fragment.context->getFloat32(0)),
fragment.builder.createIsInf(boolT, val));
case FloatClass::NZero:
case FloatClass::PZero:
return fragment.builder.createFOrdEqual(
boolT, val, fragment.context->getFloat32(0));
case FloatClass::NNorm:
case FloatClass::NDenom:
case FloatClass::PDenom:
case FloatClass::PNorm:
util::unreachable();
case FloatClass::PInf:
return fragment.builder.createLogicalAnd(
boolT,
fragment.builder.createFOrdGreaterThan(
boolT, val, fragment.context->getFloat32(0)),
fragment.builder.createIsInf(boolT, val));
}
util::unreachable();
};
// we cannot differ signaling and quiet nan
if (typeMask & 3) {
typeMask = (typeMask & ~3) | 2;
}
// we cannot differ positive and negative zero
if (typeMask & 0x60) {
typeMask = (typeMask & ~0x60) | 0x40;
}
for (int i = 0; i < 10; ++i) {
if (typeMask & (1 << i)) {
auto lhs =
testCmpClass((FloatClass)i, spirv::cast<spirv::FloatValue>(src0));
auto rhs =
testCmpClass((FloatClass)i, spirv::cast<spirv::FloatValue>(src1));
auto bitResult = fragment.builder.createLogicalAnd(boolT, lhs, rhs);
if (cmp) {
cmp = fragment.builder.createLogicalOr(boolT, cmp, bitResult);
} else {
cmp = bitResult;
}
}
}
if (!cmp) {
cmp = fragment.context->getFalse();
}
break;
}
}
if (!cmp) {
@ -1563,7 +1650,20 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
auto src0 = fragment.getScalarOperand(inst.src0, type).value;
auto src1 = fragment.getScalarOperand(inst.src1, type).value;
auto result = doCmpOp(fragment, type, src0, src1, kind, flags);
std::int8_t typeMask = 0;
if (kind == CmpKind::CLASS) {
auto value = fragment.context->findSint32Value(
fragment.getScalarOperand(inst.src2, type).value);
if (!value) {
// util::unreachable();
typeMask = 2;
} else {
typeMask = *value;
}
}
auto result = doCmpOp(fragment, type, src0, src1, kind, flags, typeMask);
fragment.setScalarOperand(inst.vdst, result);
fragment.setScalarOperand(inst.vdst + 1, {fragment.context->getUInt32Type(),
fragment.context->getUInt32(0)});
@ -1978,8 +2078,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_T_I32:
cmpOp(TypeId::SInt32, CmpKind::T);
break;
// case Vop3::Op::V3_CMP_CLASS_F32: cmpOp(TypeId::Float32, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F32:
cmpOp(TypeId::Float32, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMP_LT_I16:
cmpOp(TypeId::SInt16, CmpKind::LT);
break;
@ -1998,8 +2099,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_GE_I16:
cmpOp(TypeId::SInt16, CmpKind::GE);
break;
// case Vop3::Op::V3_CMP_CLASS_F16: cmpOp(TypeId::Float16, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F16:
cmpOp(TypeId::Float16, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMPX_F_I32:
cmpOp(TypeId::SInt32, CmpKind::F, CmpFlags::X);
break;
@ -2024,8 +2126,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_T_I32:
cmpOp(TypeId::SInt32, CmpKind::T, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F32: cmpOp(TypeId::Float32, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F32:
cmpOp(TypeId::Float32, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMPX_LT_I16:
cmpOp(TypeId::SInt16, CmpKind::LT, CmpFlags::X);
break;
@ -2044,8 +2147,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_GE_I16:
cmpOp(TypeId::SInt16, CmpKind::GE, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F16: cmpOp(TypeId::Float16, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F16:
cmpOp(TypeId::Float16, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMP_F_I64:
cmpOp(TypeId::SInt64, CmpKind::F);
break;
@ -2070,8 +2174,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMP_T_I64:
cmpOp(TypeId::SInt64, CmpKind::T);
break;
// case Vop3::Op::V3_CMP_CLASS_F64: cmpOp(TypeId::Float64, CmpKind::CLASS);
// break;
case Vop3::Op::V3_CMP_CLASS_F64:
cmpOp(TypeId::Float64, CmpKind::CLASS);
break;
case Vop3::Op::V3_CMP_LT_U16:
cmpOp(TypeId::UInt16, CmpKind::LT);
break;
@ -2114,8 +2219,9 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
case Vop3::Op::V3_CMPX_T_I64:
cmpOp(TypeId::SInt64, CmpKind::T, CmpFlags::X);
break;
// case Vop3::Op::V3_CMPX_CLASS_F64: cmpOp(TypeId::Float64, CmpKind::CLASS,
// CmpFlags::X); break;
case Vop3::Op::V3_CMPX_CLASS_F64:
cmpOp(TypeId::Float64, CmpKind::CLASS, CmpFlags::X);
break;
case Vop3::Op::V3_CMPX_LT_U16:
cmpOp(TypeId::UInt16, CmpKind::LT, CmpFlags::X);
break;
@ -2515,8 +2621,6 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
fragment.getScalarOperand(inst.src0, TypeId::SInt32).value);
auto src1 = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src1, TypeId::SInt32).value);
auto src2 = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src2, TypeId::SInt32).value);
auto operandT = fragment.context->getSint32Type();
src0 = fragment.builder.createShiftLeftLogical(
@ -2548,6 +2652,42 @@ void convertVop3(Fragment &fragment, Vop3 inst) {
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_MAX3_F32: {
auto src0 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto src1 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src1, TypeId::Float32).value);
auto src2 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src2, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto boolT = fragment.context->getBoolType();
auto max01 = fragment.builder.createSelect(
floatT, fragment.builder.createFOrdGreaterThanEqual(boolT, src0, src1),
src0, src1);
auto result = fragment.builder.createSelect(
floatT, fragment.builder.createFOrdGreaterThanEqual(boolT, max01, src2),
max01, src2);
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_FMA_F32: {
auto src0 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto src1 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src1, TypeId::Float32).value);
auto src2 = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src2, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto glslStd450 = fragment.context->getGlslStd450();
auto result = fragment.builder.createExtInst(
floatT, glslStd450, GLSLstd450Fma, {{src0, src1, src2}});
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop3::Op::V3_CNDMASK_B32: {
auto src0 = fragment.getScalarOperand(inst.src0, TypeId::UInt32).value;
auto src1 = fragment.getScalarOperand(inst.src1, TypeId::UInt32).value;
@ -3321,6 +3461,40 @@ void convertMimg(Fragment &fragment, Mimg inst) {
}
break;
}
case Mimg::Op::IMAGE_SAMPLE_LZ: {
auto image =
fragment.createImage(RegisterId::Raw(inst.srsrc << 2), inst.r128);
auto sampler = fragment.createSampler(RegisterId::Raw(inst.ssamp << 2));
auto coord0 = fragment.getVectorOperand(inst.vaddr, TypeId::Float32).value;
auto coord1 =
fragment.getVectorOperand(inst.vaddr + 1, TypeId::Float32).value;
auto coord2 =
fragment.getVectorOperand(inst.vaddr + 2, TypeId::Float32).value;
auto coords = fragment.builder.createCompositeConstruct(
fragment.context->getFloat32x3Type(),
{{coord0, coord1, coord2}}); // TODO
auto sampledImage2dT = fragment.context->getSampledImage2DType();
auto float4T = fragment.context->getFloat32x4Type();
auto floatT = fragment.context->getFloat32Type();
auto sampledImage =
fragment.builder.createSampledImage(sampledImage2dT, image, sampler);
auto value = fragment.builder.createImageSampleExplicitLod(
float4T, sampledImage, coords, spv::ImageOperandsMask::Lod,
{{fragment.context->getFloat32(0)}});
for (std::uint32_t dstOffset = 0, i = 0; i < 4; ++i) {
if (inst.dmask & (1 << i)) {
fragment.setVectorOperand(
inst.vdata + dstOffset++,
{floatT,
fragment.builder.createCompositeExtract(floatT, value, {{i}})});
}
}
break;
}
case Mimg::Op::IMAGE_SAMPLE: {
auto image =
fragment.createImage(RegisterId::Raw(inst.srsrc << 2), inst.r128);
@ -3499,11 +3673,28 @@ void convertVop1(Fragment &fragment, Vop1 inst) {
break;
}
case Vop1::Op::V_CVT_OFF_F32_I4: {
auto src = spirv::cast<spirv::SIntValue>(
fragment.getScalarOperand(inst.src0, TypeId::SInt32).value);
auto floatT = fragment.context->getFloat32Type();
auto int32T = fragment.context->getSint32Type();
src = spirv::cast<spirv::SIntValue>(fragment.builder.createBitwiseAnd(
int32T, src, fragment.context->getSInt32(0b1111)));
src = fragment.builder.createISub(int32T, src,
fragment.context->getSInt32(8));
auto fsrc = fragment.builder.createConvertSToF(floatT, src);
auto result = fragment.builder.createFDiv(floatT, fsrc,
fragment.context->getFloat32(16));
fragment.setVectorOperand(inst.vdst, {floatT, result});
break;
}
case Vop1::Op::V_RSQ_F32: {
auto src = spirv::cast<spirv::FloatValue>(
fragment.getScalarOperand(inst.src0, TypeId::Float32).value);
auto floatT = fragment.context->getFloat32Type();
auto float1 = fragment.context->getFloat32(1);
auto glslStd450 = fragment.context->getGlslStd450();
auto result = fragment.builder.createExtInst(

View file

@ -225,10 +225,9 @@ spirv::FunctionType Function::getFunctionType() {
return context->getFunctionType(getResultType(), params);
}
Fragment *Function::createFragment() {
Fragment *Function::createDetachedFragment() {
auto result = context->createFragment(0);
result->function = this;
fragments.push_back(result);
return result;
}