Refactor gcn shader transform methods

This commit is contained in:
O1L 2025-10-14 15:09:04 +03:00
parent 00c55e9688
commit 587834fe33
18 changed files with 1976 additions and 1419 deletions

View file

@ -288,3 +288,4 @@ if (WITH_PS3)
add_subdirectory(ps3fw)
endif()
add_subdirectory(test)

View file

@ -30,6 +30,12 @@ add_library(gcn-shader STATIC
src/SpvConverter.cpp
src/SpvTypeInfo.cpp
src/transform.cpp
src/transform/replace.cpp
src/transform/route.cpp
src/transform/merge.cpp
src/transform/construct.cpp
src/transform/transformations.cpp
src/transform/wrap.cpp
)
target_include_directories(gcn-shader PUBLIC include PRIVATE include/shader)

View file

@ -0,0 +1,33 @@
#pragma once
#include "analyze.hpp"
#include "ir.hpp"
#include "replace.hpp"
namespace shader::transform {
class Edge {
ir::Block mFromBlock;
int mToOperandIndex;
public:
Edge(ir::Block fromBlock, int toOperandIndex)
: mFromBlock(fromBlock), mToOperandIndex(toOperandIndex) {}
[[nodiscard]] ir::Block from() const { return mFromBlock; }
[[nodiscard]] ir::Block to() const {
return getTerminator(mFromBlock)
.getOperand(mToOperandIndex)
.getAsValue()
.staticCast<ir::Block>();
}
[[nodiscard]] int operandIndex() const { return mToOperandIndex; }
void replaceSuccessor(ir::Value newSuccessor) {
replaceTerminatorTarget(getTerminator(mFromBlock), mToOperandIndex,
newSuccessor);
}
bool operator==(const Edge &) const = default;
};
}

View file

@ -0,0 +1,28 @@
#pragma once
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "ir.hpp"
namespace shader::transform {
bool isConstruct(ir::Instruction block);
bool isParentConstruct(ir::RegionLike parent,
ir::RegionLike construct);
ir::Block getConstructOf(ir::Instruction inst);
ir::Block getConstructMergeBlock(ir::Block block);
ir::SelectionConstruct createSelectionConstruct(spv::Context &context,
ir::RegionLike parentConstruct,
const std::unordered_set<ir::Block> &components,
ir::Block header, ir::Block merge);
ir::LoopConstruct createLoopConstruct(spv::Context &context,
ir::RegionLike parentConstruct,
ir::Block header,
ir::Block latch,
ir::Block cont,
ir::Block merge,
const std::unordered_set<shader::ir::Block> &scc);
}

View file

@ -0,0 +1,14 @@
#pragma once
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "dialect.hpp"
#include "ir.hpp"
namespace shader::transform {
ir::Block createMergeBlock(spv::Context &context,
ir::InsertionPoint insertPoint,
const std::unordered_set<ir::Block> &preds,
ir::Block to);
}

View file

@ -0,0 +1,15 @@
#pragma once
#include "analyze.hpp"
#include "dialect.hpp"
#include "ir.hpp"
namespace shader::transform {
void replaceTerminatorTarget(ir::Instruction terminator,
int operandIndex,
ir::Value newTarget);
bool replaceTerminatorTarget(ir::Instruction terminator,
ir::Value oldTarget,
ir::Value newTarget);
}

View file

@ -0,0 +1,11 @@
#pragma once
#include "Edge.hpp"
#include "SpvConverter.hpp"
#include "ir.hpp"
namespace shader::transform {
ir::Block createRouteBlock(spv::Context &context,
ir::InsertionPoint insertPoint,
const std::vector<Edge> &edges);
}

View file

@ -0,0 +1,12 @@
#pragma once
#include "SpvConverter.hpp"
#include "ir.hpp"
namespace shader::transform {
ir::Value transformToCanonicalRegion(spv::Context &context,
ir::RegionLike region);
void transformToCf(spv::Context &context, ir::RegionLike region);
void transformToFlat(spv::Context &context, ir::RegionLike region);
}

View file

