diff --git a/CMakeLists.txt b/CMakeLists.txt index bbd3d4c58..38b111916 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,3 +288,4 @@ if (WITH_PS3) add_subdirectory(ps3fw) endif() +add_subdirectory(test) diff --git a/rpcsx/gpu/lib/gcn-shader/CMakeLists.txt b/rpcsx/gpu/lib/gcn-shader/CMakeLists.txt index 606664be3..0c9ddef9f 100644 --- a/rpcsx/gpu/lib/gcn-shader/CMakeLists.txt +++ b/rpcsx/gpu/lib/gcn-shader/CMakeLists.txt @@ -30,6 +30,12 @@ add_library(gcn-shader STATIC src/SpvConverter.cpp src/SpvTypeInfo.cpp src/transform.cpp + src/transform/replace.cpp + src/transform/route.cpp + src/transform/merge.cpp + src/transform/construct.cpp + src/transform/transformations.cpp + src/transform/wrap.cpp ) target_include_directories(gcn-shader PUBLIC include PRIVATE include/shader) diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/Edge.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/Edge.hpp new file mode 100644 index 000000000..774d090dd --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/Edge.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "analyze.hpp" +#include "ir.hpp" +#include "replace.hpp" + +namespace shader::transform { +class Edge { + ir::Block mFromBlock; + int mToOperandIndex; + +public: + Edge(ir::Block fromBlock, int toOperandIndex) + : mFromBlock(fromBlock), mToOperandIndex(toOperandIndex) {} + + [[nodiscard]] ir::Block from() const { return mFromBlock; } + [[nodiscard]] ir::Block to() const { + return getTerminator(mFromBlock) + .getOperand(mToOperandIndex) + .getAsValue() + .staticCast(); + } + + [[nodiscard]] int operandIndex() const { return mToOperandIndex; } + + void replaceSuccessor(ir::Value newSuccessor) { + replaceTerminatorTarget(getTerminator(mFromBlock), mToOperandIndex, + newSuccessor); + } + + bool operator==(const Edge &) const = default; +}; +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/construct.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/construct.hpp new file mode 100644 index 000000000..f858827a3 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/construct.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "SpvConverter.hpp" +#include "analyze.hpp" +#include "ir.hpp" + +namespace shader::transform { + +bool isConstruct(ir::Instruction block); +bool isParentConstruct(ir::RegionLike parent, + ir::RegionLike construct); + +ir::Block getConstructOf(ir::Instruction inst); +ir::Block getConstructMergeBlock(ir::Block block); + +ir::SelectionConstruct createSelectionConstruct(spv::Context &context, + ir::RegionLike parentConstruct, + const std::unordered_set &components, + ir::Block header, ir::Block merge); + +ir::LoopConstruct createLoopConstruct(spv::Context &context, + ir::RegionLike parentConstruct, + ir::Block header, + ir::Block latch, + ir::Block cont, + ir::Block merge, + const std::unordered_set &scc); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/merge.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/merge.hpp new file mode 100644 index 000000000..c21d88ad6 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/merge.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "SpvConverter.hpp" +#include "analyze.hpp" +#include "dialect.hpp" +#include "ir.hpp" + + +namespace shader::transform { +ir::Block createMergeBlock(spv::Context &context, + ir::InsertionPoint insertPoint, + const std::unordered_set &preds, + ir::Block to); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/replace.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/replace.hpp new file mode 100644 index 000000000..3e8e76cec --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/replace.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "analyze.hpp" +#include "dialect.hpp" +#include "ir.hpp" + +namespace shader::transform { +void replaceTerminatorTarget(ir::Instruction terminator, + int operandIndex, + ir::Value newTarget); + +bool replaceTerminatorTarget(ir::Instruction terminator, + ir::Value oldTarget, + ir::Value newTarget); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/route.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/route.hpp new file mode 100644 index 000000000..4aaffc0d9 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/route.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "Edge.hpp" +#include "SpvConverter.hpp" +#include "ir.hpp" + +namespace shader::transform { +ir::Block createRouteBlock(spv::Context &context, + ir::InsertionPoint insertPoint, + const std::vector &edges); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/transformations.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/transformations.hpp new file mode 100644 index 000000000..abcbf970f --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/transformations.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "SpvConverter.hpp" +#include "ir.hpp" + + +namespace shader::transform { +ir::Value transformToCanonicalRegion(spv::Context &context, + ir::RegionLike region); +void transformToCf(spv::Context &context, ir::RegionLike region); +void transformToFlat(spv::Context &context, ir::RegionLike region); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/include/shader/transform/wrap.hpp b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/wrap.hpp new file mode 100644 index 000000000..33d396bf1 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/include/shader/transform/wrap.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "SpvConverter.hpp" + +namespace shader::transform { +void wrapLoopConstructs(spv::Context &context, ir::RegionLike root); +void wrapSelectionConstructs(spv::Context &context, ir::RegionLike root); +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform.cpp index 7b8035ee9..772f8ea48 100644 --- a/rpcsx/gpu/lib/gcn-shader/src/transform.cpp +++ b/rpcsx/gpu/lib/gcn-shader/src/transform.cpp @@ -1,1433 +1,17 @@ #include "transform.hpp" +#include "transform/transformations.hpp" +#include "transform/wrap.hpp" #include "SpvConverter.hpp" -#include "analyze.hpp" #include "dialect.hpp" -#include -#include #include #include -#include -#include -#include using namespace shader; +using namespace shader::transform; using Builder = ir::Builder; -static bool isConstruct(ir::Instruction block) { - return block == ir::builtin::LOOP_CONSTRUCT || - block == ir::builtin::SELECTION_CONSTRUCT || - block == ir::builtin::CONTINUE_CONSTRUCT; -} -static ir::Block getConstructOf(ir::Instruction inst) { - auto block = inst.cast(); - if (block && isConstruct(block)) { - block = block.getParent().cast(); - } - - while (block) { - if (isConstruct(block)) { - return block; - } - - block = block.getParent().cast(); - } - - return {}; -} - -static ir::Instruction skipPhis(ir::Instruction inst) { - while (inst && inst == ir::spv::OpPhi) { - inst = inst.getNext(); - } - - return inst; -} - -static ir::Block getConstructMergeBlock(ir::Block block) { - if (auto construct = block.cast()) { - return construct.getMerge(); - } - - return {}; -} - -/** - * Tarjan's algorithm for finding strongly connected components (SCCs). - * This finds all cycles in the CFG - */ -static std::vector> -findSCCs(ir::Range nodes) { - std::unordered_map indices; - std::unordered_map lowlinks; - std::unordered_set onStack; - std::vector stack; - std::vector> sccs; - std::size_t index = 0; - - auto rootParent = (*nodes.begin()).getParent(); - - std::function strongConnect = [&](ir::Block node) { - indices[node] = index; - lowlinks[node] = index; - index++; - stack.push_back(node); - onStack.insert(node); - - // Consider successors of node - for (auto successor : getSuccessors(node)) { - if (successor.getParent() != rootParent) { - continue; - } - - if (!indices.contains(successor)) { - // Successor has not yet been visited; recurse on it - strongConnect(successor); - lowlinks[node] = std::min(lowlinks[node], lowlinks[successor]); - } else if (onStack.contains(successor)) { - // Successor is in stack and hence in the current SCC - lowlinks[node] = std::min(lowlinks[node], indices[successor]); - } - } - - // If node is a root node, pop the stack and create an SCC - if (lowlinks[node] == indices[node]) { - std::unordered_set scc; - scc.reserve(stack.size()); - ir::Block w; - do { - w = stack.back(); - stack.pop_back(); - onStack.erase(w); - scc.insert(w); - } while (w != node); - - // keep cycles only - if (!scc.empty()) { - auto isLoop = scc.size() > 1; - - if (!isLoop) { - // single node can contain branch to self - isLoop = hasSuccessor(w, w); - } - - if (isLoop) { - sccs.push_back(std::move(scc)); - } - } - } - }; - - for (auto node : nodes) { - if (node.getParent() != rootParent) { - continue; - } - - if (!indices.contains(node)) { - strongConnect(node); - } - } - return sccs; -} - -static void replaceTerminatorTarget(ir::Instruction terminator, - int operandIndex, ir::Value newTarget) { - auto prevTarget = terminator.getOperand(operandIndex).getAsValue(); - terminator.replaceOperand(operandIndex, newTarget); - auto selection = terminator.getPrev(); - - if (selection == ir::spv::OpSelectionMerge || - selection == ir::spv::OpLoopMerge) { - for (std::size_t i = 0, end = selection.getOperandCount(); i < end; ++i) { - if (selection.getOperand(i) == prevTarget) { - selection.replaceOperand(i, newTarget); - break; - } - } - } -} - -static bool replaceTerminatorTarget(ir::Instruction terminator, - ir::Value oldTarget, ir::Value newTarget) { - bool changes = false; - for (std::size_t i = 0, end = terminator.getOperandCount(); i < end; ++i) { - if (terminator.getOperand(i) == oldTarget) { - replaceTerminatorTarget(terminator, i, newTarget); - changes = true; - } - } - - return changes; -} - -class Edge { - ir::Block mFromBlock; - int mToOperandIndex; - -public: - Edge(ir::Block fromBlock, int toOperandIndex) - : mFromBlock(fromBlock), mToOperandIndex(toOperandIndex) {} - - [[nodiscard]] ir::Block from() const { return mFromBlock; } - [[nodiscard]] ir::Block to() const { - return getTerminator(mFromBlock) - .getOperand(mToOperandIndex) - .getAsValue() - .staticCast(); - } - - [[nodiscard]] int operandIndex() const { return mToOperandIndex; } - - void replaceSuccessor(ir::Value newSuccessor) { - replaceTerminatorTarget(getTerminator(mFromBlock), mToOperandIndex, - newSuccessor); - } - - bool operator==(const Edge &) const = default; -}; - -inline Edge createEdge(ir::Block from, ir::Block to) { - for (int index = 0; auto &op : from.getLast().getOperands()) { - - if (op.getAsValue() == to) { - return {from, index}; - } - index++; - } - - rx::die("attempt to create invalid edge"); -} - -struct CycleEdges { - std::vector entryEdges; - std::vector backEdges; - std::vector exitEdges; -}; - -static CycleEdges -calculateCycleEdges(const std::unordered_set &cycles) { - CycleEdges result; - std::unordered_set entryBlocks; - - for (auto block : cycles) { - for (auto [pred, operandIndex] : getAllPredecessors(block)) { - if (cycles.contains(pred)) { - continue; - } - - result.entryEdges.emplace_back(pred, operandIndex); - } - - for (auto [succ, operandIndex] : getAllSuccessors(block)) { - if (cycles.contains(succ)) - continue; - - entryBlocks.insert(succ); - result.exitEdges.emplace_back(block, operandIndex); - } - } - - for (auto block : cycles) { - for (auto [succ, operandIndex] : getAllSuccessors(block)) { - if (entryBlocks.contains(succ)) - continue; - - result.backEdges.emplace_back(block, operandIndex); - } - } - - return result; -} - -static ir::Block createMergeBlock(spv::Context &context, - ir::InsertionPoint insertPoint, - const std::unordered_set &preds, - ir::Block to) { - rx::dieIf(preds.empty(), "createMergeBlock: unexpected edges count"); - - auto loc = to.getLocation(); - - auto mergeBlock = Builder::create(context, insertPoint).createBlock(loc); - Builder::createAppend(context, mergeBlock).createSpvBranch(loc, to); - - if (preds.size() == getPredecessorCount(to)) { - for (auto phi : ir::range(to.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - phi.erase(); - mergeBlock.prependChild(phi); - } - } else if (preds.size() == 1) { - auto pred = *preds.begin(); - for (auto phi : ir::range(to.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) { - if (phi.getOperand(i) == pred) { - phi.replaceOperand(i, mergeBlock); - } - } - } - } else { - for (auto phi : ir::range(to.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - auto newPhi = - Builder::createPrepend(context, mergeBlock) - .createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue()); - - for (std::size_t i = 1; i < phi.getOperandCount();) { - // auto value = phi.getOperand(i).getAsValue(); - auto label = phi.getOperand(i + 1).getAsValue().staticCast(); - if (preds.contains(label)) { - newPhi.addOperand(phi.eraseOperand(i)); - newPhi.addOperand(phi.eraseOperand(i)); - } else { - i += 2; - } - } - - phi.addOperand(newPhi); - phi.addOperand(mergeBlock); - } - } - - for (auto pred : preds) { - replaceTerminatorTarget(getTerminator(pred), to, mergeBlock); - } - - return mergeBlock; -} - -static ir::Block createRouteBlock(spv::Context &context, - ir::InsertionPoint insertPoint, - const std::vector &edges) { - auto loc = context.getUnknownLocation(); - - rx::dieIf(edges.empty(), "createRouteBlock: unexpected edges count"); - - std::unordered_map> fromSucc; - std::unordered_map> toPreds; - std::unordered_map> toAllPreds; - std::unordered_set patchPredecessors; - - { - std::unordered_set routePredecessors; - - for (auto edge : edges) { - if (!routePredecessors.insert(edge.from()).second) { - patchPredecessors.insert(edge.from()); - } - - toPreds[edge.to()].emplace(edge.from()); - fromSucc[edge.from()].emplace(edge.operandIndex()); - } - - for (auto &[to, preds] : toPreds) { - toAllPreds[to] = getPredecessors(to); - } - } - - if (toPreds.size() == 1) { - auto &[to, preds] = *toPreds.begin(); - return createMergeBlock(context, insertPoint, preds, to); - } - - auto route = Builder::create(context, insertPoint).createBlock(loc); - ir::Value routePhi; - - if (toPreds.size() > 1) { - routePhi = - Builder::createPrepend(context, route) - .createSpvPhi(loc, toPreds.size() == 2 ? context.getTypeBool() - : context.getTypeUInt32()); - } - - std::unordered_map successorToId; - - if (toPreds.size() == 1) { - // single successor, create unconditional branch - Builder::createAppend(context, route) - .createSpvBranch(loc, toPreds.begin()->first); - } else if (toPreds.size() == 2) { - // 2 successors, create conditional branch - auto it = toPreds.begin(); - auto firstSuccessor = it->first; - auto secondSuccessor = (++it)->first; - - Builder::createAppend(context, route) - .createSpvBranchConditional(loc, routePhi, firstSuccessor, - secondSuccessor); - } else { - // > 2 successors, create switch - auto routeSwitch = - Builder::createAppend(context, route) - .createSpvSwitch(loc, routePhi, toPreds.begin()->first); - - successorToId.reserve(toPreds.size()); - - for (std::uint32_t id = 0; auto &[succ, pred] : toPreds) { - if (id) { - routeSwitch.addOperand(id); - routeSwitch.addOperand(succ); - } - - successorToId[succ] = id++; - } - } - - auto getSuccessorId = [&](ir::Block successor) { - if (toPreds.size() == 2) { - return context.getBool(successor == toPreds.begin()->first); - } - - return context.imm32(successorToId.at(successor)); - }; - - for (auto patchBlock : patchPredecessors) { - auto predSuccessors = getAllSuccessors(patchBlock); - auto terminator = getTerminator(patchBlock); - auto &routeSuccessors = fromSucc.at(patchBlock); - - int keepSuccessors = predSuccessors.size() - routeSuccessors.size(); - - assert(keepSuccessors >= 0); - assert(terminator == ir::spv::OpSwitch || - terminator == ir::spv::OpBranchConditional); - - auto cond = terminator.getOperand(0).getAsValue(); - auto condType = cond.getOperand(0).getAsValue(); - std::map condValueToSucc; - ir::Block defaultSucc; - - if (keepSuccessors == 0) { - // we are going to replace all successors of this block, create direct - // jump to route block - Builder::createInsertAfter(context, terminator) - .createSpvBranch(terminator.getLocation(), route); - - if (terminator == ir::spv::OpBranchConditional) { - condValueToSucc[context.getTrue()] = - terminator.getOperand(1).getAsValue().staticCast(); - condValueToSucc[context.getFalse()] = - terminator.getOperand(2).getAsValue().staticCast(); - } else if (terminator == ir::spv::OpSwitch) { - defaultSucc = - terminator.getOperand(1).getAsValue().staticCast(); - - for (int i = 2, end = terminator.getOperandCount(); i < end; i += 2) { - condValueToSucc[terminator.getOperand(i)] = - terminator.getOperand(i + 1).getAsValue().staticCast(); - } - } - } else if (terminator == ir::spv::OpSwitch) { - if (routeSuccessors.contains(1)) { - defaultSucc = - terminator.getOperand(1).getAsValue().staticCast(); - } - - bool shouldReplaceDefault = defaultSucc != nullptr; - - for (int i = 2, id = 2, end = terminator.getOperandCount(); i < end; - id += 2) { - if (routeSuccessors.contains(id + 1)) { - if (shouldReplaceDefault) { - auto value = terminator.eraseOperand(i); - auto successor = terminator.eraseOperand(i); - - condValueToSucc[value] = - successor.getAsValue().staticCast(); - - continue; - } - - condValueToSucc[terminator.getOperand(i)] = - terminator.getOperand(i + 1).getAsValue().staticCast(); - - terminator.replaceOperand(i + 1, route); - } - - i += 2; - } - - if (shouldReplaceDefault) { - terminator.replaceOperand(1, route); - } - } else { - if (routeSuccessors.contains(1)) { - condValueToSucc[context.getTrue()] = - terminator.getOperand(1).getAsValue().staticCast(); - terminator.replaceOperand(1, route); - } else { - assert(routeSuccessors.contains(2)); - condValueToSucc[context.getFalse()] = - terminator.getOperand(2).getAsValue().staticCast(); - terminator.replaceOperand(2, route); - } - } - - if (routePhi) { - auto boolType = context.getTypeBool(); - auto builder = Builder::createInsertBefore(context, terminator); - - ir::Value selector; - - if (defaultSucc) { - selector = getSuccessorId(defaultSucc); - } - - auto selectorType = - toPreds.size() == 2 ? boolType : context.getTypeUInt32(); - for (auto &[value, to] : condValueToSucc) { - if (!selector) { - selector = getSuccessorId(to); - } else { - auto valueId = value.getAsValue(); - if (!valueId) { - valueId = context.imm32(*value.getAsInt32()); - } - - ir::Value selectionCond; - - if (condType == boolType) { - selectionCond = builder.createSpvLogicalEqual( - terminator.getLocation(), boolType, cond, valueId); - } else { - selectionCond = builder.createSpvIEqual(terminator.getLocation(), - boolType, cond, valueId); - } - selector = builder.createSpvSelect(terminator.getLocation(), - selectorType, selectionCond, - getSuccessorId(to), selector); - } - } - - routePhi.addOperand(selector); - routePhi.addOperand(patchBlock); - } - - if (keepSuccessors == 0) { - terminator.remove(); - } - } - - for (auto &[to, preds] : toPreds) { - if (toPreds.size() > 1) { - auto successorId = getSuccessorId(to); - - for (auto from : preds) { - // branches already resolved - if (patchPredecessors.contains(from)) { - continue; - } - - routePhi.addOperand(successorId); - routePhi.addOperand(from); - } - } - - for (auto from : preds) { - if (patchPredecessors.contains(from)) { - continue; - } - - replaceTerminatorTarget(getTerminator(from), to, route); - } - - if (toAllPreds.at(to).size() == preds.size()) { - // all predecessors will be replaced, move phi nodes - - for (auto phi : ir::range(ir::Block(to).getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - phi.erase(); - route.prependChild(phi); - - if (preds.size() != edges.size()) { - // route block has additional edges. add dummy nodes to phi, this - // block not reachable from new predecessors anyway - - auto undef = context.getUndef(phi.getOperand(0).getAsValue()); - - for (auto edge : edges) { - if (!preds.contains(edge.from())) { - phi.addOperand(undef); - phi.addOperand(edge.from()); - } - } - } - } - - continue; - } - - if (preds.size() == 1) { - auto pred = *preds.begin(); - for (auto phi : ir::range(ir::Block(to).getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) { - auto label = phi.getOperand(i).getAsValue(); - - if (label == pred) { - phi.replaceOperand(i, route); - } - } - } - - continue; - } - - // partial predecessors replacement, update PHIs - - for (auto phi : ir::range(ir::Block(to).getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - auto newPhi = - Builder::createPrepend(context, route) - .createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue()); - - for (std::size_t i = 1; i < phi.getOperandCount();) { - // auto value = phi.getOperand(i).getAsValue(); - auto label = phi.getOperand(i + 1).getAsValue().cast(); - - if (preds.contains(label)) { - newPhi.addOperand(phi.eraseOperand(i)); - newPhi.addOperand(phi.eraseOperand(i)); - } else { - i += 2; - } - } - - phi.addOperand(newPhi); - phi.addOperand(route); - - if (preds.size() != edges.size()) { - // merge block has additional edges. add dummy nodes to phi, this - // block not reachable from new blocks - - auto dummyValue = phi.getOperand(1).getAsValue(); - - for (auto edge : edges) { - if (!preds.contains(edge.from())) { - phi.addOperand(dummyValue); - phi.addOperand(edge.from()); - } - } - } - } - } - - return route; -} - -static ir::Value transformToCanonicalRegion(spv::Context &context, - ir::RegionLike region) { - auto cfg = buildCFG(region.getFirst()); - std::vector exitNodes; - for (auto node : cfg.getPreorderNodes()) { - if (!node->hasSuccessors()) { - exitNodes.push_back(node); - } - } - - if (cfg.getEntryNode()->hasPredecessors()) { - auto builder = Builder::createPrepend(context, region); - auto prevEntry = cfg.getEntryLabel(); - auto newEntry = builder.createSpvLabel(prevEntry.getLocation()); - builder.createSpvBranch(prevEntry.getLocation(), prevEntry); - - for (auto it = prevEntry.getNext(); it && it == ir::spv::OpVariable;) { - auto moveInst = it; - it = it.getNext(); - - moveInst.erase(); - region.insertAfter(newEntry, moveInst); - } - } - - if (exitNodes.empty()) { - region.print(std::cerr, context.ns); - rx::die("scfg: cfg without termination block"); - } - - if (exitNodes.size() == 1) { - return exitNodes.back()->getLabel(); - } - - ir::Value returnType; - ir::Instruction returnInst; - - for (auto exitNode : exitNodes) { - auto terminator = exitNode->getTerminator(); - - if (terminator && terminator == ir::spv::OpReturnValue) { - auto terminatorReturnValue = terminator.getOperand(0).getAsValue(); - auto terminatorReturnType = - terminatorReturnValue.getOperand(0).getAsValue(); - if (returnType && terminatorReturnType == returnType) { - rx::die("scfg: unexpected terminator return type"); - } else { - returnType = terminatorReturnType; - } - } - - if (terminator) { - if (returnInst && returnInst.getInstId() != terminator.getInstId()) { - returnInst.print(std::cerr, context.ns); - std::cerr << '\n'; - terminator.print(std::cerr, context.ns); - std::cerr << '\n'; - rx::die("scfg: unexpected return instruction kind change"); - } else { - returnInst = terminator; - } - } - } - - if (returnType) { - auto variablePointerType = - context.getTypePointer(ir::spv::StorageClass::Function, returnType); - - auto returnValueVariable = - Builder::createInsertAfter(context, region.getFirst()) - .createSpvVariable(context.getUnknownLocation(), - variablePointerType, - ir::spv::StorageClass::Function); - - auto newExitBlock = [&] { - auto loc = context.getUnknownLocation(); - auto builder = Builder::createAppend(context, region); - auto newExitBlock = builder.createSpvLabel(loc); - - auto mergedReturnValue = - builder.createSpvLoad(loc, returnType, returnValueVariable); - builder.createSpvReturnValue(loc, mergedReturnValue); - return newExitBlock; - }(); - - for (auto exitNode : exitNodes) { - auto terminator = exitNode->getTerminator(); - - if (terminator) { - auto newTerminator = Builder::createInsertAfter(context, terminator); - - newTerminator.createSpvStore(terminator.getLocation(), - returnValueVariable, - terminator.getOperand(0).getAsValue()); - newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock); - terminator.erase(); - } - } - - return newExitBlock; - } - - if (!returnInst) { - rx::die("scfg: unexpected cfg terminator"); - } - - auto newExitBlock = Builder::createAppend(context, region) - .createSpvLabel(context.getUnknownLocation()); - - for (auto exitNode : exitNodes) { - auto terminator = exitNode->getTerminator(); - - if (terminator) { - auto newTerminator = Builder::createInsertAfter(context, terminator); - newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock); - terminator.erase(); - } - } - - region.insertAfter(newExitBlock, returnInst); - return newExitBlock; -} - -static void transformToCf(spv::Context &context, ir::RegionLike region) { - ir::Block currentBlock; - - for (auto inst : region.children()) { - if (inst == ir::builtin::BLOCK) { - continue; - } - - if (inst == ir::spv::OpLabel) { - currentBlock = Builder::createInsertBefore(context, inst) - .createBlock(inst.getLocation()); - - if (auto name = context.ns.tryGetNameOf(inst); !name.empty()) { - context.ns.setNameOf(currentBlock, std::string(name)); - } - - inst.staticCast().replaceAllUsesWith(currentBlock); - inst.remove(); - continue; - } - - if (!currentBlock) { - inst.print(std::cerr, context.ns); - std::cerr << "\n"; - region.print(std::cerr, context.ns); - std::cerr << "\n"; - rx::die("cfg: node without label"); - } - - inst.erase(); - currentBlock.addChild(inst); - - if (isTerminator(inst)) { - currentBlock = nullptr; - } - } -} - -static void transformToFlat(spv::Context &context, ir::RegionLike region) { - std::vector workList; - - workList.push_back(region.getFirst()); - - auto insertPoint = Builder::createPrepend(context, region); - - while (!workList.empty()) { - auto inst = workList.back(); - - workList.pop_back(); - - if (inst == nullptr) { - continue; - } - - auto unwrapBlock = [&](ir::Block block) { - if (auto construct = block.cast()) { - auto merge = construct.getMerge(); - auto cont = construct.getContinue().getHeader(); - auto body = construct.getHeader(); - - auto blockLabel = insertPoint.createSpvLabel(block.getLocation()); - construct.replaceAllUsesWith(blockLabel); - - if (auto name = context.ns.tryGetNameOf(block); !name.empty()) { - context.ns.setNameOf(blockLabel, std::string(name)); - } - - for (auto phi : ir::range(construct.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - insertPoint.eraseAndInsert(phi); - } - - insertPoint.createSpvLoopMerge(construct.getLocation(), merge, cont, - ir::spv::LoopControl::None()); - insertPoint.createSpvBranch(construct.getLocation(), body); - - workList.emplace_back(cont); - workList.emplace_back(construct.getFirst()); - return; - } - - if (auto construct = block.cast()) { - auto constructBody = construct.getHeader(); - - auto header = ir::InsertionPoint::createPrepend(constructBody); - auto merge = construct.getMerge(); - - for (auto phi : ir::range(construct.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - for (std::size_t i = 1; i < phi.getOperandCount();) { - if (phi.getOperand(i + 1) == construct) { - phi.eraseOperand(i); - phi.eraseOperand(i); - } else { - i += 2; - } - } - - header.eraseAndInsert(phi); - } - - Builder::createInsertBefore(context, constructBody.getLast()) - .createSpvSelectionMerge(construct.getLocation(), merge, - ir::spv::SelectionControl::None); - - construct.replaceAllUsesWith(constructBody); - workList.emplace_back(constructBody); - return; - } - - auto blockLabel = insertPoint.createSpvLabel(block.getLocation()); - - block.replaceAllUsesWith(blockLabel); - - workList.emplace_back(block.getFirst()); - - if (auto name = context.ns.tryGetNameOf(block); !name.empty()) { - context.ns.setNameOf(blockLabel, std::string(name)); - } - }; - - if (auto next = inst.getNext()) { - workList.push_back(next); - } - - if (auto block = inst.cast()) { - std::cout << "processing " << context.ns.getNameOf(block) << "\n"; - unwrapBlock(block); - block.erase(); - continue; - } - - insertPoint.eraseAndInsert(inst); - } -} - -bool isParentConstruct(ir::RegionLike parent, ir::RegionLike construct) { - while (parent != construct && construct) { - construct = construct.getParent(); - } - - return parent == construct; -} - -static ir::LoopConstruct -createLoopConstruct(spv::Context &context, ir::RegionLike parentConstruct, - ir::Block header, ir::Block latch, ir::Block cont, - ir::Block merge, - const std::unordered_set &scc) { - auto continueConstruct = - Builder::createInsertAfter(context, header) - .createContinueConstruct(header.getLocation(), cont, header); - - auto loopConstruct = Builder::createInsertBefore(context, header) - .createLoopConstruct(header.getLocation(), header, - merge, continueConstruct); - - continueConstruct.erase(); - - header.erase(); - loopConstruct.addChild(header); - - std::vector workList; - workList.emplace_back(header); - - while (!workList.empty()) { - ir::Block block = workList.back(); - workList.pop_back(); - - block.erase(); - loopConstruct.addChild(block); - - std::unordered_set successors; - if (isConstruct(block)) { - successors = {getConstructMergeBlock(block)}; - } else { - successors = getSuccessors(block); - } - - for (auto succ : successors) { - if (succ == merge || succ.getParent() != parentConstruct || - !scc.contains(succ)) { - continue; - } - - workList.push_back(succ); - } - } - - latch.erase(); - loopConstruct.addChild(latch); - - cont.erase(); - continueConstruct.addChild(cont); - - merge.erase(); - loopConstruct.getParent().insertAfter(loopConstruct, merge); - - return loopConstruct; -} - -static ir::SelectionConstruct -createSelectionConstruct(spv::Context &context, ir::RegionLike parentConstruct, - const std::unordered_set &components, - ir::Block header, ir::Block merge) { - auto selectionConstruct = - Builder::createInsertBefore(context, header) - .createSelectionConstruct(header.getLocation(), header, merge); - - std::vector workList; - workList.emplace_back(header); - - while (!workList.empty()) { - ir::Block block = workList.back(); - workList.pop_back(); - - block.erase(); - selectionConstruct.addChild(block); - - std::unordered_set successors; - if (auto construct = block.cast()) { - successors = {construct.getMerge()}; - } else { - successors = getSuccessors(block); - } - - for (auto succ : successors) { - if (succ == merge || succ.getParent() != parentConstruct || - !components.contains(succ)) { - continue; - } - - workList.push_back(succ); - } - } - - merge.erase(); - selectionConstruct.getParent().insertAfter(selectionConstruct, merge); - - return selectionConstruct; -} - -static void wrapLoopConstructs(spv::Context &context, ir::RegionLike root) { - auto region = root.children(); - auto sccs = findSCCs(region); - - for (auto scc : sccs) { - auto edges = calculateCycleEdges(scc); - - ir::Block bodyLabel; - ir::Block continueLabel; - ir::Block mergeLabel; - ir::Block latchLabel; - - if (!edges.entryEdges.empty()) { - if (edges.entryEdges.size() == 1 && edges.backEdges.size() == 1 && - edges.entryEdges[0].to() == edges.backEdges[0].to()) { - bodyLabel = edges.entryEdges[0].to(); - continueLabel = edges.backEdges[0].from(); - } - - if (!bodyLabel) { - std::vector entryEdges = edges.entryEdges; - // back edges should jump to entry block - entryEdges.insert(entryEdges.end(), edges.backEdges.begin(), - edges.backEdges.end()); - - // for loop no need to split blocks, we can just rotate loop - bodyLabel = createRouteBlock( - context, ir::InsertionPoint::createInsertBefore(*scc.begin()), - entryEdges); - scc.insert(bodyLabel); - edges = calculateCycleEdges(scc); - } - - if (!continueLabel || bodyLabel == continueLabel || - getSuccessorCount(continueLabel) != 1) { - - std::unordered_set preds; - for (auto edge : edges.backEdges) { - preds.insert(edge.from()); - } - continueLabel = createMergeBlock( - context, ir::InsertionPoint::createInsertAfter(bodyLabel), preds, - bodyLabel); - scc.insert(continueLabel); - edges = calculateCycleEdges(scc); - } - } - - if (!edges.exitEdges.empty()) { - mergeLabel = [&] -> ir::Block { - auto exitEdges = std::span(edges.exitEdges); - auto header = exitEdges[0].to(); - exitEdges = exitEdges.subspan(1); - - while (!exitEdges.empty()) { - if (header != exitEdges[0].to()) { - return {}; - } - - exitEdges = exitEdges.subspan(1); - } - - return header; - }(); - - if (mergeLabel) { - auto predecessors = getPredecessors(mergeLabel); - - for (auto pred : predecessors) { - if (!scc.contains(pred)) { - mergeLabel = {}; - break; - } - } - - if (mergeLabel && predecessors.size() == 1) { - latchLabel = *predecessors.begin(); - - auto latchSuccessors = getSuccessors(latchLabel); - - auto it = latchSuccessors.begin(); - auto firstSuccessor = *it; - auto secondSuccessor = *++it; - - if ((firstSuccessor != continueLabel && - secondSuccessor != continueLabel)) { - latchLabel = {}; - mergeLabel = {}; - } - - if (latchLabel && getPredecessorCount(continueLabel) != 1) { - latchLabel = {}; - } - } - } - - if (!mergeLabel) { - mergeLabel = createRouteBlock( - context, - ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()), - edges.exitEdges); - - edges = calculateCycleEdges(scc); - } - - if (!latchLabel) { - std::vector exitEdges = edges.exitEdges; - - for (auto [pred, operandIndex] : getAllPredecessors(continueLabel)) { - exitEdges.emplace_back(pred, operandIndex); - } - - latchLabel = createRouteBlock( - context, - ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()), - exitEdges); - scc.insert(latchLabel); - } - } - - if (bodyLabel && continueLabel && mergeLabel) { - auto loopConstruct = createLoopConstruct( - context, root, bodyLabel, latchLabel, continueLabel, mergeLabel, scc); - - // replace references to body outside this construct with header (i.e. - // loop construct node) - bodyLabel.replaceUsesIf(loopConstruct, [=](ir::ValueUse use) { - return (isTerminator(use.user) || - (use.user != loopConstruct && isConstruct(use.user))) && - getConstructOf(use.user) != loopConstruct; - }); - - // move PHIs to construct - for (auto phi : ir::range(bodyLabel.getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - phi.erase(); - loopConstruct.prependChild(phi); - } - } - } -} - -static void wrapSelectionConstructs(spv::Context &context, - ir::RegionLike root) { - std::vector> workList; - workList.push_back(root.children()); - std::unordered_set usedMergeBlocks; - - while (!workList.empty()) { - auto region = workList.back(); - workList.pop_back(); - - for (auto entryBlock : region) { - if (isConstruct(entryBlock)) { - if (entryBlock == ir::builtin::SELECTION_CONSTRUCT) { - if (auto body = - skipPhis(entryBlock.getFirst()).getNext().cast()) { - workList.emplace_back(ir::range(body)); - } - } else if (auto body = - skipPhis(entryBlock.getFirst()).cast()) { - workList.emplace_back(ir::range(body)); - } - continue; - } - - auto terminator = entryBlock.getLast(); - if (!terminator || !isTerminator(terminator) || - (terminator != ir::spv::OpBranchConditional && - terminator != ir::spv::OpSwitch)) { - continue; - } - - ir::RegionLike parentConstruct = getConstructOf(entryBlock); - - if (auto parentSelection = - parentConstruct.cast()) { - if (parentSelection.getHeader() == entryBlock) { - continue; - } - } - - auto successors = getSuccessors(entryBlock); - - if (parentConstruct) { - if (parentConstruct.getLast() == entryBlock) { - // do not look at latch/continuation blocks - continue; - } - } - - if (!parentConstruct) { - parentConstruct = root; - } - - std::unordered_set components; - components.insert(entryBlock); - - auto addConstructComponent = [&](ir::Construct construct) { - components.insert(construct); - - // add whole body of construct - for (auto child : construct.children()) { - components.insert(child); - } - - if (auto loop = construct.cast()) { - // it if is loop, add continue construct also - for (auto child : loop.getContinue().children()) { - components.insert(child); - } - } - - auto constructMerge = construct.getMerge(); - if (parentConstruct != root && - getSuccessorCount(construct.getMerge()) == 0) { - // we cannot take this merge block, it is exit from function block - // create trampoline node and replace merge block of this node - - auto newMerge = createMergeBlock( - context, ir::InsertionPoint::createInsertBefore(constructMerge), - getPredecessors(constructMerge), constructMerge); - - construct.setMerge(newMerge); - constructMerge = newMerge; - } - - components.insert(constructMerge); - - return getSuccessors(constructMerge); - }; - - auto addComponent = [&](ir::Block block) { - if (auto construct = block.cast()) { - return addConstructComponent(construct); - } - - if (hasAtLeastSuccessors(block, 1)) { - components.insert(block); - return getSuccessors(block); - } - - auto trampoline = createMergeBlock( - context, ir::InsertionPoint::createInsertBefore(block), - getPredecessors(block), block); - - components.insert(trampoline); - return getSuccessors(block); - }; - - { - // try to find blocks that has no other predecessors - - auto parentEntry = - skipPhis(parentConstruct.getFirst()).staticCast(); - - auto headerSuccessors = getSuccessors(entryBlock); - - std::vector workList(headerSuccessors.begin(), - headerSuccessors.end()); - while (!workList.empty()) { - auto block = workList.back(); - workList.pop_back(); - - if (components.contains(block)) { - continue; - } - - if (block.getParent() != parentConstruct) { - continue; - } - - if (block == parentEntry || block == parentConstruct.getLast()) { - // do not take entry/latch/continuation of parent construct - continue; - } - - bool hasAllPreds = true; - auto loop = block.cast(); - for (auto pred : getPredecessors(block)) { - if (components.contains(pred)) { - continue; - } - - if (loop && pred == loop.getContinue().getLast()) { - // ignore continue predecessor of loop - continue; - } - - hasAllPreds = false; - break; - } - - if (hasAllPreds) { - addComponent(block); - } - } - } - - if (components.size() == 1) { - // all successors are used by nodes outside this header, it means it is - // not structured loop node or case block of OpSwitch with fallthrough - continue; - } - - ir::Block entryLabel = entryBlock; - ir::Block mergeLabel; - bool mergeInserted = false; - - std::unordered_set exitBlocks; - std::vector exitEdges; - for (auto block : components) { - for (auto [succ, operandIndex] : getAllSuccessors(block)) { - if (!components.contains(succ)) { - exitEdges.emplace_back(block, operandIndex); - exitBlocks.insert(block); - } - } - } - - if (!exitBlocks.empty()) { - if (exitBlocks.size() == 1) { - mergeLabel = *exitBlocks.begin(); - } - - if (!mergeLabel || - getAllPredecessors(mergeLabel).size() != exitEdges.size() || - isConstruct(mergeLabel)) { - mergeLabel = createRouteBlock( - context, ir::InsertionPoint::createInsertAfter(entryBlock), - exitEdges); - - workList.emplace_back(ir::range(mergeLabel)); - mergeInserted = true; - } - } else { - mergeLabel = parentConstruct.getLast().staticCast(); - } - - if (!mergeInserted) { - for (auto user : mergeLabel.getUserList()) { - if (auto construct = user.cast()) { - if (construct.getMerge() != mergeLabel) { - continue; - } - } - mergeLabel = createMergeBlock( - context, ir::InsertionPoint::createInsertBefore(mergeLabel), - getPredecessors(mergeLabel), mergeLabel); - mergeInserted = true; - break; - } - } - - if (!mergeInserted) { - auto mergePreds = getPredecessors(mergeLabel); - std::unordered_set branchesInsideConstruct; - - for (auto pred : mergePreds) { - if (components.contains(pred)) { - branchesInsideConstruct.insert(pred); - } - } - - if (branchesInsideConstruct.size() != mergePreds.size()) { - mergeLabel = createMergeBlock( - context, ir::InsertionPoint::createInsertBefore(mergeLabel), - branchesInsideConstruct, mergeLabel); - } - } - - auto construct = createSelectionConstruct( - context, parentConstruct, components, entryLabel, mergeLabel); - - // update merge label - construct.setMerge(mergeLabel); - - construct.getHeader().replaceUsesIf(construct, [=](ir::ValueUse use) { - if (getConstructOf(use.user) != construct) { - if (isTerminator(use.user)) { - return true; - } - - // allow update block merges - if (isConstruct(use.user) && use.operandIndex == 1) { - return true; - } - } - - if (use.user != construct && isConstruct(use.user)) { - return true; - } - - return false; - }); - - // move PHIs to construct - for (auto phi : ir::range(construct.getHeader().getFirst())) { - if (phi != ir::spv::OpPhi) { - break; - } - - phi.erase(); - construct.prependChild(phi); - } - - // view child constructs - if (auto child = construct.getHeader().getNext().cast()) { - workList.emplace_back(ir::range(child)); - } - - // view next constructs - if (auto next = construct.getNext()) { - workList.emplace_back(ir::range(next)); - } - - break; - } - } -} void shader::structurizeCfg(spv::Context &context, ir::RegionLike region) { // std::cerr << "before transforms: "; diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/construct.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/construct.cpp new file mode 100644 index 000000000..e160595a5 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/construct.cpp @@ -0,0 +1,154 @@ +#include "SpvConverter.hpp" +#include "analyze.hpp" +#include "transform/construct.hpp" +#include "dialect.hpp" +#include + +using namespace shader; +using namespace shader::transform; + +using Builder = ir::Builder; + +bool shader::transform::isConstruct(ir::Instruction block) { + return block == ir::builtin::LOOP_CONSTRUCT || + block == ir::builtin::SELECTION_CONSTRUCT || + block == ir::builtin::CONTINUE_CONSTRUCT; +} + +ir::Block shader::transform::getConstructOf(ir::Instruction inst) { + auto block = inst.cast(); + if (block && isConstruct(block)) { + block = block.getParent().cast(); + } + + while (block) { + if (isConstruct(block)) { + return block; + } + + block = block.getParent().cast(); + } + + return {}; +} + +ir::Block shader::transform::getConstructMergeBlock(ir::Block block) { + if (auto construct = block.cast()) { + return construct.getMerge(); + } + + return {}; +} + +bool shader::transform::isParentConstruct(ir::RegionLike parent, + ir::RegionLike construct) { + while (parent != construct && construct) { + construct = construct.getParent(); + } + + return parent == construct; +} + + +ir::SelectionConstruct +shader::transform::createSelectionConstruct(spv::Context &context, + ir::RegionLike parentConstruct, + const std::unordered_set &components, + ir::Block header, + ir::Block merge) { + auto selectionConstruct = + Builder::createInsertBefore(context, header) + .createSelectionConstruct(header.getLocation(), header, merge); + + std::vector workList; + workList.emplace_back(header); + + while (!workList.empty()) { + ir::Block block = workList.back(); + workList.pop_back(); + + block.erase(); + selectionConstruct.addChild(block); + + std::unordered_set successors; + if (auto construct = block.cast()) { + successors = {construct.getMerge()}; + } else { + successors = getSuccessors(block); + } + + for (auto succ : successors) { + if (succ == merge || succ.getParent() != parentConstruct || + !components.contains(succ)) { + continue; + } + + workList.push_back(succ); + } + } + + merge.erase(); + selectionConstruct.getParent().insertAfter(selectionConstruct, merge); + + return selectionConstruct; +} + +ir::LoopConstruct +shader::transform::createLoopConstruct(spv::Context &context, + ir::RegionLike parentConstruct, + ir::Block header, + ir::Block latch, + ir::Block cont, + ir::Block merge, + const std::unordered_set &scc) { + auto continueConstruct = + Builder::createInsertAfter(context, header) + .createContinueConstruct(header.getLocation(), cont, header); + + auto loopConstruct = Builder::createInsertBefore(context, header) + .createLoopConstruct(header.getLocation(), header, + merge, continueConstruct); + + continueConstruct.erase(); + + header.erase(); + loopConstruct.addChild(header); + + std::vector workList; + workList.emplace_back(header); + + while (!workList.empty()) { + ir::Block block = workList.back(); + workList.pop_back(); + + block.erase(); + loopConstruct.addChild(block); + + std::unordered_set successors; + if (isConstruct(block)) { + successors = {getConstructMergeBlock(block)}; + } else { + successors = getSuccessors(block); + } + + for (auto succ : successors) { + if (succ == merge || succ.getParent() != parentConstruct || + !scc.contains(succ)) { + continue; + } + + workList.push_back(succ); + } + } + + latch.erase(); + loopConstruct.addChild(latch); + + cont.erase(); + continueConstruct.addChild(cont); + + merge.erase(); + loopConstruct.getParent().insertAfter(loopConstruct, merge); + + return loopConstruct; +} diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/merge.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/merge.cpp new file mode 100644 index 000000000..d0d1498af --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/merge.cpp @@ -0,0 +1,77 @@ +#include "SpvConverter.hpp" +#include "transform/merge.hpp" +#include "analyze.hpp" +#include "transform/replace.hpp" +#include "dialect.hpp" +#include + +using namespace shader; +using namespace shader::transform; + +using Builder = ir::Builder; + +ir::Block shader::transform::createMergeBlock(spv::Context &context, + ir::InsertionPoint insertPoint, + const std::unordered_set &preds, + ir::Block to) { + rx::dieIf(preds.empty(), "createMergeBlock: unexpected edges count"); + + auto loc = to.getLocation(); + + auto mergeBlock = Builder::create(context, insertPoint).createBlock(loc); + Builder::createAppend(context, mergeBlock).createSpvBranch(loc, to); + + if (preds.size() == getPredecessorCount(to)) { + for (auto phi : ir::range(to.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + phi.erase(); + mergeBlock.prependChild(phi); + } + } else if (preds.size() == 1) { + auto pred = *preds.begin(); + for (auto phi : ir::range(to.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) { + if (phi.getOperand(i) == pred) { + phi.replaceOperand(i, mergeBlock); + } + } + } + } else { + for (auto phi : ir::range(to.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + auto newPhi = + Builder::createPrepend(context, mergeBlock) + .createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue()); + + for (std::size_t i = 1; i < phi.getOperandCount();) { + // auto value = phi.getOperand(i).getAsValue(); + auto label = phi.getOperand(i + 1).getAsValue().staticCast(); + if (preds.contains(label)) { + newPhi.addOperand(phi.eraseOperand(i)); + newPhi.addOperand(phi.eraseOperand(i)); + } else { + i += 2; + } + } + + phi.addOperand(newPhi); + phi.addOperand(mergeBlock); + } + } + + for (auto pred : preds) { + replaceTerminatorTarget(getTerminator(pred), to, mergeBlock); + } + + return mergeBlock; +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/replace.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/replace.cpp new file mode 100644 index 000000000..8402333ec --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/replace.cpp @@ -0,0 +1,36 @@ +#include "transform/replace.hpp" +#include "dialect.hpp" +#include + +using namespace shader; +using namespace shader::transform; + +void shader::transform::replaceTerminatorTarget(ir::Instruction terminator, + int operandIndex, ir::Value newTarget) { + auto prevTarget = terminator.getOperand(operandIndex).getAsValue(); + terminator.replaceOperand(operandIndex, newTarget); + auto selection = terminator.getPrev(); + + if (selection == ir::spv::OpSelectionMerge || + selection == ir::spv::OpLoopMerge) { + for (std::size_t i = 0, end = selection.getOperandCount(); i < end; ++i) { + if (selection.getOperand(i) == prevTarget) { + selection.replaceOperand(i, newTarget); + break; + } + } + } +} + +bool shader::transform::replaceTerminatorTarget(ir::Instruction terminator, + ir::Value oldTarget, ir::Value newTarget) { + bool changes = false; + for (std::size_t i = 0, end = terminator.getOperandCount(); i < end; ++i) { + if (terminator.getOperand(i) == oldTarget) { + replaceTerminatorTarget(terminator, i, newTarget); + changes = true; + } + } + + return changes; +} \ No newline at end of file diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/route.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/route.cpp new file mode 100644 index 000000000..625a55732 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/route.cpp @@ -0,0 +1,428 @@ + +#include "transform/route.hpp" +#include "transform/merge.hpp" +#include "SpvConverter.hpp" +#include "analyze.hpp" +#include "dialect.hpp" +#include +#include +#include +#include +#include + +using namespace shader; +using namespace shader::transform; + +using Builder = ir::Builder; + +// Data structures for route block creation +struct RouteBlockData { + std::unordered_map> fromSucc; + std::unordered_map> toPreds; + std::unordered_map> toAllPreds; + std::unordered_set patchPredecessors; +}; + +// Analyze edges and build routing data structures +static RouteBlockData analyzeEdges(const std::vector &edges) { + RouteBlockData data; + std::unordered_set routePredecessors; + + for (auto edge : edges) { + if (!routePredecessors.insert(edge.from()).second) { + data.patchPredecessors.insert(edge.from()); + } + + data.toPreds[edge.to()].emplace(edge.from()); + data.fromSucc[edge.from()].emplace(edge.operandIndex()); + } + + for (auto &[to, preds] : data.toPreds) { + data.toAllPreds[to] = getPredecessors(to); + } + + return data; +} + +// Create route block with appropriate phi node +static std::pair createRouteBlockWithPhi( + spv::Context &context, ir::InsertionPoint insertPoint, + ir::Location loc, size_t predsCount) { + auto route = Builder::create(context, insertPoint).createBlock(loc); + ir::Value routePhi; + + if (predsCount > 1) { + routePhi = Builder::createPrepend(context, route) + .createSpvPhi(loc, predsCount == 2 + ? context.getTypeBool() + : context.getTypeUInt32()); + } + + return {route, routePhi}; +} + +// Create terminator based on number of successors +static std::unordered_map createRouteTerminator( + spv::Context &context, ir::Block route, ir::Value routePhi, + ir::Location loc, + const std::unordered_map> + &toPreds) { + std::unordered_map successorToId; + + if (toPreds.size() == 1) { + // Single successor: unconditional branch + Builder::createAppend(context, route) + .createSpvBranch(loc, toPreds.begin()->first); + } else if (toPreds.size() == 2) { + // Two successors: conditional branch + auto it = toPreds.begin(); + auto firstSuccessor = it->first; + auto secondSuccessor = (++it)->first; + + Builder::createAppend(context, route) + .createSpvBranchConditional(loc, routePhi, firstSuccessor, + secondSuccessor); + } else { + // Multiple successors: switch statement + auto routeSwitch = Builder::createAppend(context, route) + .createSpvSwitch(loc, routePhi, toPreds.begin()->first); + + successorToId.reserve(toPreds.size()); + + for (std::uint32_t id = 0; auto &[succ, pred] : toPreds) { + if (id) { + routeSwitch.addOperand(id); + routeSwitch.addOperand(succ); + } + successorToId[succ] = id++; + } + } + + return successorToId; +} + +// Get successor ID based on routing strategy +static ir::Value getSuccessorIdValue( + spv::Context &context, ir::Block successor, + const std::unordered_map> + &toPreds, + const std::unordered_map &successorToId) { + if (toPreds.size() == 2) { + return context.getBool(successor == toPreds.begin()->first); + } + return context.imm32(successorToId.at(successor)); +} + +// Process single predecessor block that needs patching +static void patchPredecessorBlock( + spv::Context &context, ir::Block patchBlock, ir::Block route, + ir::Value routePhi, const RouteBlockData &data, + const std::unordered_map> &toPreds, + const std::function &getSuccessorId) { + + auto predSuccessors = getAllSuccessors(patchBlock); + auto terminator = getTerminator(patchBlock); + auto &routeSuccessors = data.fromSucc.at(patchBlock); + + int keepSuccessors = predSuccessors.size() - routeSuccessors.size(); + + assert(keepSuccessors >= 0); + assert(terminator == ir::spv::OpSwitch || + terminator == ir::spv::OpBranchConditional); + + auto cond = terminator.getOperand(0).getAsValue(); + auto condType = cond.getOperand(0).getAsValue(); + std::map condValueToSucc; + ir::Block defaultSucc; + + if (keepSuccessors == 0) { + // we are going to replace all successors of this block, create direct + // jump to route block + Builder::createInsertAfter(context, terminator) + .createSpvBranch(terminator.getLocation(), route); + + if (terminator == ir::spv::OpBranchConditional) { + condValueToSucc[context.getTrue()] = + terminator.getOperand(1).getAsValue().staticCast(); + condValueToSucc[context.getFalse()] = + terminator.getOperand(2).getAsValue().staticCast(); + } else if (terminator == ir::spv::OpSwitch) { + defaultSucc = + terminator.getOperand(1).getAsValue().staticCast(); + + for (int i = 2, end = terminator.getOperandCount(); i < end; i += 2) { + condValueToSucc[terminator.getOperand(i)] = + terminator.getOperand(i + 1).getAsValue().staticCast(); + } + } + } else if (terminator == ir::spv::OpSwitch) { + if (routeSuccessors.contains(1)) { + defaultSucc = + terminator.getOperand(1).getAsValue().staticCast(); + } + + bool shouldReplaceDefault = defaultSucc != nullptr; + + for (int i = 2, id = 2, end = terminator.getOperandCount(); i < end; + id += 2) { + if (routeSuccessors.contains(id + 1)) { + if (shouldReplaceDefault) { + auto value = terminator.eraseOperand(i); + auto successor = terminator.eraseOperand(i); + + condValueToSucc[value] = + successor.getAsValue().staticCast(); + + continue; + } + + condValueToSucc[terminator.getOperand(i)] = + terminator.getOperand(i + 1).getAsValue().staticCast(); + + terminator.replaceOperand(i + 1, route); + } + + i += 2; + } + + if (shouldReplaceDefault) { + terminator.replaceOperand(1, route); + } + } else { + if (routeSuccessors.contains(1)) { + condValueToSucc[context.getTrue()] = + terminator.getOperand(1).getAsValue().staticCast(); + terminator.replaceOperand(1, route); + } else { + assert(routeSuccessors.contains(2)); + condValueToSucc[context.getFalse()] = + terminator.getOperand(2).getAsValue().staticCast(); + terminator.replaceOperand(2, route); + } + } + + if (routePhi) { + auto boolType = context.getTypeBool(); + auto builder = Builder::createInsertBefore(context, terminator); + + ir::Value selector; + + if (defaultSucc) { + selector = getSuccessorId(defaultSucc); + } + + auto selectorType = + toPreds.size() == 2 ? boolType : context.getTypeUInt32(); + for (auto &[value, to] : condValueToSucc) { + if (!selector) { + selector = getSuccessorId(to); + } else { + auto valueId = value.getAsValue(); + if (!valueId) { + valueId = context.imm32(*value.getAsInt32()); + } + + ir::Value selectionCond; + + if (condType == boolType) { + selectionCond = builder.createSpvLogicalEqual( + terminator.getLocation(), boolType, cond, valueId); + } else { + selectionCond = builder.createSpvIEqual(terminator.getLocation(), + boolType, cond, valueId); + } + selector = builder.createSpvSelect(terminator.getLocation(), + selectorType, selectionCond, + getSuccessorId(to), selector); + } + } + + routePhi.addOperand(selector); + routePhi.addOperand(patchBlock); + } + + if (keepSuccessors == 0) { + terminator.remove(); + } +} + +// Move all phi nodes from target to route block +static void moveAllPhiNodes(spv::Context &context, ir::Block to, ir::Block route, + const std::unordered_set &preds, + const std::vector &edges) { + for (auto phi : ir::range(ir::Block(to).getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + phi.erase(); + route.prependChild(phi); + + if (preds.size() != edges.size()) { + // route block has additional edges. add dummy nodes to phi, this + // block not reachable from new predecessors anyway + + auto undef = context.getUndef(phi.getOperand(0).getAsValue()); + + for (auto edge : edges) { + if (!preds.contains(edge.from())) { + phi.addOperand(undef); + phi.addOperand(edge.from()); + } + } + } + } +} + +// Update phi nodes for single predecessor +static void updatePhiNodesForSinglePred(ir::Block to, ir::Block pred, + ir::Block route) { + for (auto phi : ir::range(ir::Block(to).getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) { + auto label = phi.getOperand(i).getAsValue(); + + if (label == pred) { + phi.replaceOperand(i, route); + } + } + } +} + +// Update phi nodes for partial predecessor replacement +static void updatePhiNodesPartial(spv::Context &context, ir::Block to, + ir::Block route, + const std::unordered_set &preds, + const std::vector &edges) { + for (auto phi : ir::range(ir::Block(to).getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + auto newPhi = + Builder::createPrepend(context, route) + .createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue()); + + for (std::size_t i = 1; i < phi.getOperandCount();) { + auto label = phi.getOperand(i + 1).getAsValue().cast(); + + if (preds.contains(label)) { + newPhi.addOperand(phi.eraseOperand(i)); + newPhi.addOperand(phi.eraseOperand(i)); + } else { + i += 2; + } + } + + phi.addOperand(newPhi); + phi.addOperand(route); + + if (preds.size() != edges.size()) { + // merge block has additional edges. add dummy nodes to phi, this + // block not reachable from new blocks + + auto dummyValue = phi.getOperand(1).getAsValue(); + + for (auto edge : edges) { + if (!preds.contains(edge.from())) { + phi.addOperand(dummyValue); + phi.addOperand(edge.from()); + } + } + } + } +} + +// Process all target blocks and update their phi nodes +static void processTargetBlocks( + spv::Context &context, ir::Block route, ir::Value routePhi, + const RouteBlockData &data, + const std::unordered_map> &toPreds, + const std::vector &edges, + const std::function &getSuccessorId) { + + for (auto &[to, preds] : toPreds) { + if (toPreds.size() > 1) { + auto successorId = getSuccessorId(to); + + for (auto from : preds) { + // branches already resolved + if (data.patchPredecessors.contains(from)) { + continue; + } + + routePhi.addOperand(successorId); + routePhi.addOperand(from); + } + } + + for (auto from : preds) { + if (data.patchPredecessors.contains(from)) { + continue; + } + + replaceTerminatorTarget(getTerminator(from), to, route); + } + + if (data.toAllPreds.at(to).size() == preds.size()) { + // all predecessors will be replaced, move phi nodes + moveAllPhiNodes(context, to, route, preds, edges); + continue; + } + + if (preds.size() == 1) { + auto pred = *preds.begin(); + updatePhiNodesForSinglePred(to, pred, route); + continue; + } + + // partial predecessors replacement, update PHIs + updatePhiNodesPartial(context, to, route, preds, edges); + } +} + +// Main function +ir::Block shader::transform::createRouteBlock(spv::Context &context, + ir::InsertionPoint insertPoint, + const std::vector &edges) { + auto loc = context.getUnknownLocation(); + + rx::dieIf(edges.empty(), "createRouteBlock: unexpected edges count"); + + // Step 1: Analyze edges and build data structures + auto data = analyzeEdges(edges); + + // Step 2: Handle simple case - single target block + if (data.toPreds.size() == 1) { + auto &[to, preds] = *data.toPreds.begin(); + return createMergeBlock(context, insertPoint, preds, to); + } + + // Step 3: Create route block and phi node + auto [route, routePhi] = createRouteBlockWithPhi(context, insertPoint, + loc, data.toPreds.size()); + + // Step 4: Create appropriate terminator (branch/conditional/switch) + auto successorToId = createRouteTerminator(context, route, routePhi, + loc, data.toPreds); + + // Step 5: Create lambda for getting successor IDs + auto getSuccessorId = [&](ir::Block successor) { + return getSuccessorIdValue(context, successor, data.toPreds, successorToId); + }; + + // Step 6: Patch predecessor blocks that have multiple routes + for (auto patchBlock : data.patchPredecessors) { + patchPredecessorBlock(context, patchBlock, route, routePhi, data, + data.toPreds, getSuccessorId); + } + + // Step 7: Process target blocks and update phi nodes + processTargetBlocks(context, route, routePhi, data, data.toPreds, edges, + getSuccessorId); + + return route; +} diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/transformations.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/transformations.cpp new file mode 100644 index 000000000..6f36682f3 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/transformations.cpp @@ -0,0 +1,280 @@ +#include "SpvConverter.hpp" +#include "analyze.hpp" +#include "transform/transformations.hpp" +#include "dialect.hpp" +#include + +#include +#include +#include + +using namespace shader; +using namespace shader::transform; + +using Builder = ir::Builder; + +ir::Value shader::transform::transformToCanonicalRegion(spv::Context &context, + ir::RegionLike region) { + auto cfg = buildCFG(region.getFirst()); + std::vector exitNodes; + for (auto node : cfg.getPreorderNodes()) { + if (!node->hasSuccessors()) { + exitNodes.push_back(node); + } + } + + if (cfg.getEntryNode()->hasPredecessors()) { + auto builder = Builder::createPrepend(context, region); + auto prevEntry = cfg.getEntryLabel(); + auto newEntry = builder.createSpvLabel(prevEntry.getLocation()); + builder.createSpvBranch(prevEntry.getLocation(), prevEntry); + + for (auto it = prevEntry.getNext(); it && it == ir::spv::OpVariable;) { + auto moveInst = it; + it = it.getNext(); + + moveInst.erase(); + region.insertAfter(newEntry, moveInst); + } + } + + if (exitNodes.empty()) { + region.print(std::cerr, context.ns); + rx::die("scfg: cfg without termination block"); + } + + if (exitNodes.size() == 1) { + return exitNodes.back()->getLabel(); + } + + ir::Value returnType; + ir::Instruction returnInst; + + for (auto exitNode : exitNodes) { + auto terminator = exitNode->getTerminator(); + + if (terminator && terminator == ir::spv::OpReturnValue) { + auto terminatorReturnValue = terminator.getOperand(0).getAsValue(); + auto terminatorReturnType = + terminatorReturnValue.getOperand(0).getAsValue(); + if (returnType && terminatorReturnType == returnType) { + rx::die("scfg: unexpected terminator return type"); + } else { + returnType = terminatorReturnType; + } + } + + if (terminator) { + if (returnInst && returnInst.getInstId() != terminator.getInstId()) { + returnInst.print(std::cerr, context.ns); + std::cerr << '\n'; + terminator.print(std::cerr, context.ns); + std::cerr << '\n'; + rx::die("scfg: unexpected return instruction kind change"); + } else { + returnInst = terminator; + } + } + } + + if (returnType) { + auto variablePointerType = + context.getTypePointer(ir::spv::StorageClass::Function, returnType); + + auto returnValueVariable = + Builder::createInsertAfter(context, region.getFirst()) + .createSpvVariable(context.getUnknownLocation(), + variablePointerType, + ir::spv::StorageClass::Function); + + auto newExitBlock = [&] { + auto loc = context.getUnknownLocation(); + auto builder = Builder::createAppend(context, region); + auto newExitBlock = builder.createSpvLabel(loc); + + auto mergedReturnValue = + builder.createSpvLoad(loc, returnType, returnValueVariable); + builder.createSpvReturnValue(loc, mergedReturnValue); + return newExitBlock; + }(); + + for (auto exitNode : exitNodes) { + auto terminator = exitNode->getTerminator(); + + if (terminator) { + auto newTerminator = Builder::createInsertAfter(context, terminator); + + newTerminator.createSpvStore(terminator.getLocation(), + returnValueVariable, + terminator.getOperand(0).getAsValue()); + newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock); + terminator.erase(); + } + } + + return newExitBlock; + } + + if (!returnInst) { + rx::die("scfg: unexpected cfg terminator"); + } + + auto newExitBlock = Builder::createAppend(context, region) + .createSpvLabel(context.getUnknownLocation()); + + for (auto exitNode : exitNodes) { + auto terminator = exitNode->getTerminator(); + + if (terminator) { + auto newTerminator = Builder::createInsertAfter(context, terminator); + newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock); + terminator.erase(); + } + } + + region.insertAfter(newExitBlock, returnInst); + return newExitBlock; +} + +void shader::transform::transformToCf(spv::Context &context, ir::RegionLike region) { + ir::Block currentBlock; + + for (auto inst : region.children()) { + if (inst == ir::builtin::BLOCK) { + continue; + } + + if (inst == ir::spv::OpLabel) { + currentBlock = Builder::createInsertBefore(context, inst) + .createBlock(inst.getLocation()); + + if (auto name = context.ns.tryGetNameOf(inst); !name.empty()) { + context.ns.setNameOf(currentBlock, std::string(name)); + } + + inst.staticCast().replaceAllUsesWith(currentBlock); + inst.remove(); + continue; + } + + if (!currentBlock) { + inst.print(std::cerr, context.ns); + std::cerr << "\n"; + region.print(std::cerr, context.ns); + std::cerr << "\n"; + rx::die("cfg: node without label"); + } + + inst.erase(); + currentBlock.addChild(inst); + + if (isTerminator(inst)) { + currentBlock = nullptr; + } + } +} + +void shader::transform::transformToFlat(spv::Context &context, ir::RegionLike region) { + std::vector workList; + + workList.push_back(region.getFirst()); + + auto insertPoint = Builder::createPrepend(context, region); + + while (!workList.empty()) { + auto inst = workList.back(); + + workList.pop_back(); + + if (inst == nullptr) { + continue; + } + + auto unwrapBlock = [&](ir::Block block) { + if (auto construct = block.cast()) { + auto merge = construct.getMerge(); + auto cont = construct.getContinue().getHeader(); + auto body = construct.getHeader(); + + auto blockLabel = insertPoint.createSpvLabel(block.getLocation()); + construct.replaceAllUsesWith(blockLabel); + + if (auto name = context.ns.tryGetNameOf(block); !name.empty()) { + context.ns.setNameOf(blockLabel, std::string(name)); + } + + for (auto phi : ir::range(construct.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + insertPoint.eraseAndInsert(phi); + } + + insertPoint.createSpvLoopMerge(construct.getLocation(), merge, cont, + ir::spv::LoopControl::None()); + insertPoint.createSpvBranch(construct.getLocation(), body); + + workList.emplace_back(cont); + workList.emplace_back(construct.getFirst()); + return; + } + + if (auto construct = block.cast()) { + auto constructBody = construct.getHeader(); + + auto header = ir::InsertionPoint::createPrepend(constructBody); + auto merge = construct.getMerge(); + + for (auto phi : ir::range(construct.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + for (std::size_t i = 1; i < phi.getOperandCount();) { + if (phi.getOperand(i + 1) == construct) { + phi.eraseOperand(i); + phi.eraseOperand(i); + } else { + i += 2; + } + } + + header.eraseAndInsert(phi); + } + + Builder::createInsertBefore(context, constructBody.getLast()) + .createSpvSelectionMerge(construct.getLocation(), merge, + ir::spv::SelectionControl::None); + + construct.replaceAllUsesWith(constructBody); + workList.emplace_back(constructBody); + return; + } + + auto blockLabel = insertPoint.createSpvLabel(block.getLocation()); + + block.replaceAllUsesWith(blockLabel); + + workList.emplace_back(block.getFirst()); + + if (auto name = context.ns.tryGetNameOf(block); !name.empty()) { + context.ns.setNameOf(blockLabel, std::string(name)); + } + }; + + if (auto next = inst.getNext()) { + workList.push_back(next); + } + + if (auto block = inst.cast()) { + std::cout << "processing " << context.ns.getNameOf(block) << "\n"; + unwrapBlock(block); + block.erase(); + continue; + } + + insertPoint.eraseAndInsert(inst); + } +} + diff --git a/rpcsx/gpu/lib/gcn-shader/src/transform/wrap.cpp b/rpcsx/gpu/lib/gcn-shader/src/transform/wrap.cpp new file mode 100644 index 000000000..cb61f4859 --- /dev/null +++ b/rpcsx/gpu/lib/gcn-shader/src/transform/wrap.cpp @@ -0,0 +1,565 @@ +#include "SpvConverter.hpp" +#include "transform/Edge.hpp" +#include "transform/construct.hpp" +#include "transform/merge.hpp" +#include "transform/route.hpp" +#include "transform/wrap.hpp" +#include "dialect.hpp" +#include + +using namespace shader; +using namespace shader::transform; + +using Builder = ir::Builder; + +struct CycleEdges { + std::vector entryEdges; + std::vector backEdges; + std::vector exitEdges; +}; + +static CycleEdges +calculateCycleEdges(const std::unordered_set &cycles) { + CycleEdges result; + std::unordered_set entryBlocks; + + for (auto block : cycles) { + for (auto [pred, operandIndex] : getAllPredecessors(block)) { + if (cycles.contains(pred)) { + continue; + } + + result.entryEdges.emplace_back(pred, operandIndex); + } + + for (auto [succ, operandIndex] : getAllSuccessors(block)) { + if (cycles.contains(succ)) + continue; + + entryBlocks.insert(succ); + result.exitEdges.emplace_back(block, operandIndex); + } + } + + for (auto block : cycles) { + for (auto [succ, operandIndex] : getAllSuccessors(block)) { + if (entryBlocks.contains(succ)) + continue; + + result.backEdges.emplace_back(block, operandIndex); + } + } + + return result; +} + + +static ir::Instruction skipPhis(ir::Instruction inst) { + while (inst && inst == ir::spv::OpPhi) { + inst = inst.getNext(); + } + + return inst; +} + +/** + * Tarjan's algorithm for finding strongly connected components (SCCs). + * This finds all cycles in the CFG + */ +static std::vector> +findSCCs(ir::Range nodes) { + std::unordered_map indices; + std::unordered_map lowlinks; + std::unordered_set onStack; + std::vector stack; + std::vector> sccs; + std::size_t index = 0; + + auto rootParent = (*nodes.begin()).getParent(); + + std::function strongConnect = [&](ir::Block node) { + indices[node] = index; + lowlinks[node] = index; + index++; + stack.push_back(node); + onStack.insert(node); + + // Consider successors of node + for (auto successor : getSuccessors(node)) { + if (successor.getParent() != rootParent) { + continue; + } + + if (!indices.contains(successor)) { + // Successor has not yet been visited; recurse on it + strongConnect(successor); + lowlinks[node] = std::min(lowlinks[node], lowlinks[successor]); + } else if (onStack.contains(successor)) { + // Successor is in stack and hence in the current SCC + lowlinks[node] = std::min(lowlinks[node], indices[successor]); + } + } + + // If node is a root node, pop the stack and create an SCC + if (lowlinks[node] == indices[node]) { + std::unordered_set scc; + scc.reserve(stack.size()); + ir::Block w; + do { + w = stack.back(); + stack.pop_back(); + onStack.erase(w); + scc.insert(w); + } while (w != node); + + // keep cycles only + if (!scc.empty()) { + auto isLoop = scc.size() > 1; + + if (!isLoop) { + // single node can contain branch to self + isLoop = hasSuccessor(w, w); + } + + if (isLoop) { + sccs.push_back(std::move(scc)); + } + } + } + }; + + for (auto node : nodes) { + if (node.getParent() != rootParent) { + continue; + } + + if (!indices.contains(node)) { + strongConnect(node); + } + } + return sccs; +} + +void shader::transform::wrapLoopConstructs(spv::Context &context, ir::RegionLike root) { + auto region = root.children(); + auto sccs = findSCCs(region); + + for (auto scc : sccs) { + auto edges = calculateCycleEdges(scc); + + ir::Block bodyLabel; + ir::Block continueLabel; + ir::Block mergeLabel; + ir::Block latchLabel; + + if (!edges.entryEdges.empty()) { + if (edges.entryEdges.size() == 1 && edges.backEdges.size() == 1 && + edges.entryEdges[0].to() == edges.backEdges[0].to()) { + bodyLabel = edges.entryEdges[0].to(); + continueLabel = edges.backEdges[0].from(); + } + + if (!bodyLabel) { + std::vector entryEdges = edges.entryEdges; + // back edges should jump to entry block + entryEdges.insert(entryEdges.end(), edges.backEdges.begin(), + edges.backEdges.end()); + + // for loop no need to split blocks, we can just rotate loop + bodyLabel = createRouteBlock( + context, ir::InsertionPoint::createInsertBefore(*scc.begin()), + entryEdges); + scc.insert(bodyLabel); + edges = calculateCycleEdges(scc); + } + + if (!continueLabel || bodyLabel == continueLabel || + getSuccessorCount(continueLabel) != 1) { + + std::unordered_set preds; + for (auto edge : edges.backEdges) { + preds.insert(edge.from()); + } + continueLabel = createMergeBlock( + context, ir::InsertionPoint::createInsertAfter(bodyLabel), preds, + bodyLabel); + scc.insert(continueLabel); + edges = calculateCycleEdges(scc); + } + } + + if (!edges.exitEdges.empty()) { + mergeLabel = [&] -> ir::Block { + auto exitEdges = std::span(edges.exitEdges); + auto header = exitEdges[0].to(); + exitEdges = exitEdges.subspan(1); + + while (!exitEdges.empty()) { + if (header != exitEdges[0].to()) { + return {}; + } + + exitEdges = exitEdges.subspan(1); + } + + return header; + }(); + + if (mergeLabel) { + auto predecessors = getPredecessors(mergeLabel); + + for (auto pred : predecessors) { + if (!scc.contains(pred)) { + mergeLabel = {}; + break; + } + } + + if (mergeLabel && predecessors.size() == 1) { + latchLabel = *predecessors.begin(); + + auto latchSuccessors = getSuccessors(latchLabel); + + auto it = latchSuccessors.begin(); + auto firstSuccessor = *it; + auto secondSuccessor = *++it; + + if ((firstSuccessor != continueLabel && + secondSuccessor != continueLabel)) { + latchLabel = {}; + mergeLabel = {}; + } + + if (latchLabel && getPredecessorCount(continueLabel) != 1) { + latchLabel = {}; + } + } + } + + if (!mergeLabel) { + mergeLabel = createRouteBlock( + context, + ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()), + edges.exitEdges); + + edges = calculateCycleEdges(scc); + } + + if (!latchLabel) { + std::vector exitEdges = edges.exitEdges; + + for (auto [pred, operandIndex] : getAllPredecessors(continueLabel)) { + exitEdges.emplace_back(pred, operandIndex); + } + + latchLabel = createRouteBlock( + context, + ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()), + exitEdges); + scc.insert(latchLabel); + } + } + + if (bodyLabel && continueLabel && mergeLabel) { + auto loopConstruct = createLoopConstruct( + context, root, bodyLabel, latchLabel, continueLabel, mergeLabel, scc); + + // replace references to body outside this construct with header (i.e. + // loop construct node) + bodyLabel.replaceUsesIf(loopConstruct, [=](ir::ValueUse use) { + return (isTerminator(use.user) || + (use.user != loopConstruct && isConstruct(use.user))) && + getConstructOf(use.user) != loopConstruct; + }); + + // move PHIs to construct + for (auto phi : ir::range(bodyLabel.getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + phi.erase(); + loopConstruct.prependChild(phi); + } + } + } +} + +void shader::transform::wrapSelectionConstructs(spv::Context &context, + ir::RegionLike root) { + std::vector> workList; + workList.push_back(root.children()); + std::unordered_set usedMergeBlocks; + + while (!workList.empty()) { + auto region = workList.back(); + workList.pop_back(); + + for (auto entryBlock : region) { + if (isConstruct(entryBlock)) { + if (entryBlock == ir::builtin::SELECTION_CONSTRUCT) { + if (auto body = + skipPhis(entryBlock.getFirst()).getNext().cast()) { + workList.emplace_back(ir::range(body)); + } + } else if (auto body = + skipPhis(entryBlock.getFirst()).cast()) { + workList.emplace_back(ir::range(body)); + } + continue; + } + + auto terminator = entryBlock.getLast(); + if (!terminator || !isTerminator(terminator) || + (terminator != ir::spv::OpBranchConditional && + terminator != ir::spv::OpSwitch)) { + continue; + } + + ir::RegionLike parentConstruct = getConstructOf(entryBlock); + + if (auto parentSelection = + parentConstruct.cast()) { + if (parentSelection.getHeader() == entryBlock) { + continue; + } + } + + auto successors = getSuccessors(entryBlock); + + if (parentConstruct) { + if (parentConstruct.getLast() == entryBlock) { + // do not look at latch/continuation blocks + continue; + } + } + + if (!parentConstruct) { + parentConstruct = root; + } + + std::unordered_set components; + components.insert(entryBlock); + + auto addConstructComponent = [&](ir::Construct construct) { + components.insert(construct); + + // add whole body of construct + for (auto child : construct.children()) { + components.insert(child); + } + + if (auto loop = construct.cast()) { + // it if is loop, add continue construct also + for (auto child : loop.getContinue().children()) { + components.insert(child); + } + } + + auto constructMerge = construct.getMerge(); + if (parentConstruct != root && + getSuccessorCount(construct.getMerge()) == 0) { + // we cannot take this merge block, it is exit from function block + // create trampoline node and replace merge block of this node + + auto newMerge = createMergeBlock( + context, ir::InsertionPoint::createInsertBefore(constructMerge), + getPredecessors(constructMerge), constructMerge); + + construct.setMerge(newMerge); + constructMerge = newMerge; + } + + components.insert(constructMerge); + + return getSuccessors(constructMerge); + }; + + auto addComponent = [&](ir::Block block) { + if (auto construct = block.cast()) { + return addConstructComponent(construct); + } + + if (hasAtLeastSuccessors(block, 1)) { + components.insert(block); + return getSuccessors(block); + } + + auto trampoline = createMergeBlock( + context, ir::InsertionPoint::createInsertBefore(block), + getPredecessors(block), block); + + components.insert(trampoline); + return getSuccessors(block); + }; + + { + // try to find blocks that has no other predecessors + + auto parentEntry = + skipPhis(parentConstruct.getFirst()).staticCast(); + + auto headerSuccessors = getSuccessors(entryBlock); + + std::vector workList(headerSuccessors.begin(), + headerSuccessors.end()); + while (!workList.empty()) { + auto block = workList.back(); + workList.pop_back(); + + if (components.contains(block)) { + continue; + } + + if (block.getParent() != parentConstruct) { + continue; + } + + if (block == parentEntry || block == parentConstruct.getLast()) { + // do not take entry/latch/continuation of parent construct + continue; + } + + bool hasAllPreds = true; + auto loop = block.cast(); + for (auto pred : getPredecessors(block)) { + if (components.contains(pred)) { + continue; + } + + if (loop && pred == loop.getContinue().getLast()) { + // ignore continue predecessor of loop + continue; + } + + hasAllPreds = false; + break; + } + + if (hasAllPreds) { + addComponent(block); + } + } + } + + if (components.size() == 1) { + // all successors are used by nodes outside this header, it means it is + // not structured loop node or case block of OpSwitch with fallthrough + continue; + } + + ir::Block entryLabel = entryBlock; + ir::Block mergeLabel; + bool mergeInserted = false; + + std::unordered_set exitBlocks; + std::vector exitEdges; + for (auto block : components) { + for (auto [succ, operandIndex] : getAllSuccessors(block)) { + if (!components.contains(succ)) { + exitEdges.emplace_back(block, operandIndex); + exitBlocks.insert(block); + } + } + } + + if (!exitBlocks.empty()) { + if (exitBlocks.size() == 1) { + mergeLabel = *exitBlocks.begin(); + } + + if (!mergeLabel || + getAllPredecessors(mergeLabel).size() != exitEdges.size() || + isConstruct(mergeLabel)) { + mergeLabel = createRouteBlock( + context, ir::InsertionPoint::createInsertAfter(entryBlock), + exitEdges); + + workList.emplace_back(ir::range(mergeLabel)); + mergeInserted = true; + } + } else { + mergeLabel = parentConstruct.getLast().staticCast(); + } + + if (!mergeInserted) { + for (auto user : mergeLabel.getUserList()) { + if (auto construct = user.cast()) { + if (construct.getMerge() != mergeLabel) { + continue; + } + } + mergeLabel = createMergeBlock( + context, ir::InsertionPoint::createInsertBefore(mergeLabel), + getPredecessors(mergeLabel), mergeLabel); + mergeInserted = true; + break; + } + } + + if (!mergeInserted) { + auto mergePreds = getPredecessors(mergeLabel); + std::unordered_set branchesInsideConstruct; + + for (auto pred : mergePreds) { + if (components.contains(pred)) { + branchesInsideConstruct.insert(pred); + } + } + + if (branchesInsideConstruct.size() != mergePreds.size()) { + mergeLabel = createMergeBlock( + context, ir::InsertionPoint::createInsertBefore(mergeLabel), + branchesInsideConstruct, mergeLabel); + } + } + + auto construct = createSelectionConstruct( + context, parentConstruct, components, entryLabel, mergeLabel); + + // update merge label + construct.setMerge(mergeLabel); + + construct.getHeader().replaceUsesIf(construct, [=](ir::ValueUse use) { + if (getConstructOf(use.user) != construct) { + if (isTerminator(use.user)) { + return true; + } + + // allow update block merges + if (isConstruct(use.user) && use.operandIndex == 1) { + return true; + } + } + + if (use.user != construct && isConstruct(use.user)) { + return true; + } + + return false; + }); + + // move PHIs to construct + for (auto phi : ir::range(construct.getHeader().getFirst())) { + if (phi != ir::spv::OpPhi) { + break; + } + + phi.erase(); + construct.prependChild(phi); + } + + // view child constructs + if (auto child = construct.getHeader().getNext().cast()) { + workList.emplace_back(ir::range(child)); + } + + // view next constructs + if (auto next = construct.getNext()) { + workList.emplace_back(ir::range(next)); + } + + break; + } + } +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 000000000..2a48a1042 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,52 @@ +# Tests and Benchmarks Configuration +cmake_minimum_required(VERSION 3.16) +project(MemoryTableTests) + +# Add Google Benchmark as an external project if not found +find_package(benchmark QUIET) + +if(NOT benchmark_FOUND) + include(FetchContent) + + FetchContent_Declare( + googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.9.4 + ) + + # Disable benchmark tests and examples to speed up build + set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) + set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable Google Test for benchmark" FORCE) + set(BENCHMARK_ENABLE_WERROR OFF CACHE BOOL "Disable warnings as errors" FORCE) + + FetchContent_MakeAvailable(googlebenchmark) +endif() + +# Add Google Test for unit testing +find_package(GTest QUIET) + +if(NOT GTest_FOUND) + include(FetchContent) + + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.15.2 + ) + + # Disable gtest installation + set(INSTALL_GTEST OFF CACHE BOOL "Disable gtest installation" FORCE) + set(gtest_force_shared_crt ON CACHE BOOL "Force shared CRT for gtest" FORCE) + + FetchContent_MakeAvailable(googletest) +endif() + +add_executable(gcn_shader_tests + gcn_shader_tests.cpp +) + +target_link_libraries(gcn_shader_tests + gtest + gtest_main + gcn-shader +) diff --git a/test/gcn_shader_tests.cpp b/test/gcn_shader_tests.cpp new file mode 100644 index 000000000..32eda1293 --- /dev/null +++ b/test/gcn_shader_tests.cpp @@ -0,0 +1,253 @@ +#include +#include + +// Include shader framework for CFG testing +#include "shader/SpvConverter.hpp" +#include "shader/analyze.hpp" +#include "shader/dialect.hpp" +#include "shader/ir.hpp" +#include "shader/ir/Context.hpp" +#include "shader/spv.hpp" +#include "shader/transform.hpp" + +using namespace shader; +using Builder = ir::Builder; + +class GcnShaderTest : public ::testing::Test { +protected: + void SetUp() override { + // Setup SPIR-V context for CFG testing + context = std::make_unique(); + loc = context->getUnknownLocation(); + trueV = context->getTrue(); + falseV = context->getFalse(); + } + + void TearDown() override { context.reset(); } + + ir::Value createLabel(const std::string &name) { + auto builder = Builder::createAppend( + *context, context->layout.getOrCreateFunctions(*context)); + auto label = builder.createSpvLabel(loc); + context->ns.setNameOf(label, name); + return label; + } + + void createBranch(ir::Value from, ir::Value to) { + Builder::createInsertAfter(*context, from).createSpvBranch(loc, to); + } + + void createConditionalBranch(ir::Value from, ir::Value a, ir::Value b) { + Builder::createInsertAfter(*context, from) + .createSpvBranchConditional(loc, trueV, a, b); + } + + void createReturn(ir::Value from) { + Builder::createInsertAfter(*context, from).createSpvReturn(loc); + } + + void createSwitch(ir::Value from, std::span cases) { + auto globals = Builder::createAppend( + *context, context->layout.getOrCreateGlobals(*context)); + auto globalVariable = globals.createSpvVariable( + loc, context->getTypeUInt32(), ir::spv::StorageClass::Private, + context->imm32(0)); + + auto switchOp = Builder::createInsertAfter(*context, from) + .createSpvSwitch(loc, globalVariable, cases[0]); + + std::uint32_t i = 0; + for (auto c : cases.subspan(1)) { + switchOp.addOperand(i++); + switchOp.addOperand(c); + } + } + + void createSwitchBranch(ir::Value from, ir::Value defaultTarget, + const std::vector>& cases) { + // Create a switch value (use a constant for testing) + auto type = context->getTypeUInt32(); + auto globals = Builder::createAppend( + *context, context->layout.getOrCreateGlobals(*context)); + auto globalVariable = globals.createSpvConstant( + loc, type, 0); + + auto builder = Builder::createInsertAfter(*context, from); + auto switchInst = + builder.createSpvSwitch(loc, globalVariable, defaultTarget); + + // Add each case + for (const auto& [value, target] : cases) { + switchInst.addOperand(value); + switchInst.addOperand(target); + } + } + + bool testStructurization() { + auto region = context->layout.getOrCreateFunctions(*context); + context->layout.regions[spv::BinaryLayout::kFunctions] = {}; + auto functions = context->layout.getOrCreateFunctions(*context); + + structurizeCfg(*context, region); + + { + auto debugs = Builder::createAppend( + *context, context->layout.getOrCreateDebugs(*context)); + + auto cfg = buildCFG(region.getFirst()); + for (auto node : cfg.getPreorderNodes()) { + auto value = node->getLabel(); + if (auto name = context->ns.tryGetNameOf(value); !name.empty()) { + debugs.createSpvName(loc, value, std::string(name)); + } + } + + for (auto bb : cfg.getPreorderNodes()) { + for (auto child : bb->range()) { + child.erase(); + functions.addChild(child); + } + } + } + region = functions; + + auto entryLabel = region.getFirst().cast(); + + auto memModel = Builder::createAppend( + *context, context->layout.getOrCreateMemoryModels(*context)); + auto capabilities = Builder::createAppend( + *context, context->layout.getOrCreateCapabilities(*context)); + + capabilities.createSpvCapability(loc, ir::spv::Capability::Shader); + + memModel.createSpvMemoryModel(loc, ir::spv::AddressingModel::Logical, + ir::spv::MemoryModel::GLSL450); + + auto mainReturnT = context->getTypeVoid(); + auto mainFnT = context->getTypeFunction(mainReturnT, {}); + + auto builder = Builder::createPrepend(*context, region); + auto mainFn = builder.createSpvFunction( + loc, mainReturnT, ir::spv::FunctionControl::None, mainFnT); + + builder.createSpvLabel(loc); + builder.createSpvBranch(loc, entryLabel); + + Builder::createAppend(*context, region).createSpvFunctionEnd(loc); + + auto entryPoints = Builder::createAppend( + *context, context->layout.getOrCreateEntryPoints(*context)); + + auto executionModes = Builder::createAppend( + *context, context->layout.getOrCreateExecutionModes(*context)); + + executionModes.createSpvExecutionMode( + mainFn.getLocation(), mainFn, + ir::spv::ExecutionMode::LocalSize(1, 1, 1)); + + entryPoints.createSpvEntryPoint(mainFn.getLocation(), + ir::spv::ExecutionModel::GLCompute, mainFn, + "main", {}); + + auto spv = shader::spv::serialize(context->layout.merge(*context)); + if (shader::spv::validate(spv)) { + return true; + } + + shader::spv::dump(spv, true); + return false; + } + +protected: + std::unique_ptr context; + ir::Location loc; + ir::Value trueV; + ir::Value falseV; +}; + +TEST_F(GcnShaderTest, ProjectDivaTest1) { + auto _1 = createLabel("1"); + auto _2 = createLabel("2"); + auto _3 = createLabel("3"); + auto _4 = createLabel("4"); + auto _5 = createLabel("5"); + auto _6 = createLabel("6"); + auto _7 = createLabel("7"); + auto _8 = createLabel("8"); + auto _9 = createLabel("9"); + auto _10 = createLabel("10"); + auto _11 = createLabel("11"); + auto _12 = createLabel("12"); + auto _13 = createLabel("13"); + createBranch(_1, _2); + createConditionalBranch(_2, _4, _3); + createConditionalBranch(_3, _12, _11); + createConditionalBranch(_4, _6, _5); + createConditionalBranch(_5, _9, _8); + createBranch(_6, _7); + createBranch(_7, _6); + createBranch(_8, _3); + createBranch(_9, _10); + createBranch(_10, _7); + createBranch(_11, _12); + createBranch(_12, _13); + createReturn(_13); + + EXPECT_TRUE(testStructurization()); +} + +TEST_F(GcnShaderTest, BatmanReturnToArkham1) { + auto _1 = createLabel("1"); + auto _2 = createLabel("2"); + auto _3 = createLabel("3"); + auto _4 = createLabel("4"); + auto _5 = createLabel("5"); + auto _6 = createLabel("6"); + auto _7 = createLabel("7"); + auto _8 = createLabel("8"); + auto _9 = createLabel("9"); + auto _10 = createLabel("10"); + auto _11 = createLabel("11"); + auto _12 = createLabel("12"); + auto _13 = createLabel("13"); + auto _14 = createLabel("14"); + auto _15 = createLabel("15"); + auto _16 = createLabel("16"); + auto _17 = createLabel("17"); + auto _18 = createLabel("18"); + auto _19 = createLabel("19"); + auto _20 = createLabel("20"); + auto _21 = createLabel("21"); + auto _22 = createLabel("22"); + auto _23 = createLabel("23"); + auto _24 = createLabel("24"); + auto _25 = createLabel("25"); + createBranch(_1, _2); + createConditionalBranch(_2, _4, _3); + createConditionalBranch(_3, _6, _5); + createBranch(_4, _3); + createConditionalBranch(_5, _8, _7); + createBranch(_6, _5); + createConditionalBranch(_7, _10, _9); + createBranch(_8, _7); + createConditionalBranch(_9, _12, _11); + createBranch(_10, _9); + createConditionalBranch(_11, _14, _13); + createBranch(_12, _11); + createConditionalBranch(_13, _16, _15); + createBranch(_14, _13); + createBranch(_15, _25); + createConditionalBranch(_16, _18, _17); + createBranch(_17, _18); + createConditionalBranch(_18, _20, _19); + createBranch(_19, _20); + createConditionalBranch(_20, _22, _21); + createBranch(_21, _22); + createConditionalBranch(_22, _24, _23); + createBranch(_23, _24); + createBranch(_24, _15); + createReturn(_25); + + EXPECT_TRUE(testStructurization()); +} +