shader: add switch canonicalization transform

This commit is contained in:
DH 2025-12-04 21:20:17 +03:00
parent 92703954d0
commit 0f8a3dd1db
4 changed files with 140 additions and 70 deletions

View file

@ -4,8 +4,9 @@
#include "ir.hpp"
namespace shader::transform {
ir::Value transformToCanonicalRegion(spv::Context &context,
ir::RegionLike region);
void transformToCf(spv::Context &context, ir::RegionLike region);
void transformToFlat(spv::Context &context, ir::RegionLike region);
ir::Value toCanonicalRegion(spv::Context &context, ir::RegionLike region);
void toCf(spv::Context &context, ir::RegionLike region);
void toFlat(spv::Context &context, ir::RegionLike region);
void canonicalizeSwitchSelectionConstructs(spv::Context &context,
ir::RegionLike root);
} // namespace shader::transform

View file

@ -1,31 +1,27 @@
#include "transform.hpp"
#include "transform/transformations.hpp"
#include "transform/wrap.hpp"
#include "SpvConverter.hpp"
#include "dialect.hpp"
#include <iostream>
#include <rx/die.hpp>
using namespace shader;
using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
void shader::structurizeCfg(spv::Context &context, ir::RegionLike region) {
// std::cerr << "before transforms: ";
// region.print(std::cerr, context.ns);
// std::cerr << "\n";
transformToCanonicalRegion(context, region);
transformToCf(context, region);
transform::toCanonicalRegion(context, region);
transform::toCf(context, region);
wrapLoopConstructs(context, region);
wrapSelectionConstructs(context, region);
transform::wrapLoopConstructs(context, region);
transform::wrapSelectionConstructs(context, region);
transform::canonicalizeSwitchSelectionConstructs(context, region);
// std::cerr << "structured: ";
// region.print(std::cerr, context.ns);
// std::cerr << "\n";
transformToFlat(context, region);
transform::toFlat(context, region);
// std::cerr << "flat: ";
// region.print(std::cerr, context.ns);

View file

@ -151,32 +151,6 @@ static std::unordered_map<ir::Value, std::uint32_t> createRouteTerminator(
successorToId.reserve(toPreds.size());
auto hasBranchesTo = [](ir::Block from, ir::Block to) {
std::vector<ir::Block> workList;
std::unordered_set<ir::Block> visited;
workList.push_back(from);
visited.insert(from);
while (!workList.empty()) {
auto block = workList.back();
workList.pop_back();
if (block == to) {
return true;
}
for (auto succ : getSuccessors(block)) {
if (visited.insert(succ).second) {
workList.push_back(succ);
}
}
}
visited.insert(from);
return false;
};
for (std::uint32_t id = 0; auto &[succ, pred] : toPreds) {
if (id) {
routeSwitch.addOperand(id);
@ -184,29 +158,6 @@ static std::unordered_map<ir::Value, std::uint32_t> createRouteTerminator(
}
successorToId[succ] = id++;
}
auto caseCount = routeSwitch.getOperandCount() / 2 - 1;
for (std::size_t i = 1; i < caseCount; ++i) {
auto caseValue0 = routeSwitch.getOperand(2 + i * 2);
auto caseTarget0 = routeSwitch.getOperand(2 + i * 2 + 1)
.getAsValue()
.staticCast<ir::Block>();
for (std::size_t t = 0; t < i; ++t) {
auto caseValue1 = routeSwitch.getOperand(2 + t * 2);
auto caseTarget1 = routeSwitch.getOperand(2 + t * 2 + 1)
.getAsValue()
.staticCast<ir::Block>();
if (hasBranchesTo(caseTarget0, caseTarget1)) {
routeSwitch.replaceOperand(2 + i * 2, caseValue1);
routeSwitch.replaceOperand(2 + i * 2 + 1, caseTarget1);
routeSwitch.replaceOperand(2 + t * 2, caseValue0);
routeSwitch.replaceOperand(2 + t * 2 + 1, caseTarget0);
break;
}
}
}
}
return successorToId;

View file

@ -2,6 +2,7 @@
#include "SpvConverter.hpp"
#include "analyze.hpp"
#include "dialect.hpp"
#include <list>
#include <rx/die.hpp>
#include <iostream>
@ -13,8 +14,8 @@ using namespace shader::transform;
using Builder = ir::Builder<ir::builtin::Builder, ir::spv::Builder>;
ir::Value shader::transform::transformToCanonicalRegion(spv::Context &context,
ir::RegionLike region) {
ir::Value shader::transform::toCanonicalRegion(spv::Context &context,
ir::RegionLike region) {
auto cfg = buildCFG(region.getFirst());
std::vector<CFG::Node *> exitNodes;
for (auto node : cfg.getPreorderNodes()) {
@ -136,9 +137,9 @@ ir::Value shader::transform::transformToCanonicalRegion(spv::Context &context,
return newExitBlock;
}
void shader::transform::transformToCf(spv::Context &context,
ir::RegionLike region) {
void shader::transform::toCf(spv::Context &context, ir::RegionLike region) {
ir::Block currentBlock;
ir::Block terminationBlock;
for (auto inst : region.children()) {
if (inst == ir::builtin::BLOCK) {
@ -170,13 +171,21 @@ void shader::transform::transformToCf(spv::Context &context,
currentBlock.addChild(inst);
if (isTerminator(inst)) {
if (!isBranch(inst)) {
terminationBlock = currentBlock;
}
currentBlock = nullptr;
}
}
if (terminationBlock != nullptr) {
terminationBlock.erase();
region.addChild(terminationBlock);
}
}
void shader::transform::transformToFlat(spv::Context &context,
ir::RegionLike region) {
void shader::transform::toFlat(spv::Context &context, ir::RegionLike region) {
std::vector<ir::Instruction> workList;
workList.push_back(region.getFirst());
@ -279,3 +288,116 @@ void shader::transform::transformToFlat(spv::Context &context,
insertPoint.eraseAndInsert(inst);
}
}
static void
toCanonicalSwitchSelectionConstruct(spv::Context &context,
ir::SelectionConstruct switchConstruct) {
auto switchOp = switchConstruct.getHeader().getLast();
auto mergeBlock = switchConstruct.getMerge();
struct CaseInfo {
ir::Operand value;
ir::Block fallthroughBlock;
};
std::unordered_map<ir::Block, CaseInfo> cases;
for (std::size_t i = 2; i < switchOp.getOperandCount();) {
if (switchOp.getOperand(i + 1) == mergeBlock) {
i += 2;
} else {
auto value = switchOp.eraseOperand(i);
auto target = switchOp.eraseOperand(i).getAsValue().cast<ir::Block>();
cases[target] = {.value = std::move(value)};
}
}
if (cases.empty()) {
return;
}
std::vector<ir::Block> workList;
for (auto &[target, caseInfo] : cases) {
workList.push_back(target);
while (!workList.empty()) {
auto block = workList.back();
workList.pop_back();
if (block == mergeBlock) {
continue;
}
if (block != target && cases.contains(block)) {
caseInfo.fallthroughBlock = block;
workList.clear();
break;
}
if (auto construct = block.cast<ir::Construct>()) {
workList.push_back(construct.getMerge());
continue;
}
for (auto succ : getSuccessors(block)) {
workList.push_back(succ);
}
}
}
std::list<ir::Block> sortedCases;
for (auto &[target, caseInfo] : cases) {
if (caseInfo.fallthroughBlock == nullptr) {
sortedCases.push_back(target);
}
}
assert(!sortedCases.empty());
for (auto &[target, caseInfo] : cases) {
if (caseInfo.fallthroughBlock == nullptr) {
continue;
}
auto it = sortedCases.begin();
while (it != sortedCases.end()) {
if (caseInfo.fallthroughBlock == *it) {
break;
}
++it;
}
sortedCases.insert(it, target);
}
for (auto target : sortedCases) {
auto &info = cases.at(target);
switchOp.addOperand(info.value);
switchOp.addOperand(target);
}
}
void shader::transform::canonicalizeSwitchSelectionConstructs(
spv::Context &context, ir::RegionLike root) {
std::vector<ir::Range<ir::Block>> workList;
workList.push_back(root.children<ir::Block>());
while (!workList.empty()) {
auto region = workList.back();
workList.pop_back();
for (auto entryBlock : region) {
if (auto selection = entryBlock.cast<ir::SelectionConstruct>()) {
if (selection.getHeader().getLast() == ir::spv::OpSwitch) {
toCanonicalSwitchSelectionConstruct(context, selection);
}
}
workList.emplace_back(entryBlock.children<ir::Block>());
}
}
}