@ -0,0 +1,8 @@
#pragma once
#include "SpvConverter.hpp"
namespace shader::transform {
void wrapLoopConstructs(spv::Context &context, ir::RegionLike root);
void wrapSelectionConstructs(spv::Context &context, ir::RegionLike root);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,154 @@
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "transform/construct.hpp"
#include "dialect.hpp"
#include <rx/die.hpp>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
bool shader::transform::isConstruct(ir::Instruction block) {
return block == ir::builtin::LOOP_CONSTRUCT ||
block == ir::builtin::SELECTION_CONSTRUCT ||
block == ir::builtin::CONTINUE_CONSTRUCT;
}
ir::Block shader::transform::getConstructOf(ir::Instruction inst) {
auto block = inst.cast<ir::Block>();
if (block && isConstruct(block)) {
block = block.getParent().cast<ir::Block>();
}
while (block) {
if (isConstruct(block)) {
return block;
}
block = block.getParent().cast<ir::Block>();
}
return {};
}
ir::Block shader::transform::getConstructMergeBlock(ir::Block block) {
if (auto construct = block.cast<ir::Construct>()) {
return construct.getMerge();
}
return {};
}
bool shader::transform::isParentConstruct(ir::RegionLike parent,
ir::RegionLike construct) {
while (parent != construct && construct) {
construct = construct.getParent();
}
return parent == construct;
}
ir::SelectionConstruct
shader::transform::createSelectionConstruct(spv::Context &context,
ir::RegionLike parentConstruct,
const std::unordered_set<ir::Block> &components,
ir::Block header,
ir::Block merge) {
auto selectionConstruct =
Builder::createInsertBefore(context, header)
.createSelectionConstruct(header.getLocation(), header, merge);
std::vector<ir::Block> workList;
workList.emplace_back(header);
while (!workList.empty()) {
ir::Block block = workList.back();
workList.pop_back();
block.erase();
selectionConstruct.addChild(block);
std::unordered_set<ir::Block> successors;
if (auto construct = block.cast<ir::Construct>()) {
successors = {construct.getMerge()};
} else {
successors = getSuccessors(block);
}
for (auto succ : successors) {
if (succ == merge || succ.getParent() != parentConstruct ||
!components.contains(succ)) {
continue;
}
workList.push_back(succ);
}
}
merge.erase();
selectionConstruct.getParent().insertAfter(selectionConstruct, merge);
return selectionConstruct;
}
ir::LoopConstruct
shader::transform::createLoopConstruct(spv::Context &context,
ir::RegionLike parentConstruct,
ir::Block header,
ir::Block latch,
ir::Block cont,
ir::Block merge,
const std::unordered_set<shader::ir::Block> &scc) {
auto continueConstruct =
Builder::createInsertAfter(context, header)
.createContinueConstruct(header.getLocation(), cont, header);
auto loopConstruct = Builder::createInsertBefore(context, header)
.createLoopConstruct(header.getLocation(), header,
merge, continueConstruct);
continueConstruct.erase();
header.erase();
loopConstruct.addChild(header);
std::vector<ir::Block> workList;
workList.emplace_back(header);
while (!workList.empty()) {
ir::Block block = workList.back();
workList.pop_back();
block.erase();
loopConstruct.addChild(block);
std::unordered_set<ir::Block> successors;
if (isConstruct(block)) {
successors = {getConstructMergeBlock(block)};
} else {
successors = getSuccessors(block);
}
for (auto succ : successors) {
if (succ == merge || succ.getParent() != parentConstruct ||
!scc.contains(succ)) {
continue;
}
workList.push_back(succ);
}
}
latch.erase();
loopConstruct.addChild(latch);
cont.erase();
continueConstruct.addChild(cont);
merge.erase();
loopConstruct.getParent().insertAfter(loopConstruct, merge);
return loopConstruct;
}

View file

@ -0,0 +1,77 @@
#include "SpvConverter.hpp"
#include "transform/merge.hpp"
#include "analyze.hpp"
#include "transform/replace.hpp"
#include "dialect.hpp"
#include <rx/die.hpp>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
ir::Block shader::transform::createMergeBlock(spv::Context &context,
ir::InsertionPoint insertPoint,
const std::unordered_set<ir::Block> &preds,
ir::Block to) {
rx::dieIf(preds.empty(), "createMergeBlock: unexpected edges count");
auto loc = to.getLocation();
auto mergeBlock = Builder::create(context, insertPoint).createBlock(loc);
Builder::createAppend(context, mergeBlock).createSpvBranch(loc, to);
if (preds.size() == getPredecessorCount(to)) {
for (auto phi : ir::range(to.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
phi.erase();
mergeBlock.prependChild(phi);
}
} else if (preds.size() == 1) {
auto pred = *preds.begin();
for (auto phi : ir::range(to.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) {
if (phi.getOperand(i) == pred) {
phi.replaceOperand(i, mergeBlock);
}
}
}
} else {
for (auto phi : ir::range(to.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
auto newPhi =
Builder::createPrepend(context, mergeBlock)
.createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue());
for (std::size_t i = 1; i < phi.getOperandCount();) {
// auto value = phi.getOperand(i).getAsValue();
auto label = phi.getOperand(i + 1).getAsValue().staticCast<ir::Block>();
if (preds.contains(label)) {
newPhi.addOperand(phi.eraseOperand(i));
newPhi.addOperand(phi.eraseOperand(i));
} else {
i += 2;
}
}
phi.addOperand(newPhi);
phi.addOperand(mergeBlock);
}
}
for (auto pred : preds) {
replaceTerminatorTarget(getTerminator(pred), to, mergeBlock);
}
return mergeBlock;
}

View file

@ -0,0 +1,36 @@
#include "transform/replace.hpp"
#include "dialect.hpp"
#include <rx/die.hpp>
using namespace shader;
using namespace shader::transform;
void shader::transform::replaceTerminatorTarget(ir::Instruction terminator,
int operandIndex, ir::Value newTarget) {
auto prevTarget = terminator.getOperand(operandIndex).getAsValue();
terminator.replaceOperand(operandIndex, newTarget);
auto selection = terminator.getPrev();
if (selection == ir::spv::OpSelectionMerge ||
selection == ir::spv::OpLoopMerge) {
for (std::size_t i = 0, end = selection.getOperandCount(); i < end; ++i) {
if (selection.getOperand(i) == prevTarget) {
selection.replaceOperand(i, newTarget);
break;
}
}
}
}
bool shader::transform::replaceTerminatorTarget(ir::Instruction terminator,
ir::Value oldTarget, ir::Value newTarget) {
bool changes = false;
for (std::size_t i = 0, end = terminator.getOperandCount(); i < end; ++i) {
if (terminator.getOperand(i) == oldTarget) {
replaceTerminatorTarget(terminator, i, newTarget);
changes = true;
}
}
return changes;
}

View file

@ -0,0 +1,428 @@
#include "transform/route.hpp"
#include "transform/merge.hpp"
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "dialect.hpp"
#include <functional>
#include <rx/die.hpp>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
// Data structures for route block creation
struct RouteBlockData {
std::unordered_map<ir::Block, std::unordered_set<unsigned>> fromSucc;
std::unordered_map<ir::Block, std::unordered_set<ir::Block>> toPreds;
std::unordered_map<ir::Block, std::unordered_set<ir::Block>> toAllPreds;
std::unordered_set<ir::Block> patchPredecessors;
};
// Analyze edges and build routing data structures
static RouteBlockData analyzeEdges(const std::vector<Edge> &edges) {
RouteBlockData data;
std::unordered_set<ir::Block> routePredecessors;
for (auto edge : edges) {
if (!routePredecessors.insert(edge.from()).second) {
data.patchPredecessors.insert(edge.from());
}
data.toPreds[edge.to()].emplace(edge.from());
data.fromSucc[edge.from()].emplace(edge.operandIndex());
}
for (auto &[to, preds] : data.toPreds) {
data.toAllPreds[to] = getPredecessors(to);
}
return data;
}
// Create route block with appropriate phi node
static std::pair<ir::Block, ir::Value> createRouteBlockWithPhi(
spv::Context &context, ir::InsertionPoint insertPoint,
ir::Location loc, size_t predsCount) {
auto route = Builder::create(context, insertPoint).createBlock(loc);
ir::Value routePhi;
if (predsCount > 1) {
routePhi = Builder::createPrepend(context, route)
.createSpvPhi(loc, predsCount == 2
? context.getTypeBool()
: context.getTypeUInt32());
}
return {route, routePhi};
}
// Create terminator based on number of successors
static std::unordered_map<ir::Value, std::uint32_t> createRouteTerminator(
spv::Context &context, ir::Block route, ir::Value routePhi,
ir::Location loc,
const std::unordered_map<ir::Block, std::unordered_set<ir::Block>>
&toPreds) {
std::unordered_map<ir::Value, std::uint32_t> successorToId;
if (toPreds.size() == 1) {
// Single successor: unconditional branch
Builder::createAppend(context, route)
.createSpvBranch(loc, toPreds.begin()->first);
} else if (toPreds.size() == 2) {
// Two successors: conditional branch
auto it = toPreds.begin();
auto firstSuccessor = it->first;
auto secondSuccessor = (++it)->first;
Builder::createAppend(context, route)
.createSpvBranchConditional(loc, routePhi, firstSuccessor,
secondSuccessor);
} else {
// Multiple successors: switch statement
auto routeSwitch = Builder::createAppend(context, route)
.createSpvSwitch(loc, routePhi, toPreds.begin()->first);
successorToId.reserve(toPreds.size());
for (std::uint32_t id = 0; auto &[succ, pred] : toPreds) {
if (id) {
routeSwitch.addOperand(id);
routeSwitch.addOperand(succ);
}
successorToId[succ] = id++;
}
}
return successorToId;
}
// Get successor ID based on routing strategy
static ir::Value getSuccessorIdValue(
spv::Context &context, ir::Block successor,
const std::unordered_map<ir::Block, std::unordered_set<ir::Block>>
&toPreds,
const std::unordered_map<ir::Value, std::uint32_t> &successorToId) {
if (toPreds.size() == 2) {
return context.getBool(successor == toPreds.begin()->first);
}
return context.imm32(successorToId.at(successor));
}
// Process single predecessor block that needs patching
static void patchPredecessorBlock(
spv::Context &context, ir::Block patchBlock, ir::Block route,
ir::Value routePhi, const RouteBlockData &data,
const std::unordered_map<ir::Block, std::unordered_set<ir::Block>> &toPreds,
const std::function<ir::Value(ir::Block)> &getSuccessorId) {
auto predSuccessors = getAllSuccessors(patchBlock);
auto terminator = getTerminator(patchBlock);
auto &routeSuccessors = data.fromSucc.at(patchBlock);
int keepSuccessors = predSuccessors.size() - routeSuccessors.size();
assert(keepSuccessors >= 0);
assert(terminator == ir::spv::OpSwitch ||
terminator == ir::spv::OpBranchConditional);
auto cond = terminator.getOperand(0).getAsValue();
auto condType = cond.getOperand(0).getAsValue();
std::map<ir::Operand, ir::Block> condValueToSucc;
ir::Block defaultSucc;
if (keepSuccessors == 0) {
// we are going to replace all successors of this block, create direct
// jump to route block
Builder::createInsertAfter(context, terminator)
.createSpvBranch(terminator.getLocation(), route);
if (terminator == ir::spv::OpBranchConditional) {
condValueToSucc[context.getTrue()] =
terminator.getOperand(1).getAsValue().staticCast<ir::Block>();
condValueToSucc[context.getFalse()] =
terminator.getOperand(2).getAsValue().staticCast<ir::Block>();
} else if (terminator == ir::spv::OpSwitch) {
defaultSucc =
terminator.getOperand(1).getAsValue().staticCast<ir::Block>();
for (int i = 2, end = terminator.getOperandCount(); i < end; i += 2) {
condValueToSucc[terminator.getOperand(i)] =
terminator.getOperand(i + 1).getAsValue().staticCast<ir::Block>();
}
}
} else if (terminator == ir::spv::OpSwitch) {
if (routeSuccessors.contains(1)) {
defaultSucc =
terminator.getOperand(1).getAsValue().staticCast<ir::Block>();
}
bool shouldReplaceDefault = defaultSucc != nullptr;
for (int i = 2, id = 2, end = terminator.getOperandCount(); i < end;
id += 2) {
if (routeSuccessors.contains(id + 1)) {
if (shouldReplaceDefault) {
auto value = terminator.eraseOperand(i);
auto successor = terminator.eraseOperand(i);
condValueToSucc[value] =
successor.getAsValue().staticCast<ir::Block>();
continue;
}
condValueToSucc[terminator.getOperand(i)] =
terminator.getOperand(i + 1).getAsValue().staticCast<ir::Block>();
terminator.replaceOperand(i + 1, route);
}
i += 2;
}
if (shouldReplaceDefault) {
terminator.replaceOperand(1, route);
}
} else {
if (routeSuccessors.contains(1)) {
condValueToSucc[context.getTrue()] =
terminator.getOperand(1).getAsValue().staticCast<ir::Block>();
terminator.replaceOperand(1, route);
} else {
assert(routeSuccessors.contains(2));
condValueToSucc[context.getFalse()] =
terminator.getOperand(2).getAsValue().staticCast<ir::Block>();
terminator.replaceOperand(2, route);
}
}
if (routePhi) {
auto boolType = context.getTypeBool();
auto builder = Builder::createInsertBefore(context, terminator);
ir::Value selector;
if (defaultSucc) {
selector = getSuccessorId(defaultSucc);
}
auto selectorType =
toPreds.size() == 2 ? boolType : context.getTypeUInt32();
for (auto &[value, to] : condValueToSucc) {
if (!selector) {
selector = getSuccessorId(to);
} else {
auto valueId = value.getAsValue();
if (!valueId) {
valueId = context.imm32(*value.getAsInt32());
}
ir::Value selectionCond;
if (condType == boolType) {
selectionCond = builder.createSpvLogicalEqual(
terminator.getLocation(), boolType, cond, valueId);
} else {
selectionCond = builder.createSpvIEqual(terminator.getLocation(),
boolType, cond, valueId);
}
selector = builder.createSpvSelect(terminator.getLocation(),
selectorType, selectionCond,
getSuccessorId(to), selector);
}
}
routePhi.addOperand(selector);
routePhi.addOperand(patchBlock);
}
if (keepSuccessors == 0) {
terminator.remove();
}
}
// Move all phi nodes from target to route block
static void moveAllPhiNodes(spv::Context &context, ir::Block to, ir::Block route,
const std::unordered_set<ir::Block> &preds,
const std::vector<Edge> &edges) {
for (auto phi : ir::range(ir::Block(to).getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
phi.erase();
route.prependChild(phi);
if (preds.size() != edges.size()) {
// route block has additional edges. add dummy nodes to phi, this
// block not reachable from new predecessors anyway
auto undef = context.getUndef(phi.getOperand(0).getAsValue());
for (auto edge : edges) {
if (!preds.contains(edge.from())) {
phi.addOperand(undef);
phi.addOperand(edge.from());
}
}
}
}
}
// Update phi nodes for single predecessor
static void updatePhiNodesForSinglePred(ir::Block to, ir::Block pred,
ir::Block route) {
for (auto phi : ir::range(ir::Block(to).getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
for (std::size_t i = 2; i < phi.getOperandCount(); i += 2) {
auto label = phi.getOperand(i).getAsValue();
if (label == pred) {
phi.replaceOperand(i, route);
}
}
}
}
// Update phi nodes for partial predecessor replacement
static void updatePhiNodesPartial(spv::Context &context, ir::Block to,
ir::Block route,
const std::unordered_set<ir::Block> &preds,
const std::vector<Edge> &edges) {
for (auto phi : ir::range(ir::Block(to).getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
auto newPhi =
Builder::createPrepend(context, route)
.createSpvPhi(phi.getLocation(), phi.getOperand(0).getAsValue());
for (std::size_t i = 1; i < phi.getOperandCount();) {
auto label = phi.getOperand(i + 1).getAsValue().cast<ir::Block>();
if (preds.contains(label)) {
newPhi.addOperand(phi.eraseOperand(i));
newPhi.addOperand(phi.eraseOperand(i));
} else {
i += 2;
}
}
phi.addOperand(newPhi);
phi.addOperand(route);
if (preds.size() != edges.size()) {
// merge block has additional edges. add dummy nodes to phi, this
// block not reachable from new blocks
auto dummyValue = phi.getOperand(1).getAsValue();
for (auto edge : edges) {
if (!preds.contains(edge.from())) {
phi.addOperand(dummyValue);
phi.addOperand(edge.from());
}
}
}
}
}
// Process all target blocks and update their phi nodes
static void processTargetBlocks(
spv::Context &context, ir::Block route, ir::Value routePhi,
const RouteBlockData &data,
const std::unordered_map<ir::Block, std::unordered_set<ir::Block>> &toPreds,
const std::vector<Edge> &edges,
const std::function<ir::Value(ir::Block)> &getSuccessorId) {
for (auto &[to, preds] : toPreds) {
if (toPreds.size() > 1) {
auto successorId = getSuccessorId(to);
for (auto from : preds) {
// branches already resolved
if (data.patchPredecessors.contains(from)) {
continue;
}
routePhi.addOperand(successorId);
routePhi.addOperand(from);
}
}
for (auto from : preds) {
if (data.patchPredecessors.contains(from)) {
continue;
}
replaceTerminatorTarget(getTerminator(from), to, route);
}
if (data.toAllPreds.at(to).size() == preds.size()) {
// all predecessors will be replaced, move phi nodes
moveAllPhiNodes(context, to, route, preds, edges);
continue;
}
if (preds.size() == 1) {
auto pred = *preds.begin();
updatePhiNodesForSinglePred(to, pred, route);
continue;
}
// partial predecessors replacement, update PHIs
updatePhiNodesPartial(context, to, route, preds, edges);
}
}
// Main function
ir::Block shader::transform::createRouteBlock(spv::Context &context,
ir::InsertionPoint insertPoint,
const std::vector<Edge> &edges) {
auto loc = context.getUnknownLocation();
rx::dieIf(edges.empty(), "createRouteBlock: unexpected edges count");
// Step 1: Analyze edges and build data structures
auto data = analyzeEdges(edges);
// Step 2: Handle simple case - single target block
if (data.toPreds.size() == 1) {
auto &[to, preds] = *data.toPreds.begin();
return createMergeBlock(context, insertPoint, preds, to);
}
// Step 3: Create route block and phi node
auto [route, routePhi] = createRouteBlockWithPhi(context, insertPoint,
loc, data.toPreds.size());
// Step 4: Create appropriate terminator (branch/conditional/switch)
auto successorToId = createRouteTerminator(context, route, routePhi,
loc, data.toPreds);
// Step 5: Create lambda for getting successor IDs
auto getSuccessorId = [&](ir::Block successor) {
return getSuccessorIdValue(context, successor, data.toPreds, successorToId);
};
// Step 6: Patch predecessor blocks that have multiple routes
for (auto patchBlock : data.patchPredecessors) {
patchPredecessorBlock(context, patchBlock, route, routePhi, data,
data.toPreds, getSuccessorId);
}
// Step 7: Process target blocks and update phi nodes
processTargetBlocks(context, route, routePhi, data, data.toPreds, edges,
getSuccessorId);
return route;
}

View file

@ -0,0 +1,280 @@
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "transform/transformations.hpp"
#include "dialect.hpp"
#include <rx/die.hpp>
#include <iostream>
#include <rx/die.hpp>
#include <vector>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
ir::Value shader::transform::transformToCanonicalRegion(spv::Context &context,
ir::RegionLike region) {
auto cfg = buildCFG(region.getFirst());
std::vector<CFG::Node *> exitNodes;
for (auto node : cfg.getPreorderNodes()) {
if (!node->hasSuccessors()) {
exitNodes.push_back(node);
}
}
if (cfg.getEntryNode()->hasPredecessors()) {
auto builder = Builder::createPrepend(context, region);
auto prevEntry = cfg.getEntryLabel();
auto newEntry = builder.createSpvLabel(prevEntry.getLocation());
builder.createSpvBranch(prevEntry.getLocation(), prevEntry);
for (auto it = prevEntry.getNext(); it && it == ir::spv::OpVariable;) {
auto moveInst = it;
it = it.getNext();
moveInst.erase();
region.insertAfter(newEntry, moveInst);
}
}
if (exitNodes.empty()) {
region.print(std::cerr, context.ns);
rx::die("scfg: cfg without termination block");
}
if (exitNodes.size() == 1) {
return exitNodes.back()->getLabel();
}
ir::Value returnType;
ir::Instruction returnInst;
for (auto exitNode : exitNodes) {
auto terminator = exitNode->getTerminator();
if (terminator && terminator == ir::spv::OpReturnValue) {
auto terminatorReturnValue = terminator.getOperand(0).getAsValue();
auto terminatorReturnType =
terminatorReturnValue.getOperand(0).getAsValue();
if (returnType && terminatorReturnType == returnType) {
rx::die("scfg: unexpected terminator return type");
} else {
returnType = terminatorReturnType;
}
}
if (terminator) {
if (returnInst && returnInst.getInstId() != terminator.getInstId()) {
returnInst.print(std::cerr, context.ns);
std::cerr << '\n';
terminator.print(std::cerr, context.ns);
std::cerr << '\n';
rx::die("scfg: unexpected return instruction kind change");
} else {
returnInst = terminator;
}
}
}
if (returnType) {
auto variablePointerType =
context.getTypePointer(ir::spv::StorageClass::Function, returnType);
auto returnValueVariable =
Builder::createInsertAfter(context, region.getFirst())
.createSpvVariable(context.getUnknownLocation(),
variablePointerType,
ir::spv::StorageClass::Function);
auto newExitBlock = [&] {
auto loc = context.getUnknownLocation();
auto builder = Builder::createAppend(context, region);
auto newExitBlock = builder.createSpvLabel(loc);
auto mergedReturnValue =
builder.createSpvLoad(loc, returnType, returnValueVariable);
builder.createSpvReturnValue(loc, mergedReturnValue);
return newExitBlock;
}();
for (auto exitNode : exitNodes) {
auto terminator = exitNode->getTerminator();
if (terminator) {
auto newTerminator = Builder::createInsertAfter(context, terminator);
newTerminator.createSpvStore(terminator.getLocation(),
returnValueVariable,
terminator.getOperand(0).getAsValue());
newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock);
terminator.erase();
}
}
return newExitBlock;
}
if (!returnInst) {
rx::die("scfg: unexpected cfg terminator");
}
auto newExitBlock = Builder::createAppend(context, region)
.createSpvLabel(context.getUnknownLocation());
for (auto exitNode : exitNodes) {
auto terminator = exitNode->getTerminator();
if (terminator) {
auto newTerminator = Builder::createInsertAfter(context, terminator);
newTerminator.createSpvBranch(terminator.getLocation(), newExitBlock);
terminator.erase();
}
}
region.insertAfter(newExitBlock, returnInst);
return newExitBlock;
}
void shader::transform::transformToCf(spv::Context &context, ir::RegionLike region) {
ir::Block currentBlock;
for (auto inst : region.children()) {
if (inst == ir::builtin::BLOCK) {
continue;
}
if (inst == ir::spv::OpLabel) {
currentBlock = Builder::createInsertBefore(context, inst)
.createBlock(inst.getLocation());
if (auto name = context.ns.tryGetNameOf(inst); !name.empty()) {
context.ns.setNameOf(currentBlock, std::string(name));
}
inst.staticCast<ir::Value>().replaceAllUsesWith(currentBlock);
inst.remove();
continue;
}
if (!currentBlock) {
inst.print(std::cerr, context.ns);
std::cerr << "\n";
region.print(std::cerr, context.ns);
std::cerr << "\n";
rx::die("cfg: node without label");
}
inst.erase();
currentBlock.addChild(inst);
if (isTerminator(inst)) {
currentBlock = nullptr;
}
}
}
void shader::transform::transformToFlat(spv::Context &context, ir::RegionLike region) {
std::vector<ir::Instruction> workList;
workList.push_back(region.getFirst());
auto insertPoint = Builder::createPrepend(context, region);
while (!workList.empty()) {
auto inst = workList.back();
workList.pop_back();
if (inst == nullptr) {
continue;
}
auto unwrapBlock = [&](ir::Block block) {
if (auto construct = block.cast<ir::LoopConstruct>()) {
auto merge = construct.getMerge();
auto cont = construct.getContinue().getHeader();
auto body = construct.getHeader();
auto blockLabel = insertPoint.createSpvLabel(block.getLocation());
construct.replaceAllUsesWith(blockLabel);
if (auto name = context.ns.tryGetNameOf(block); !name.empty()) {
context.ns.setNameOf(blockLabel, std::string(name));
}
for (auto phi : ir::range(construct.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
insertPoint.eraseAndInsert(phi);
}
insertPoint.createSpvLoopMerge(construct.getLocation(), merge, cont,
ir::spv::LoopControl::None());
insertPoint.createSpvBranch(construct.getLocation(), body);
workList.emplace_back(cont);
workList.emplace_back(construct.getFirst());
return;
}
if (auto construct = block.cast<ir::SelectionConstruct>()) {
auto constructBody = construct.getHeader();
auto header = ir::InsertionPoint::createPrepend(constructBody);
auto merge = construct.getMerge();
for (auto phi : ir::range(construct.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
for (std::size_t i = 1; i < phi.getOperandCount();) {
if (phi.getOperand(i + 1) == construct) {
phi.eraseOperand(i);
phi.eraseOperand(i);
} else {
i += 2;
}
}
header.eraseAndInsert(phi);
}
Builder::createInsertBefore(context, constructBody.getLast())
.createSpvSelectionMerge(construct.getLocation(), merge,
ir::spv::SelectionControl::None);
construct.replaceAllUsesWith(constructBody);
workList.emplace_back(constructBody);
return;
}
auto blockLabel = insertPoint.createSpvLabel(block.getLocation());
block.replaceAllUsesWith(blockLabel);
workList.emplace_back(block.getFirst());
if (auto name = context.ns.tryGetNameOf(block); !name.empty()) {
context.ns.setNameOf(blockLabel, std::string(name));
}
};
if (auto next = inst.getNext()) {
workList.push_back(next);
}
if (auto block = inst.cast<ir::Block>()) {
std::cout << "processing " << context.ns.getNameOf(block) << "\n";
unwrapBlock(block);
block.erase();
continue;
}
insertPoint.eraseAndInsert(inst);
}
}

View file

@ -0,0 +1,565 @@
#include "SpvConverter.hpp"
#include "transform/Edge.hpp"
#include "transform/construct.hpp"
#include "transform/merge.hpp"
#include "transform/route.hpp"
#include "transform/wrap.hpp"
#include "dialect.hpp"
#include <rx/die.hpp>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
struct CycleEdges {
std::vector<Edge> entryEdges;
std::vector<Edge> backEdges;
std::vector<Edge> exitEdges;
};
static CycleEdges
calculateCycleEdges(const std::unordered_set<ir::Block> &cycles) {
CycleEdges result;
std::unordered_set<ir::Block> entryBlocks;
for (auto block : cycles) {
for (auto [pred, operandIndex] : getAllPredecessors(block)) {
if (cycles.contains(pred)) {
continue;
}
result.entryEdges.emplace_back(pred, operandIndex);
}
for (auto [succ, operandIndex] : getAllSuccessors(block)) {
if (cycles.contains(succ))
continue;
entryBlocks.insert(succ);
result.exitEdges.emplace_back(block, operandIndex);
}
}
for (auto block : cycles) {
for (auto [succ, operandIndex] : getAllSuccessors(block)) {
if (entryBlocks.contains(succ))
continue;
result.backEdges.emplace_back(block, operandIndex);
}
}
return result;
}
static ir::Instruction skipPhis(ir::Instruction inst) {
while (inst && inst == ir::spv::OpPhi) {
inst = inst.getNext();
}
return inst;
}
/**
* Tarjan's algorithm for finding strongly connected components (SCCs).
* This finds all cycles in the CFG
*/
static std::vector<std::unordered_set<ir::Block>>
findSCCs(ir::Range<ir::Block> nodes) {
std::unordered_map<ir::Block, std::size_t> indices;
std::unordered_map<ir::Block, std::size_t> lowlinks;
std::unordered_set<ir::Block> onStack;
std::vector<ir::Block> stack;
std::vector<std::unordered_set<ir::Block>> sccs;
std::size_t index = 0;
auto rootParent = (*nodes.begin()).getParent();
std::function<void(ir::Block)> strongConnect = [&](ir::Block node) {
indices[node] = index;
lowlinks[node] = index;
index++;
stack.push_back(node);
onStack.insert(node);
// Consider successors of node
for (auto successor : getSuccessors(node)) {
if (successor.getParent() != rootParent) {
continue;
}
if (!indices.contains(successor)) {
// Successor has not yet been visited; recurse on it
strongConnect(successor);
lowlinks[node] = std::min(lowlinks[node], lowlinks[successor]);
} else if (onStack.contains(successor)) {
// Successor is in stack and hence in the current SCC
lowlinks[node] = std::min(lowlinks[node], indices[successor]);
}
}
// If node is a root node, pop the stack and create an SCC
if (lowlinks[node] == indices[node]) {
std::unordered_set<ir::Block> scc;
scc.reserve(stack.size());
ir::Block w;
do {
w = stack.back();
stack.pop_back();
onStack.erase(w);
scc.insert(w);
} while (w != node);
// keep cycles only
if (!scc.empty()) {
auto isLoop = scc.size() > 1;
if (!isLoop) {
// single node can contain branch to self
isLoop = hasSuccessor(w, w);
}
if (isLoop) {
sccs.push_back(std::move(scc));
}
}
}
};
for (auto node : nodes) {
if (node.getParent() != rootParent) {
continue;
}
if (!indices.contains(node)) {
strongConnect(node);
}
}
return sccs;
}
void shader::transform::wrapLoopConstructs(spv::Context &context, ir::RegionLike root) {
auto region = root.children<ir::Block>();
auto sccs = findSCCs(region);
for (auto scc : sccs) {
auto edges = calculateCycleEdges(scc);
ir::Block bodyLabel;
ir::Block continueLabel;
ir::Block mergeLabel;
ir::Block latchLabel;
if (!edges.entryEdges.empty()) {
if (edges.entryEdges.size() == 1 && edges.backEdges.size() == 1 &&
edges.entryEdges[0].to() == edges.backEdges[0].to()) {
bodyLabel = edges.entryEdges[0].to();
continueLabel = edges.backEdges[0].from();
}
if (!bodyLabel) {
std::vector<Edge> entryEdges = edges.entryEdges;
// back edges should jump to entry block
entryEdges.insert(entryEdges.end(), edges.backEdges.begin(),
edges.backEdges.end());
// for loop no need to split blocks, we can just rotate loop
bodyLabel = createRouteBlock(
context, ir::InsertionPoint::createInsertBefore(*scc.begin()),
entryEdges);
scc.insert(bodyLabel);
edges = calculateCycleEdges(scc);
}
if (!continueLabel || bodyLabel == continueLabel ||
getSuccessorCount(continueLabel) != 1) {
std::unordered_set<ir::Block> preds;
for (auto edge : edges.backEdges) {
preds.insert(edge.from());
}
continueLabel = createMergeBlock(
context, ir::InsertionPoint::createInsertAfter(bodyLabel), preds,
bodyLabel);
scc.insert(continueLabel);
edges = calculateCycleEdges(scc);
}
}
if (!edges.exitEdges.empty()) {
mergeLabel = [&] -> ir::Block {
auto exitEdges = std::span(edges.exitEdges);
auto header = exitEdges[0].to();
exitEdges = exitEdges.subspan(1);
while (!exitEdges.empty()) {
if (header != exitEdges[0].to()) {
return {};
}
exitEdges = exitEdges.subspan(1);
}
return header;
}();
if (mergeLabel) {
auto predecessors = getPredecessors(mergeLabel);
for (auto pred : predecessors) {
if (!scc.contains(pred)) {
mergeLabel = {};
break;
}
}
if (mergeLabel && predecessors.size() == 1) {
latchLabel = *predecessors.begin();
auto latchSuccessors = getSuccessors(latchLabel);
auto it = latchSuccessors.begin();
auto firstSuccessor = *it;
auto secondSuccessor = *++it;
if ((firstSuccessor != continueLabel &&
secondSuccessor != continueLabel)) {
latchLabel = {};
mergeLabel = {};
}
if (latchLabel && getPredecessorCount(continueLabel) != 1) {
latchLabel = {};
}
}
}
if (!mergeLabel) {
mergeLabel = createRouteBlock(
context,
ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()),
edges.exitEdges);
edges = calculateCycleEdges(scc);
}
if (!latchLabel) {
std::vector<Edge> exitEdges = edges.exitEdges;
for (auto [pred, operandIndex] : getAllPredecessors(continueLabel)) {
exitEdges.emplace_back(pred, operandIndex);
}
latchLabel = createRouteBlock(
context,
ir::InsertionPoint::createInsertAfter(edges.exitEdges[0].from()),
exitEdges);
scc.insert(latchLabel);
}
}
if (bodyLabel && continueLabel && mergeLabel) {
auto loopConstruct = createLoopConstruct(
context, root, bodyLabel, latchLabel, continueLabel, mergeLabel, scc);
// replace references to body outside this construct with header (i.e.
// loop construct node)
bodyLabel.replaceUsesIf(loopConstruct, [=](ir::ValueUse use) {
return (isTerminator(use.user) ||
(use.user != loopConstruct && isConstruct(use.user))) &&
getConstructOf(use.user) != loopConstruct;
});
// move PHIs to construct
for (auto phi : ir::range(bodyLabel.getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
phi.erase();
loopConstruct.prependChild(phi);
}
}
}
}
void shader::transform::wrapSelectionConstructs(spv::Context &context,
ir::RegionLike root) {
std::vector<ir::Range<ir::Block>> workList;
workList.push_back(root.children<ir::Block>());
std::unordered_set<ir::Block> usedMergeBlocks;
while (!workList.empty()) {
auto region = workList.back();
workList.pop_back();
for (auto entryBlock : region) {
if (isConstruct(entryBlock)) {
if (entryBlock == ir::builtin::SELECTION_CONSTRUCT) {
if (auto body =
skipPhis(entryBlock.getFirst()).getNext().cast<ir::Block>()) {
workList.emplace_back(ir::range(body));
}
} else if (auto body =
skipPhis(entryBlock.getFirst()).cast<ir::Block>()) {
workList.emplace_back(ir::range(body));
}
continue;
}
auto terminator = entryBlock.getLast();
if (!terminator || !isTerminator(terminator) ||
(terminator != ir::spv::OpBranchConditional &&
terminator != ir::spv::OpSwitch)) {
continue;
}
ir::RegionLike parentConstruct = getConstructOf(entryBlock);
if (auto parentSelection =
parentConstruct.cast<ir::SelectionConstruct>()) {
if (parentSelection.getHeader() == entryBlock) {
continue;
}
}
auto successors = getSuccessors(entryBlock);
if (parentConstruct) {
if (parentConstruct.getLast() == entryBlock) {
// do not look at latch/continuation blocks
continue;
}
}
if (!parentConstruct) {
parentConstruct = root;
}
std::unordered_set<ir::Block> components;
components.insert(entryBlock);
auto addConstructComponent = [&](ir::Construct construct) {
components.insert(construct);
// add whole body of construct
for (auto child : construct.children<ir::Block>()) {
components.insert(child);
}
if (auto loop = construct.cast<ir::LoopConstruct>()) {
// it if is loop, add continue construct also
for (auto child : loop.getContinue().children<ir::Block>()) {
components.insert(child);
}
}
auto constructMerge = construct.getMerge();
if (parentConstruct != root &&
getSuccessorCount(construct.getMerge()) == 0) {
// we cannot take this merge block, it is exit from function block
// create trampoline node and replace merge block of this node
auto newMerge = createMergeBlock(
context, ir::InsertionPoint::createInsertBefore(constructMerge),
getPredecessors(constructMerge), constructMerge);
construct.setMerge(newMerge);
constructMerge = newMerge;
}
components.insert(constructMerge);
return getSuccessors(constructMerge);
};
auto addComponent = [&](ir::Block block) {
if (auto construct = block.cast<ir::Construct>()) {
return addConstructComponent(construct);
}
if (hasAtLeastSuccessors(block, 1)) {
components.insert(block);
return getSuccessors(block);
}
auto trampoline = createMergeBlock(
context, ir::InsertionPoint::createInsertBefore(block),
getPredecessors(block), block);
components.insert(trampoline);
return getSuccessors(block);
};
{
// try to find blocks that has no other predecessors
auto parentEntry =
skipPhis(parentConstruct.getFirst()).staticCast<ir::Block>();
auto headerSuccessors = getSuccessors(entryBlock);
std::vector<ir::Block> workList(headerSuccessors.begin(),
headerSuccessors.end());
while (!workList.empty()) {
auto block = workList.back();
workList.pop_back();
if (components.contains(block)) {
continue;
}
if (block.getParent() != parentConstruct) {
continue;
}
if (block == parentEntry || block == parentConstruct.getLast()) {
// do not take entry/latch/continuation of parent construct
continue;
}
bool hasAllPreds = true;
auto loop = block.cast<ir::LoopConstruct>();
for (auto pred : getPredecessors(block)) {
if (components.contains(pred)) {
continue;
}
if (loop && pred == loop.getContinue().getLast()) {
// ignore continue predecessor of loop
continue;
}
hasAllPreds = false;
break;
}
if (hasAllPreds) {
addComponent(block);
}
}
}
if (components.size() == 1) {
// all successors are used by nodes outside this header, it means it is
// not structured loop node or case block of OpSwitch with fallthrough
continue;
}
ir::Block entryLabel = entryBlock;
ir::Block mergeLabel;
bool mergeInserted = false;
std::unordered_set<ir::Block> exitBlocks;
std::vector<Edge> exitEdges;
for (auto block : components) {
for (auto [succ, operandIndex] : getAllSuccessors(block)) {
if (!components.contains(succ)) {
exitEdges.emplace_back(block, operandIndex);
exitBlocks.insert(block);
}
}
}
if (!exitBlocks.empty()) {
if (exitBlocks.size() == 1) {
mergeLabel = *exitBlocks.begin();
}
if (!mergeLabel ||
getAllPredecessors(mergeLabel).size() != exitEdges.size() ||
isConstruct(mergeLabel)) {
mergeLabel = createRouteBlock(
context, ir::InsertionPoint::createInsertAfter(entryBlock),
exitEdges);
workList.emplace_back(ir::range(mergeLabel));
mergeInserted = true;
}
} else {
mergeLabel = parentConstruct.getLast().staticCast<ir::Block>();
}
if (!mergeInserted) {
for (auto user : mergeLabel.getUserList()) {
if (auto construct = user.cast<ir::Construct>()) {
if (construct.getMerge() != mergeLabel) {
continue;
}
}
mergeLabel = createMergeBlock(
context, ir::InsertionPoint::createInsertBefore(mergeLabel),
getPredecessors(mergeLabel), mergeLabel);
mergeInserted = true;
break;
}
}
if (!mergeInserted) {
auto mergePreds = getPredecessors(mergeLabel);
std::unordered_set<ir::Block> branchesInsideConstruct;
for (auto pred : mergePreds) {
if (components.contains(pred)) {
branchesInsideConstruct.insert(pred);
}
}
if (branchesInsideConstruct.size() != mergePreds.size()) {
mergeLabel = createMergeBlock(
context, ir::InsertionPoint::createInsertBefore(mergeLabel),
branchesInsideConstruct, mergeLabel);
}
}
auto construct = createSelectionConstruct(
context, parentConstruct, components, entryLabel, mergeLabel);
// update merge label
construct.setMerge(mergeLabel);
construct.getHeader().replaceUsesIf(construct, [=](ir::ValueUse use) {
if (getConstructOf(use.user) != construct) {
if (isTerminator(use.user)) {
return true;
}
// allow update block merges
if (isConstruct(use.user) && use.operandIndex == 1) {
return true;
}
}
if (use.user != construct && isConstruct(use.user)) {
return true;
}
return false;
});
// move PHIs to construct
for (auto phi : ir::range(construct.getHeader().getFirst())) {
if (phi != ir::spv::OpPhi) {
break;
}
phi.erase();
construct.prependChild(phi);
}
// view child constructs
if (auto child = construct.getHeader().getNext().cast<ir::Block>()) {
workList.emplace_back(ir::range(child));
}
// view next constructs
if (auto next = construct.getNext()) {
workList.emplace_back(ir::range(next));
}
break;
}
}
}

52
test/CMakeLists.txt Normal file
View file

@ -0,0 +1,52 @@
# Tests and Benchmarks Configuration
cmake_minimum_required(VERSION 3.16)
project(MemoryTableTests)
# Add Google Benchmark as an external project if not found
find_package(benchmark QUIET)
if(NOT benchmark_FOUND)
include(FetchContent)
FetchContent_Declare(
googlebenchmark
GIT_REPOSITORY https://github.com/google/benchmark.git
GIT_TAG v1.9.4
)
# Disable benchmark tests and examples to speed up build
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable Google Test for benchmark" FORCE)
set(BENCHMARK_ENABLE_WERROR OFF CACHE BOOL "Disable warnings as errors" FORCE)
FetchContent_MakeAvailable(googlebenchmark)
endif()
# Add Google Test for unit testing
find_package(GTest QUIET)
if(NOT GTest_FOUND)
include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.15.2
)
# Disable gtest installation
set(INSTALL_GTEST OFF CACHE BOOL "Disable gtest installation" FORCE)
set(gtest_force_shared_crt ON CACHE BOOL "Force shared CRT for gtest" FORCE)
FetchContent_MakeAvailable(googletest)
endif()
add_executable(gcn_shader_tests
gcn_shader_tests.cpp
)
target_link_libraries(gcn_shader_tests
gtest
gtest_main
gcn-shader
)

253
test/gcn_shader_tests.cpp Normal file
View file

@ -0,0 +1,253 @@
#include <gtest/gtest.h>
#include <memory>
// Include shader framework for CFG testing
#include "shader/SpvConverter.hpp"
#include "shader/analyze.hpp"
#include "shader/dialect.hpp"
#include "shader/ir.hpp"
#include "shader/ir/Context.hpp"
#include "shader/spv.hpp"
#include "shader/transform.hpp"
using namespace shader;
using Builder = ir::Builder<ir::spv::Builder>;
class GcnShaderTest : public ::testing::Test {
protected:
void SetUp() override {
// Setup SPIR-V context for CFG testing
context = std::make_unique<spv::Context>();
loc = context->getUnknownLocation();
trueV = context->getTrue();
falseV = context->getFalse();
}
void TearDown() override { context.reset(); }
ir::Value createLabel(const std::string &name) {
auto builder = Builder::createAppend(
*context, context->layout.getOrCreateFunctions(*context));
auto label = builder.createSpvLabel(loc);
context->ns.setNameOf(label, name);
return label;
}
void createBranch(ir::Value from, ir::Value to) {
Builder::createInsertAfter(*context, from).createSpvBranch(loc, to);
}
void createConditionalBranch(ir::Value from, ir::Value a, ir::Value b) {
Builder::createInsertAfter(*context, from)
.createSpvBranchConditional(loc, trueV, a, b);
}
void createReturn(ir::Value from) {
Builder::createInsertAfter(*context, from).createSpvReturn(loc);
}
void createSwitch(ir::Value from, std::span<const ir::Value> cases) {
auto globals = Builder::createAppend(
*context, context->layout.getOrCreateGlobals(*context));
auto globalVariable = globals.createSpvVariable(
loc, context->getTypeUInt32(), ir::spv::StorageClass::Private,
context->imm32(0));
auto switchOp = Builder::createInsertAfter(*context, from)
.createSpvSwitch(loc, globalVariable, cases[0]);
std::uint32_t i = 0;
for (auto c : cases.subspan(1)) {
switchOp.addOperand(i++);
switchOp.addOperand(c);
}
}
void createSwitchBranch(ir::Value from, ir::Value defaultTarget,
const std::vector<std::pair<std::uint32_t, ir::Value>>& cases) {
// Create a switch value (use a constant for testing)
auto type = context->getTypeUInt32();
auto globals = Builder::createAppend(
*context, context->layout.getOrCreateGlobals(*context));
auto globalVariable = globals.createSpvConstant(
loc, type, 0);
auto builder = Builder::createInsertAfter(*context, from);
auto switchInst =
builder.createSpvSwitch(loc, globalVariable, defaultTarget);
// Add each case
for (const auto& [value, target] : cases) {
switchInst.addOperand(value);
switchInst.addOperand(target);
}
}
bool testStructurization() {
auto region = context->layout.getOrCreateFunctions(*context);
context->layout.regions[spv::BinaryLayout::kFunctions] = {};
auto functions = context->layout.getOrCreateFunctions(*context);
structurizeCfg(*context, region);
{
auto debugs = Builder::createAppend(
*context, context->layout.getOrCreateDebugs(*context));
auto cfg = buildCFG(region.getFirst());
for (auto node : cfg.getPreorderNodes()) {
auto value = node->getLabel();
if (auto name = context->ns.tryGetNameOf(value); !name.empty()) {
debugs.createSpvName(loc, value, std::string(name));
}
}
for (auto bb : cfg.getPreorderNodes()) {
for (auto child : bb->range()) {
child.erase();
functions.addChild(child);
}
}
}
region = functions;
auto entryLabel = region.getFirst().cast<ir::Value>();
auto memModel = Builder::createAppend(
*context, context->layout.getOrCreateMemoryModels(*context));
auto capabilities = Builder::createAppend(
*context, context->layout.getOrCreateCapabilities(*context));
capabilities.createSpvCapability(loc, ir::spv::Capability::Shader);
memModel.createSpvMemoryModel(loc, ir::spv::AddressingModel::Logical,
ir::spv::MemoryModel::GLSL450);
auto mainReturnT = context->getTypeVoid();
auto mainFnT = context->getTypeFunction(mainReturnT, {});
auto builder = Builder::createPrepend(*context, region);
auto mainFn = builder.createSpvFunction(
loc, mainReturnT, ir::spv::FunctionControl::None, mainFnT);
builder.createSpvLabel(loc);
builder.createSpvBranch(loc, entryLabel);
Builder::createAppend(*context, region).createSpvFunctionEnd(loc);
auto entryPoints = Builder::createAppend(
*context, context->layout.getOrCreateEntryPoints(*context));
auto executionModes = Builder::createAppend(
*context, context->layout.getOrCreateExecutionModes(*context));
executionModes.createSpvExecutionMode(
mainFn.getLocation(), mainFn,
ir::spv::ExecutionMode::LocalSize(1, 1, 1));
entryPoints.createSpvEntryPoint(mainFn.getLocation(),
ir::spv::ExecutionModel::GLCompute, mainFn,
"main", {});
auto spv = shader::spv::serialize(context->layout.merge(*context));
if (shader::spv::validate(spv)) {
return true;
}
shader::spv::dump(spv, true);
return false;
}
protected:
std::unique_ptr<spv::Context> context;
ir::Location loc;
ir::Value trueV;
ir::Value falseV;
};
TEST_F(GcnShaderTest, ProjectDivaTest1) {
auto _1 = createLabel("1");
auto _2 = createLabel("2");
auto _3 = createLabel("3");
auto _4 = createLabel("4");
auto _5 = createLabel("5");
auto _6 = createLabel("6");
auto _7 = createLabel("7");
auto _8 = createLabel("8");
auto _9 = createLabel("9");
auto _10 = createLabel("10");
auto _11 = createLabel("11");
auto _12 = createLabel("12");
auto _13 = createLabel("13");
createBranch(_1, _2);
createConditionalBranch(_2, _4, _3);
createConditionalBranch(_3, _12, _11);
createConditionalBranch(_4, _6, _5);
createConditionalBranch(_5, _9, _8);
createBranch(_6, _7);
createBranch(_7, _6);
createBranch(_8, _3);
createBranch(_9, _10);
createBranch(_10, _7);
createBranch(_11, _12);
createBranch(_12, _13);
createReturn(_13);
EXPECT_TRUE(testStructurization());
}
TEST_F(GcnShaderTest, BatmanReturnToArkham1) {
auto _1 = createLabel("1");
auto _2 = createLabel("2");
auto _3 = createLabel("3");
auto _4 = createLabel("4");
auto _5 = createLabel("5");
auto _6 = createLabel("6");
auto _7 = createLabel("7");
auto _8 = createLabel("8");
auto _9 = createLabel("9");
auto _10 = createLabel("10");
auto _11 = createLabel("11");
auto _12 = createLabel("12");
auto _13 = createLabel("13");
auto _14 = createLabel("14");
auto _15 = createLabel("15");
auto _16 = createLabel("16");
auto _17 = createLabel("17");
auto _18 = createLabel("18");
auto _19 = createLabel("19");
auto _20 = createLabel("20");
auto _21 = createLabel("21");
auto _22 = createLabel("22");
auto _23 = createLabel("23");
auto _24 = createLabel("24");
auto _25 = createLabel("25");
createBranch(_1, _2);
createConditionalBranch(_2, _4, _3);
createConditionalBranch(_3, _6, _5);
createBranch(_4, _3);
createConditionalBranch(_5, _8, _7);
createBranch(_6, _5);
createConditionalBranch(_7, _10, _9);
createBranch(_8, _7);
createConditionalBranch(_9, _12, _11);
createBranch(_10, _9);
createConditionalBranch(_11, _14, _13);
createBranch(_12, _11);
createConditionalBranch(_13, _16, _15);
createBranch(_14, _13);
createBranch(_15, _25);
createConditionalBranch(_16, _18, _17);
createBranch(_17, _18);
createConditionalBranch(_18, _20, _19);
createBranch(_19, _20);
createConditionalBranch(_20, _22, _21);
createBranch(_21, _22);
createConditionalBranch(_22, _24, _23);
createBranch(_23, _24);
createBranch(_24, _15);
createReturn(_25);
EXPECT_TRUE(testStructurization());
}