shader: improve to structural conversion transform

fix memory leak in spv::dump
properly print blocks
add loop, continue and selection construct nodes
add getParent() and print(...) to RegionLike
This commit is contained in:
DH 2025-09-20 22:05:32 +03:00
parent 4b854f46a8
commit 67f3ece45a
29 changed files with 2172 additions and 1211 deletions

View file

@ -28,6 +28,7 @@ struct Context : ir::Context {
Context(); Context();
ir::Region createRegion(ir::Location loc);
ir::Value createRegionWithLabel(ir::Location loc); ir::Value createRegionWithLabel(ir::Location loc);
void setName(ir::spv::IdRef inst, std::string name); void setName(ir::spv::IdRef inst, std::string name);
@ -35,6 +36,7 @@ struct Context : ir::Context {
ir::Value getOrCreateConstant(ir::Value typeValue, const ir::Operand &value); ir::Value getOrCreateConstant(ir::Value typeValue, const ir::Operand &value);
ir::Value getNull(ir::Value typeValue); ir::Value getNull(ir::Value typeValue);
ir::Value getUndef(ir::Value typeValue);
ir::Value getType(ir::spv::Op baseType, int width, bool isSigned); ir::Value getType(ir::spv::Op baseType, int width, bool isSigned);
ir::Value getType(const TypeInfo &info); ir::Value getType(const TypeInfo &info);

View file

@ -15,8 +15,6 @@
#include <vector> #include <vector>
namespace shader { namespace shader {
struct DomTree;
struct PostDomTree;
class CFG { class CFG {
public: public:
class Node { class Node {
@ -42,7 +40,6 @@ public:
mSuccessors.insert(to); mSuccessors.insert(to);
} }
bool hasPredecessor(Node *node) { return mPredecessors.contains(node); }
bool hasSuccessor(Node *node) { return mSuccessors.contains(node); } bool hasSuccessor(Node *node) { return mSuccessors.contains(node); }
auto &getPredecessors() { return mPredecessors; } auto &getPredecessors() { return mPredecessors; }
auto &getSuccessors() { return mSuccessors; } auto &getSuccessors() { return mSuccessors; }
@ -118,16 +115,6 @@ public:
void print(std::ostream &os, ir::NameStorage &ns, bool subgraph = false, void print(std::ostream &os, ir::NameStorage &ns, bool subgraph = false,
std::string_view nameSuffix = ""); std::string_view nameSuffix = "");
std::string genTest(); std::string genTest();
CFG buildView(CFG::Node *from, PostDomTree *domTree = nullptr,
const std::unordered_set<ir::Value> &stopLabels = {},
ir::Value continueLabel = nullptr);
CFG buildView(ir::Value from, PostDomTree *domTree = nullptr,
const std::unordered_set<ir::Value> &stopLabels = {},
ir::Value continueLabel = nullptr) {
return buildView(getNode(from), domTree, stopLabels, continueLabel);
}
}; };
class MemorySSA { class MemorySSA {
@ -182,11 +169,27 @@ bool isWithoutSideEffects(ir::InstructionId id);
bool isTerminator(ir::Instruction inst); bool isTerminator(ir::Instruction inst);
bool isBranch(ir::Instruction inst); bool isBranch(ir::Instruction inst);
ir::Value unwrapPointer(ir::Value pointer); ir::Value unwrapPointer(ir::Value pointer);
ir::Instruction getTerminator(ir::RegionLike region);
std::vector<std::pair<ir::Block, int>> getAllSuccessors(ir::Block region);
std::vector<std::pair<ir::Block, int>> getAllPredecessors(ir::Block region);
std::unordered_set<ir::Block> getSuccessors(ir::Block region);
std::unordered_set<ir::Block> getPredecessors(ir::Block region);
std::size_t getSuccessorCount(ir::Block region);
std::size_t getPredecessorCount(ir::Block region);
bool hasSuccessor(ir::Block region, ir::Block successor);
bool hasAtLeastSuccessors(ir::Block region, std::size_t count);
ir::Block getUniqSuccessor(ir::Block region);
graph::DomTree<ir::Block> buildDomTree(ir::Block block);
graph::DomTree<ir::Block> buildPostDomTree(ir::Block block);
graph::DomTree<ir::Block> buildDomTree(ir::RegionLike region);
graph::DomTree<ir::Block> buildPostDomTree(ir::RegionLike region);
graph::DomTree<ir::Value> buildDomTree(CFG &cfg, ir::Value root = nullptr); graph::DomTree<ir::Value> buildDomTree(CFG &cfg, ir::Value root = nullptr);
graph::DomTree<ir::Value> buildPostDomTree(CFG &cfg, ir::Value root); graph::DomTree<ir::Value> buildPostDomTree(CFG &cfg, ir::Value root);
CFG buildCFG(ir::Instruction firstInstruction, CFG buildCFG(ir::Instruction firstInstruction, ir::Value exitLabel = nullptr,
const std::unordered_set<ir::Value> &exitLabels = {},
ir::Value continueLabel = nullptr); ir::Value continueLabel = nullptr);
MemorySSA buildMemorySSA(CFG &cfg, ModuleInfo *moduleInfo = nullptr); MemorySSA buildMemorySSA(CFG &cfg, ModuleInfo *moduleInfo = nullptr);
@ -200,23 +203,6 @@ bool dominates(ir::Instruction a, ir::Instruction b, bool isPostDom,
ir::Value findNearestCommonDominator(ir::Instruction a, ir::Instruction b, ir::Value findNearestCommonDominator(ir::Instruction a, ir::Instruction b,
graph::DomTree<ir::Value> &domTree); graph::DomTree<ir::Value> &domTree);
class BackEdgeStorage {
std::unordered_map<ir::Value, std::unordered_set<ir::Value>> backEdges;
public:
BackEdgeStorage() = default;
BackEdgeStorage(CFG &cfg);
const std::unordered_set<ir::Value> *get(ir::Value value) {
if (auto it = backEdges.find(value); it != backEdges.end()) {
return &it->second;
}
return nullptr;
}
auto &all() { return backEdges; }
};
struct AnalysisStorage { struct AnalysisStorage {
template <typename... T> template <typename... T>
requires(sizeof...(T) > 0) requires(sizeof...(T) > 0)
@ -245,9 +231,7 @@ struct AnalysisStorage {
{ {
void *result = getImpl( void *result = getImpl(
rx::TypeId::get<T>(), getDeleter<T>(), rx::TypeId::get<T>(), getDeleter<T>(),
[&] { [&] { return new T(std::forward<ArgsT>(args)...); },
return std::make_unique<T>(std::forward<ArgsT>(args)...).release();
},
[&](void *object) { [&](void *object) {
*reinterpret_cast<T *>(object) = T(std::forward<ArgsT>(args)...); *reinterpret_cast<T *>(object) = T(std::forward<ArgsT>(args)...);
}); });
@ -261,10 +245,7 @@ struct AnalysisStorage {
{ {
void *result = getImpl( void *result = getImpl(
rx::TypeId::get<T>(), getDeleter<T>(), rx::TypeId::get<T>(), getDeleter<T>(),
[&] { [&] { return new T(std::forward<BuilderFn>(builder)()); },
return std::make_unique<T>(std::forward<BuilderFn>(builder)())
.release();
},
[&](void *object) { [&](void *object) {
*reinterpret_cast<T *>(object) = std::forward<BuilderFn>(builder)(); *reinterpret_cast<T *>(object) = std::forward<BuilderFn>(builder)();
}); });
@ -304,20 +285,20 @@ private:
std::map<rx::TypeId, Entry> mStorage; std::map<rx::TypeId, Entry> mStorage;
}; };
struct PostDomTree : graph::DomTree<ir::Value> { struct PostDomTree : graph::DomTree<ir::Block> {
PostDomTree() = default; PostDomTree() = default;
PostDomTree(graph::DomTree<ir::Value> &&other) PostDomTree(graph::DomTree<ir::Block> &&other)
: graph::DomTree<ir::Value>::DomTree(std::move(other)) {} : graph::DomTree<ir::Block>::DomTree(std::move(other)) {}
PostDomTree(CFG &cfg, ir::Value root) PostDomTree(ir::Block block) : PostDomTree(buildPostDomTree(block)) {}
: PostDomTree(buildPostDomTree(cfg, root)) {} PostDomTree(ir::RegionLike region) : DomTree(buildPostDomTree(region)) {}
}; };
struct DomTree : graph::DomTree<ir::Value> { struct DomTree : graph::DomTree<ir::Block> {
DomTree() = default; DomTree() = default;
DomTree(graph::DomTree<ir::Value> &&other) DomTree(graph::DomTree<ir::Block> &&other)
: graph::DomTree<ir::Value>::DomTree(std::move(other)) {} : graph::DomTree<ir::Block>::DomTree(std::move(other)) {}
DomTree(CFG &cfg, ir::Value root = nullptr) DomTree(ir::Block block) : DomTree(buildDomTree(block)) {}
: DomTree(buildDomTree(cfg, root)) {} DomTree(ir::RegionLike region) : DomTree(buildDomTree(region)) {}
}; };
template <typename T, std::size_t> struct Tag : T { template <typename T, std::size_t> struct Tag : T {
@ -337,107 +318,4 @@ template <typename T, std::size_t> struct Tag : T {
} }
}; };
struct Construct {
Construct *parent;
std::forward_list<Construct> children;
ir::Value header;
ir::Value merge;
ir::Value loopBody;
ir::Value loopContinue;
AnalysisStorage analysis;
static std::unique_ptr<Construct> createRoot(ir::RegionLike region,
ir::Value merge) {
auto result = std::make_unique<Construct>();
auto &cfg =
result->analysis.get<CFG>([&] { return buildCFG(region.getFirst()); });
result->header = cfg.getEntryLabel();
result->merge = merge;
return result;
}
Construct *createChild(ir::Value header, ir::Value merge) {
auto &result = children.emplace_front();
result.parent = this;
result.header = header;
result.merge = merge;
return &result;
}
Construct *createChild(ir::Value header, ir::Value merge,
ir::Value loopContinue, ir::Value loopBody) {
auto &result = children.emplace_front();
result.parent = this;
result.header = header;
result.merge = merge;
result.loopContinue = loopContinue;
result.loopBody = loopBody;
return &result;
}
Construct createTemporaryChild(ir::Value header, ir::Value merge) {
Construct result;
result.parent = this;
result.header = header;
result.merge = merge;
return result;
}
CFG &getCfg() {
return analysis.get<CFG>([this] {
if (parent != nullptr) {
return parent->getCfg().buildView(header, &parent->getPostDomTree(),
{header, merge});
}
return buildCFG(header);
});
}
CFG &getCfgWithoutContinue() {
if (loopContinue == nullptr) {
return getCfg();
}
return analysis.get<Tag<CFG, kWithoutContinue>>([this] {
if (parent != nullptr) {
return parent->getCfg().buildView(header, &parent->getPostDomTree(),
{header, merge}, loopContinue);
}
return buildCFG(header, {}, loopContinue);
});
}
DomTree &getDomTree() { return analysis.get<DomTree>(getCfg(), header); }
PostDomTree &getPostDomTree() {
return analysis.get<PostDomTree>(getCfg(), merge);
}
BackEdgeStorage &getBackEdgeStorage() {
return analysis.get<BackEdgeStorage>(getCfg());
}
BackEdgeStorage &getBackEdgeWithoutContinueStorage() {
if (loopContinue == nullptr) {
return getBackEdgeStorage();
}
return analysis.get<Tag<BackEdgeStorage, kWithoutContinue>>(
getCfgWithoutContinue());
}
auto getBackEdges(ir::Value node) { return getBackEdgeStorage().get(node); }
auto getBackEdgesWithoutContinue(ir::Value node) {
return getBackEdgeWithoutContinueStorage().get(node);
}
auto getBackEdges() { return getBackEdges(header); }
void invalidate();
void invalidateAll();
bool isNull() const { return header == nullptr; }
void removeLastChild() { children.pop_front(); }
private:
enum {
kWithoutContinue,
};
};
} // namespace shader } // namespace shader

View file

@ -1,6 +1,8 @@
#pragma once #pragma once
#include "../ir/Block.hpp" #include "../ir/Block.hpp"
#include "../ir/Builder.hpp" #include "../ir/Builder.hpp"
#include "../ir/LoopConstruct.hpp"
#include "../ir/SelectionConstruct.hpp"
#include "../ir/Value.hpp" #include "../ir/Value.hpp"
namespace shader::ir { namespace shader::ir {
@ -11,8 +13,9 @@ namespace shader::ir::builtin {
enum Op { enum Op {
INVALID_INSTRUCTION, INVALID_INSTRUCTION,
BLOCK, BLOCK,
IF_ELSE, LOOP_CONSTRUCT,
LOOP, CONTINUE_CONSTRUCT,
SELECTION_CONSTRUCT,
}; };
inline const char *getInstructionName(unsigned id) { inline const char *getInstructionName(unsigned id) {
@ -23,11 +26,14 @@ inline const char *getInstructionName(unsigned id) {
case BLOCK: case BLOCK:
return "block"; return "block";
case IF_ELSE: case LOOP_CONSTRUCT:
return "ifElse"; return "loop_construct";
case LOOP: case CONTINUE_CONSTRUCT:
return "loop"; return "continue_construct";
case SELECTION_CONSTRUCT:
return "selection_construct";
} }
return nullptr; return nullptr;
} }
@ -46,23 +52,30 @@ struct Builder : BuilderFacade<Builder<ImplT>, ImplT> {
INVALID_INSTRUCTION); INVALID_INSTRUCTION);
} }
Instruction createIfElse(Location location, Value cond, Block ifTrue, SelectionConstruct createSelectionConstruct(Location location, ir::Block body,
Block ifFalse = {}) { ir::Block merge) {
std::vector<Operand> operands = {{cond, ifTrue}}; return this->template create<SelectionConstruct>(
if (ifFalse) { location, Kind::Builtin, SELECTION_CONSTRUCT,
operands.push_back(ifFalse); std::span<const Operand>{{body, merge}});
}
return this->template create<Instruction>(location, Kind::Builtin, IF_ELSE,
operands);
} }
Instruction createLoop(Location location, Block body) { ContinueConstruct createContinueConstruct(Location location, ir::Block body,
return this->template create<Instruction>(location, Kind::Builtin, IF_ELSE, ir::Block merge) {
{{body}}); return this->template create<ContinueConstruct>(
location, Kind::Builtin, CONTINUE_CONSTRUCT,
std::span<const Operand>{{body, merge}});
}
LoopConstruct createLoopConstruct(Location location, ir::Block body,
ir::Block merge,
ContinueConstruct continueConstruct) {
return this->template create<LoopConstruct>(
location, Kind::Builtin, LOOP_CONSTRUCT,
std::span<const Operand>{{body, merge, continueConstruct}});
} }
auto createBlock(Location location) { auto createBlock(Location location) {
return this->template create<Block>(location); return this->template create<Block>(location, Kind::Builtin, BLOCK);
} }
auto createRegion(Location location) { auto createRegion(Location location) {

View file

@ -14,6 +14,7 @@ enum Op {
OpBarrier, OpBarrier,
OpJump, OpJump,
OpExit, OpExit,
OpScope,
OpCount, OpCount,
}; };
@ -24,8 +25,9 @@ template <typename BaseT> struct BaseImpl : BaseT {
using BaseT::BaseT; using BaseT::BaseT;
using BaseT::operator=; using BaseT::operator=;
void print(std::ostream &os, NameStorage &ns) const override { void print(std::ostream &os, NameStorage &ns,
BaseT::print(os, ns); const PrintOptions &opts) const override {
BaseT::print(os, ns, opts);
if (link) { if (link) {
os << " : "; os << " : ";
@ -220,7 +222,7 @@ struct ScopeWrapper : BaseWrapper<ImplT, ir::BlockWrapper> {
std::set<Var> result; std::set<Var> result;
std::vector<Var> workList; std::vector<Var> workList;
for (auto comp : var.getOperands()) { for (auto &comp : var.getOperands()) {
auto compVar = comp.getAsValue().staticCast<Var>(); auto compVar = comp.getAsValue().staticCast<Var>();
result.insert(compVar); result.insert(compVar);
@ -235,7 +237,7 @@ struct ScopeWrapper : BaseWrapper<ImplT, ir::BlockWrapper> {
auto var = workList.back(); auto var = workList.back();
workList.pop_back(); workList.pop_back();
for (auto comp : var.getOperands()) { for (auto &comp : var.getOperands()) {
auto compVar = comp.getAsValue().staticCast<Var>(); auto compVar = comp.getAsValue().staticCast<Var>();
result.insert(compVar); result.insert(compVar);
@ -353,7 +355,8 @@ struct Builder : BuilderFacade<Builder<ImplT>, ImplT> {
} }
Scope createScope(ir::Instruction labelInst) { Scope createScope(ir::Instruction labelInst) {
Scope result = this->template create<Scope>(labelInst.getLocation()); Scope result = this->template create<Scope>(labelInst.getLocation(),
ir::Kind::MemSSA, OpScope);
result.impl->link = labelInst; result.impl->link = labelInst;
return result; return result;
} }
@ -417,6 +420,8 @@ inline const char *getInstructionName(unsigned op) {
return "jump"; return "jump";
case OpExit: case OpExit:
return "exit"; return "exit";
case OpScope:
return "scope";
} }
return nullptr; return nullptr;
} }

View file

@ -19,10 +19,12 @@ struct Block : BlockWrapper<BlockImpl> {
}; };
struct BlockImpl : ValueImpl, RegionLikeImpl { struct BlockImpl : ValueImpl, RegionLikeImpl {
BlockImpl(Location loc); using ValueImpl::ValueImpl;
Node clone(Context &context, CloneMap &map) const override; Node clone(Context &context, CloneMap &map) const override;
void print(std::ostream &os, NameStorage &ns) const override { void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const override {
os << '%' << ns.getNameOf(const_cast<BlockImpl *>(this)); os << '%' << ns.getNameOf(const_cast<BlockImpl *>(this));
os << " = "; os << " = ";
@ -41,11 +43,13 @@ struct BlockImpl : ValueImpl, RegionLikeImpl {
} }
os << "{\n"; os << "{\n";
auto childOpts = opts.nextLevel();
for (auto child : children()) { for (auto child : children()) {
os << " "; childOpts.printIdent(os);
child.print(os, ns); child.print(os, ns, childOpts);
os << "\n"; os << "\n";
} }
opts.printIdent(os);
os << "}"; os << "}";
} }
}; };

View file

@ -22,6 +22,46 @@ template <typename BuilderT, typename ImplT> struct BuilderFacade {
} }
}; };
class InsertionPoint {
RegionLike mInsertionStorage;
Instruction mInsertionPoint;
public:
InsertionPoint() = default;
InsertionPoint(RegionLike storage, Instruction point)
: mInsertionStorage(storage), mInsertionPoint(point) {}
static InsertionPoint createInsertAfter(Instruction point) {
return {point.getParent(), point};
}
static InsertionPoint createInsertBefore(Instruction point) {
return {point.getParent(), point.getPrev()};
}
static InsertionPoint createAppend(RegionLike storage) {
return {storage, storage.getLast()};
}
static InsertionPoint createPrepend(RegionLike storage) {
return {storage, nullptr};
}
RegionLike getInsertionStorage() { return mInsertionStorage; }
Instruction getInsertionPoint() { return mInsertionPoint; }
void insert(ir::Instruction inst) {
getInsertionStorage().insertAfter(getInsertionPoint(), inst);
mInsertionPoint = inst;
}
void eraseAndInsert(ir::Instruction inst) {
inst.erase();
insert(inst);
}
};
template <template <typename> typename... InterfaceTs> template <template <typename> typename... InterfaceTs>
class Builder : public InterfaceTs<Builder<InterfaceTs...>>... { class Builder : public InterfaceTs<Builder<InterfaceTs...>>... {
Context *mContext{}; Context *mContext{};
@ -32,6 +72,13 @@ public:
Builder() = default; Builder() = default;
Builder(Context &context) : mContext(&context) {} Builder(Context &context) : mContext(&context) {}
static Builder create(Context &context, InsertionPoint point) {
auto result = Builder(context);
result.mInsertionStorage = point.getInsertionStorage();
result.mInsertionPoint = point.getInsertionPoint();
return result;
}
static Builder createInsertAfter(Context &context, Instruction point) { static Builder createInsertAfter(Context &context, Instruction point) {
auto result = Builder(context); auto result = Builder(context);
result.mInsertionStorage = point.getParent(); result.mInsertionStorage = point.getParent();
@ -42,14 +89,14 @@ public:
static Builder createInsertBefore(Context &context, Instruction point) { static Builder createInsertBefore(Context &context, Instruction point) {
auto result = Builder(context); auto result = Builder(context);
result.mInsertionStorage = point.getParent(); result.mInsertionStorage = point.getParent();
result.mInsertionPoint = point.getPrev().cast<Instruction>(); result.mInsertionPoint = point.getPrev();
return result; return result;
} }
static Builder createAppend(Context &context, RegionLike storage) { static Builder createAppend(Context &context, RegionLike storage) {
auto result = Builder(context); auto result = Builder(context);
result.mInsertionStorage = storage; result.mInsertionStorage = storage;
result.mInsertionPoint = storage.getLast().cast<Instruction>(); result.mInsertionPoint = storage.getLast();
return result; return result;
} }
@ -65,6 +112,10 @@ public:
Instruction getInsertionPoint() { return mInsertionPoint; } Instruction getInsertionPoint() { return mInsertionPoint; }
void setInsertionPoint(Instruction inst) { mInsertionPoint = inst; } void setInsertionPoint(Instruction inst) { mInsertionPoint = inst; }
InsertionPoint saveInsertionPoint() {
return { mInsertionStorage, mInsertionPoint };
}
template <typename T, typename... ArgsT> template <typename T, typename... ArgsT>
requires requires { requires requires {
typename T::underlying_type; typename T::underlying_type;
@ -73,12 +124,21 @@ public:
} }
T create(ArgsT &&...args) { T create(ArgsT &&...args) {
auto result = getContext().template create<T>(std::forward<ArgsT>(args)...); auto result = getContext().template create<T>(std::forward<ArgsT>(args)...);
using InstanceType = typename T::underlying_type;
getInsertionStorage().insertAfter(getInsertionPoint(), result); getInsertionStorage().insertAfter(getInsertionPoint(), result);
if constexpr (requires { mInsertionPoint = Instruction(result); }) { if constexpr (requires { mInsertionPoint = Instruction(result); }) {
mInsertionPoint = Instruction(result); mInsertionPoint = Instruction(result);
} }
return result; return result;
} }
void insert(ir::Instruction inst) {
getInsertionStorage().insertAfter(getInsertionPoint(), inst);
mInsertionPoint = inst;
}
void eraseAndInsert(ir::Instruction inst) {
inst.erase();
insert(inst);
}
}; };
} // namespace shader::ir } // namespace shader::ir

View file

@ -0,0 +1,42 @@
#pragma once
#include "Block.hpp"
namespace shader::ir {
template <typename ImplT>
struct ConstructWrapper : BlockWrapper<ImplT> {
using BlockWrapper<ImplT>::BlockWrapper;
using BlockWrapper<ImplT>::operator=;
ir::Block getHeader() {
return this->impl->getOperand(0)
.getAsValue()
.template staticCast<ir::Block>();
}
void setHeader(ir::Block block) {
this->impl->replaceOperand(0, block);
}
ir::Block getMerge() {
return this->impl->getOperand(1)
.getAsValue()
.template staticCast<ir::Block>();
}
void setMerge(ir::Block block) {
this->impl->replaceOperand(1, block);
}
};
struct ConstructImpl;
struct Construct : ConstructWrapper<ConstructImpl> {
using ConstructWrapper<ConstructImpl>::ConstructWrapper;
using ConstructWrapper<ConstructImpl>::operator=;
};
struct ConstructImpl : BlockImpl {
using BlockImpl::BlockImpl;
};
} // namespace shader::ir

View file

@ -4,8 +4,10 @@
#include "Block.hpp" #include "Block.hpp"
#include "Context.hpp" #include "Context.hpp"
#include "InstructionImpl.hpp" #include "InstructionImpl.hpp"
#include "LoopConstruct.hpp"
#include "NodeImpl.hpp" #include "NodeImpl.hpp"
#include "RegionImpl.hpp" #include "RegionImpl.hpp"
#include "SelectionConstruct.hpp"
#include "ValueImpl.hpp" #include "ValueImpl.hpp"
namespace shader::ir { namespace shader::ir {
@ -41,8 +43,8 @@ inline Operand InstructionImpl::eraseOperand(int index, int count) {
if (index + count == operands.size()) { if (index + count == operands.size()) {
auto result = replaceOperand(index, nullptr); auto result = replaceOperand(index, nullptr);
for (int i = 1; i < count; ++i) { for (std::size_t i = index + 1; i < operands.size(); ++i) {
replaceOperand(i + index, nullptr); replaceOperand(i, nullptr);
} }
operands.resize(operands.size() - count); operands.resize(operands.size() - count);
@ -51,11 +53,7 @@ inline Operand InstructionImpl::eraseOperand(int index, int count) {
auto result = replaceOperand(index, replaceOperand(index + 1, nullptr)); auto result = replaceOperand(index, replaceOperand(index + 1, nullptr));
for (int i = 1; i < count; ++i) { for (std::size_t i = index + 1; i < operands.size() - count; ++i) {
replaceOperand(index + i, nullptr);
}
for (int i = index + 1; i < operands.size() - count; ++i) {
replaceOperand(i, replaceOperand(i + count, nullptr)); replaceOperand(i, replaceOperand(i + count, nullptr));
} }
@ -146,6 +144,12 @@ inline void RegionLikeImpl::prependChild(Instruction node) {
assert(node.getPrev() == nullptr); assert(node.getPrev() == nullptr);
assert(node.getNext() == nullptr); assert(node.getNext() == nullptr);
#ifndef NDEBUG
if (auto thisInst = dynamic_cast<InstructionImpl *>(this)) {
assert(node != thisInst);
}
#endif
node.get()->parent = this; node.get()->parent = this;
if (last == nullptr) { if (last == nullptr) {
last = node; last = node;
@ -161,6 +165,12 @@ inline void RegionLikeImpl::addChild(Instruction node) {
assert(node.getPrev() == nullptr); assert(node.getPrev() == nullptr);
assert(node.getNext() == nullptr); assert(node.getNext() == nullptr);
#ifndef NDEBUG
if (auto thisInst = dynamic_cast<InstructionImpl *>(this)) {
assert(node != thisInst);
}
#endif
node.get()->parent = this; node.get()->parent = this;
if (first == nullptr) { if (first == nullptr) {
first = node; first = node;
@ -171,13 +181,33 @@ inline void RegionLikeImpl::addChild(Instruction node) {
last = node; last = node;
} }
inline void RegionImpl::print(std::ostream &os, NameStorage &ns) const { inline void RegionLikeImpl::printRegion(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const {
if (auto node = dynamic_cast<const NodeImpl *>(this)) {
node->print(os, ns, opts);
} else {
os << "<detached region>";
}
}
inline auto RegionLikeImpl::getParent() const {
if (auto inst = dynamic_cast<const InstructionImpl *>(this)) {
return inst->parent;
}
return RegionLike();
}
inline void RegionImpl::print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const {
opts.printIdent(os);
os << "{\n"; os << "{\n";
for (auto child : children()) { for (auto childIdent = opts.nextLevel(); auto child : children()) {
os << " "; childIdent.printIdent(os);
child.print(os, ns); child.print(os, ns, childIdent);
os << "\n"; os << "\n";
} }
opts.printIdent(os);
os << "}"; os << "}";
} }
@ -230,6 +260,24 @@ T cloneInstructionImpl(const U *object, Context &context, CloneMap &map,
return result; return result;
} }
template <typename T, typename U, typename... ArgsT>
requires(std::is_same_v<typename T::underlying_type, U>)
T cloneBlockImpl(const U *object, Context &context, CloneMap &map,
ArgsT &&...args) {
auto result = context.create<T>(clone(object->getLocation(), context),
std::forward<ArgsT>(args)...);
for (auto &&operand : object->getOperands()) {
result.addOperand(operand.clone(context, map));
}
for (auto &&child : object->children()) {
result.addChild(ir::clone(child, context, map));
}
return result;
}
} // namespace detail } // namespace detail
inline Node InstructionImpl::clone(Context &context, CloneMap &map) const { inline Node InstructionImpl::clone(Context &context, CloneMap &map) const {
@ -250,20 +298,24 @@ inline Node RegionImpl::clone(Context &context, CloneMap &map) const {
return result; return result;
} }
inline BlockImpl::BlockImpl(Location loc)
: ValueImpl(loc, ir::Kind::Builtin, builtin::BLOCK) {}
inline Node BlockImpl::clone(Context &context, CloneMap &map) const { inline Node BlockImpl::clone(Context &context, CloneMap &map) const {
auto result = context.create<Block>(ir::clone(getLocation(), context)); return detail::cloneBlockImpl<Block>(this, context, map, kind, op);
for (auto &&operand : getOperands()) { }
result.addOperand(operand.clone(context, map));
}
for (auto &&child : children()) { inline Node ContinueConstructImpl::clone(Context &context,
result.addChild(ir::clone(child, context, map)); CloneMap &map) const {
} return detail::cloneBlockImpl<ContinueConstruct>(this, context, map, kind,
op);
}
return result; inline Node LoopConstructImpl::clone(Context &context, CloneMap &map) const {
return detail::cloneBlockImpl<LoopConstruct>(this, context, map, kind, op);
}
inline Node SelectionConstructImpl::clone(Context &context,
CloneMap &map) const {
return detail::cloneBlockImpl<SelectionConstruct>(this, context, map, kind,
op);
} }
inline Operand Operand::clone(Context &context, CloneMap &map) const { inline Operand Operand::clone(Context &context, CloneMap &map) const {
@ -324,7 +376,8 @@ inline Node memssa::DefImpl::clone(Context &context, CloneMap &map) const {
inline Node memssa::ScopeImpl::clone(Context &context, CloneMap &map) const { inline Node memssa::ScopeImpl::clone(Context &context, CloneMap &map) const {
auto self = Scope(const_cast<ScopeImpl *>(this)); auto self = Scope(const_cast<ScopeImpl *>(this));
auto result = context.create<Scope>(ir::clone(self.getLocation(), context)); auto result =
context.create<Scope>(ir::clone(self.getLocation(), context), kind, op);
for (auto &&operand : self.getOperands()) { for (auto &&operand : self.getOperands()) {
result.addOperand(operand.clone(context, map)); result.addOperand(operand.clone(context, map));

View file

@ -47,12 +47,13 @@ struct InstructionImpl : NodeImpl {
decltype(auto) getOperands() const { return std::span(operands); } decltype(auto) getOperands() const { return std::span(operands); }
void print(std::ostream &os, NameStorage &ns) const override { void print(std::ostream &os, NameStorage &ns,
const PrintOptions &) const override {
os << getInstructionName(kind, op); os << getInstructionName(kind, op);
if (!operands.empty()) { if (!operands.empty()) {
os << "("; os << "(";
for (bool first = true; auto operand : operands) { for (bool first = true; auto &operand : operands) {
if (first) { if (first) {
first = false; first = false;
} else { } else {

View file

@ -0,0 +1,105 @@
#pragma once
#include "Construct.hpp"
namespace shader::ir {
template <typename ImplT>
struct ContinueConstructWrapper : ConstructWrapper<ImplT> {
using ConstructWrapper<ImplT>::ConstructWrapper;
using ConstructWrapper<ImplT>::operator=;
};
struct ContinueConstructImpl;
struct ContinueConstruct : ContinueConstructWrapper<ContinueConstructImpl> {
using ContinueConstructWrapper<
ContinueConstructImpl>::ContinueConstructWrapper;
using ContinueConstructWrapper<ContinueConstructImpl>::operator=;
};
template <typename ImplT>
struct LoopConstructWrapper : ConstructWrapper<ImplT> {
using ConstructWrapper<ImplT>::ConstructWrapper;
using ConstructWrapper<ImplT>::operator=;
Block getLatch() { return this->impl->last.template staticCast<Block>(); }
ContinueConstruct getContinue() {
return this->impl->getOperand(2)
.getAsValue()
.template staticCast<ContinueConstruct>();
}
};
struct LoopConstructImpl;
struct LoopConstruct : LoopConstructWrapper<LoopConstructImpl> {
using LoopConstructWrapper<LoopConstructImpl>::LoopConstructWrapper;
using LoopConstructWrapper<LoopConstructImpl>::operator=;
};
struct LoopConstructImpl : ConstructImpl {
using ConstructImpl::ConstructImpl;
Node clone(Context &context, CloneMap &map) const override;
void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const override {
os << '%' << ns.getNameOf(const_cast<LoopConstructImpl *>(this));
os << " = ";
if (getOperands().size() > 3) {
os << '[';
for (bool first = true; auto &operand : getOperands().subspan(3)) {
if (first) {
first = false;
} else {
os << ", ";
}
operand.print(os, ns);
}
os << "] ";
}
auto bodyOpts = opts.nextLevel();
os << "loop (header = ";
getOperand(0).print(os, ns);
os << ", merge = ";
getOperand(1).print(os, ns);
os << ", latch = ";
os << "%" << ns.getNameOf(last);
os << ") {\n";
{
bodyOpts.printIdent(os);
os << "body {\n";
for (auto childOpts = bodyOpts.nextLevel(); auto child : children()) {
childOpts.printIdent(os);
child.print(os, ns, childOpts);
os << "\n";
}
bodyOpts.printIdent(os);
os << "}\n";
}
{
bodyOpts.printIdent(os);
os << "continue {\n";
bodyOpts.printIdent(os, 1);
getOperand(2).getAsValue().print(os, ns, bodyOpts.nextLevel());
os << "\n";
bodyOpts.printIdent(os);
os << "}\n";
}
opts.printIdent(os);
os << "}";
}
};
struct ContinueConstructImpl : ConstructImpl {
using ConstructImpl::ConstructImpl;
Node clone(Context &context, CloneMap &map) const override;
};
} // namespace shader::ir

View file

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "PrintOptions.hpp"
#include "Location.hpp" #include "Location.hpp"
#include "Node.hpp" #include "Node.hpp"
#include "Operand.hpp" #include "Operand.hpp"
@ -59,7 +60,7 @@ struct NodeImpl {
void setLocation(Location newLocation) { location = newLocation; } void setLocation(Location newLocation) { location = newLocation; }
Location getLocation() const { return location; } Location getLocation() const { return location; }
virtual void print(std::ostream &os, NameStorage &ns) const = 0; virtual void print(std::ostream &os, NameStorage &ns, const PrintOptions &opts) const = 0;
virtual Node clone(Context &context, CloneMap &map) const = 0; virtual Node clone(Context &context, CloneMap &map) const = 0;
}; };
} // namespace shader::ir } // namespace shader::ir

View file

@ -1,9 +1,10 @@
#pragma once #pragma once
#include "InstructionImpl.hpp" // IWYU pragma: keep #include "InstructionImpl.hpp" // IWYU pragma: keep
#include <type_traits>
namespace shader::ir { namespace shader::ir {
template <typename T> struct PreincNodeIterable { template <typename T = Instruction> struct Range {
struct EndIterator {}; struct EndIterator {};
struct Iterator { struct Iterator {
@ -52,18 +53,24 @@ template <typename T> struct PreincNodeIterable {
} }
}; };
PreincNodeIterable(Instruction beginIt, Instruction endIt) Range(Instruction beginIt, Instruction endIt)
: mBeginIt(beginIt), mEndIt(endIt) {} : mBeginIt(beginIt), mEndIt(endIt) {}
template <typename OtherT>
requires(!std::is_same_v<OtherT, Range>)
Range(OtherT other) : Range(other.mBeginIt, other.mEndIt) {}
Iterator begin() const { return Iterator(mBeginIt, mEndIt); } Iterator begin() const { return Iterator(mBeginIt, mEndIt); }
EndIterator end() const { return EndIterator{}; } EndIterator end() const { return EndIterator{}; }
private: private:
Instruction mBeginIt; Instruction mBeginIt;
Instruction mEndIt; Instruction mEndIt;
template <typename> friend struct Range;
}; };
template <typename T> struct RevPreincNodeIterable { template <typename T = Instruction> struct RevRange {
struct EndIterator {}; struct EndIterator {};
struct Iterator { struct Iterator {
@ -112,25 +119,37 @@ template <typename T> struct RevPreincNodeIterable {
} }
}; };
RevPreincNodeIterable(Instruction beginIt, Instruction endIt) RevRange(Instruction beginIt, Instruction endIt)
: mBeginIt(beginIt), mEndIt(endIt) {} : mBeginIt(beginIt), mEndIt(endIt) {}
template <typename OtherT>
requires(!std::is_same_v<OtherT, RevRange>)
RevRange(OtherT other) : RevRange(other.mBeginIt, other.mEndIt) {}
Iterator begin() const { return Iterator(mBeginIt, mEndIt); } Iterator begin() const { return Iterator(mBeginIt, mEndIt); }
EndIterator end() const { return EndIterator{}; } EndIterator end() const { return EndIterator{}; }
private: private:
Instruction mBeginIt; Instruction mBeginIt;
Instruction mEndIt; Instruction mEndIt;
template <typename> friend struct RevRange;
}; };
template <typename T = Instruction> template <typename T = Instruction>
inline PreincNodeIterable<T> range(Instruction begin, inline Range<T> range(Instruction begin, Instruction end = nullptr) {
Instruction end = nullptr) { if (end) {
assert(begin.getParent() == end.getParent());
}
return {begin, end}; return {begin, end};
} }
template <typename T = Instruction> template <typename T = Instruction>
inline RevPreincNodeIterable<T> revRange(Instruction begin, inline RevRange<T> revRange(Instruction begin, Instruction end = nullptr) {
Instruction end = nullptr) { if (end) {
assert(begin.getParent() == end.getParent());
}
return {begin, end}; return {begin, end};
} }
} // namespace shader::ir } // namespace shader::ir

View file

@ -0,0 +1,22 @@
#pragma once
#include <ostream>
#include <string>
namespace shader {
struct PrintOptions {
int identLevel = 0;
int identCount = 2;
char identChar = ' ';
[[nodiscard]] PrintOptions nextLevel() const {
auto result = *this;
result.identLevel++;
return result;
}
void printIdent(std::ostream &os, int offset = 0) const {
os << std::string((identLevel + offset) * identCount, identChar);
}
};
} // namespace shader

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "PointerWrapper.hpp" #include "PointerWrapper.hpp"
#include "PrintOptions.hpp"
#include <ostream> #include <ostream>
namespace shader::ir { namespace shader::ir {
@ -9,18 +10,28 @@ template <typename T> struct PrintableWrapper : PointerWrapper<T> {
using PointerWrapper<T>::PointerWrapper; using PointerWrapper<T>::PointerWrapper;
using PointerWrapper<T>::operator=; using PointerWrapper<T>::operator=;
void print(std::ostream &os, NameStorage &ns) const { void print(std::ostream &os, NameStorage &ns, const PrintOptions &opts = {}) const {
if constexpr (requires { this->impl->print(os, ns); }) { if constexpr (requires { this->impl->print(os, ns, opts); }) {
this->impl->print(os, ns, opts);
} else if constexpr (requires { this->impl->print(os, ns); }) {
this->impl->print(os, ns); this->impl->print(os, ns);
} else if constexpr (requires { this->impl->print(os, opts); }) {
this->impl->print(os, opts);
} else { } else {
this->impl->print(os); this->impl->print(os);
} }
} }
void print(std::ostream &os) const void print(std::ostream &os, const PrintOptions &opts = {}) const
requires requires { this->impl->print(os); } requires(
requires { this->impl->print(os, opts); } ||
requires { this->impl->print(os); })
{ {
this->impl->print(os); if constexpr (requires { this->impl->print(os, opts); }) {
this->impl->print(os, opts);
} else {
this->impl->print(os);
}
} }
}; };
} // namespace shader::ir } // namespace shader::ir

View file

@ -9,7 +9,8 @@ namespace shader::ir {
struct RegionImpl : NodeImpl, RegionLikeImpl { struct RegionImpl : NodeImpl, RegionLikeImpl {
RegionImpl(Location loc) { setLocation(loc); } RegionImpl(Location loc) { setLocation(loc); }
void print(std::ostream &os, NameStorage &ns) const override; void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const override;
Node clone(Context &context, CloneMap &map) const override; Node clone(Context &context, CloneMap &map) const override;
}; };
} // namespace shader::ir } // namespace shader::ir

View file

@ -28,9 +28,19 @@ struct RegionLikeWrapper : BaseWrapper<ImplT> {
template <typename T = Instruction> auto revChildren() { template <typename T = Instruction> auto revChildren() {
return this->impl->template revChildren<T>(); return this->impl->template revChildren<T>();
} }
void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts = {}) const {
this->impl->printRegion(os, ns, opts);
}
auto getParent() const {
return this->impl->getParent();
}
}; };
struct RegionLikeImpl; struct RegionLikeImpl;
struct RegionLike : RegionLikeWrapper<RegionLikeImpl, PointerWrapper> { struct RegionLike : RegionLikeWrapper<RegionLikeImpl, PointerWrapper> {
using RegionLikeWrapper::RegionLikeWrapper; using RegionLikeWrapper::RegionLikeWrapper;
using RegionLikeWrapper::operator=; using RegionLikeWrapper::operator=;

View file

@ -11,15 +11,20 @@ struct RegionLikeImpl {
virtual ~RegionLikeImpl() = default; virtual ~RegionLikeImpl() = default;
template <typename T = Instruction> auto children() const { template <typename T = Instruction> auto children() const {
return PreincNodeIterable<T>{first, nullptr}; return Range<T>{first, nullptr};
} }
template <typename T = Instruction> auto revChildren() const { template <typename T = Instruction> auto revChildren() const {
return RevPreincNodeIterable<T>{last, nullptr}; return RevRange<T>{last, nullptr};
} }
virtual void insertAfter(Instruction point, Instruction node); void insertAfter(Instruction point, Instruction node);
virtual void prependChild(Instruction node); void prependChild(Instruction node);
virtual void addChild(Instruction node); void addChild(Instruction node);
void printRegion(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const;
auto getParent() const;
}; };
} // namespace shader::ir } // namespace shader::ir

View file

@ -0,0 +1,62 @@
#pragma once
#include "Construct.hpp"
namespace shader::ir {
template <typename ImplT>
struct SelectionConstructWrapper : ConstructWrapper<ImplT> {
using ConstructWrapper<ImplT>::ConstructWrapper;
using ConstructWrapper<ImplT>::operator=;
};
struct SelectionConstructImpl;
struct SelectionConstruct : SelectionConstructWrapper<SelectionConstructImpl> {
using SelectionConstructWrapper<
SelectionConstructImpl>::SelectionConstructWrapper;
using SelectionConstructWrapper<SelectionConstructImpl>::operator=;
};
struct SelectionConstructImpl : ConstructImpl {
using ConstructImpl::ConstructImpl;
Node clone(Context &context, CloneMap &map) const override;
void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const override {
os << '%' << ns.getNameOf(const_cast<SelectionConstructImpl *>(this));
os << " = ";
if (getOperands().size() > 2) {
os << '[';
for (bool first = true; auto &operand : getOperands().subspan(2)) {
if (first) {
first = false;
} else {
os << ", ";
}
operand.print(os, ns);
}
os << "] ";
}
auto bodyOpts = opts.nextLevel();
os << "selection (header = ";
getOperand(0).print(os, ns);
os << ", merge = ";
getOperand(1).print(os, ns);
os << ") {\n";
for (auto child : children()) {
bodyOpts.printIdent(os);
child.print(os, ns, bodyOpts);
os << "\n";
}
opts.printIdent(os);
os << "}";
}
};
} // namespace shader::ir

View file

@ -2,9 +2,12 @@
#include "Instruction.hpp" #include "Instruction.hpp"
#include "Operand.hpp" #include "Operand.hpp"
#include "rx/FunctionRef.hpp"
namespace shader::ir { namespace shader::ir {
struct Value; struct Value;
struct ValueUse;
template <typename T> struct ValueWrapper : InstructionWrapper<T> { template <typename T> struct ValueWrapper : InstructionWrapper<T> {
using InstructionWrapper<T>::InstructionWrapper; using InstructionWrapper<T>::InstructionWrapper;
using InstructionWrapper<T>::operator=; using InstructionWrapper<T>::operator=;
@ -12,6 +15,7 @@ template <typename T> struct ValueWrapper : InstructionWrapper<T> {
decltype(auto) getUserList() const { return this->impl->getUserList(); } decltype(auto) getUserList() const { return this->impl->getUserList(); }
auto &getUseList() const { return this->impl->uses; } auto &getUseList() const { return this->impl->uses; }
void replaceAllUsesWith(Value other) const; void replaceAllUsesWith(Value other) const;
void replaceUsesIf(Value other, rx::FunctionRef<bool(ValueUse)> cb);
bool isUnused() const { return this->impl->uses.empty(); } bool isUnused() const { return this->impl->uses.empty(); }
}; };
@ -22,15 +26,21 @@ struct Value : ValueWrapper<ValueImpl> {
using ValueWrapper::operator=; using ValueWrapper::operator=;
}; };
template <typename T>
void ValueWrapper<T>::replaceAllUsesWith(Value other) const {
this->impl->replaceAllUsesWith(other);
}
struct ValueUse { struct ValueUse {
Instruction user; Instruction user;
Value node; Value node;
int operandIndex; int operandIndex;
auto operator<=>(const ValueUse &) const = default; auto operator<=>(const ValueUse &) const = default;
}; };
template <typename T>
void ValueWrapper<T>::replaceAllUsesWith(Value other) const {
this->impl->replaceAllUsesWith(other);
}
template <typename T>
void ValueWrapper<T>::replaceUsesIf(Value other,
rx::FunctionRef<bool(ValueUse)> cb) {
this->impl->replaceUsesIf(other, cb);
}
} // namespace shader::ir } // namespace shader::ir

View file

@ -4,25 +4,24 @@
#include "NameStorage.hpp" #include "NameStorage.hpp"
#include "Node.hpp" #include "Node.hpp"
#include "Value.hpp" #include "Value.hpp"
#include "rx/FunctionRef.hpp"
namespace shader::ir { namespace shader::ir {
struct ValueImpl : InstructionImpl { struct ValueImpl : InstructionImpl {
std::set<ValueUse> uses; std::set<ValueUse> uses;
ValueImpl(Location location, Kind kind, unsigned op, using InstructionImpl::InstructionImpl;
std::span<const Operand> operands = {})
: InstructionImpl(location, kind, op, operands) {}
void addUse(Instruction user, int operandIndex) { void addUse(Instruction user, int operandIndex) {
uses.insert({user, this, operandIndex}); uses.insert({.user=user, .node=this, .operandIndex=operandIndex});
} }
void removeUse(Instruction user, int operandIndex) { void removeUse(Instruction user, int operandIndex) {
uses.erase({user, this, operandIndex}); uses.erase({.user=user, .node=this, .operandIndex=operandIndex});
} }
std::set<Node> getUserList() const { std::set<Instruction> getUserList() const {
std::set<Node> list; std::set<Instruction> list;
for (auto use : uses) { for (auto use : uses) {
list.insert(use.user); list.insert(use.user);
} }
@ -44,10 +43,29 @@ struct ValueImpl : InstructionImpl {
} }
} }
void print(std::ostream &os, NameStorage &ns) const override { void replaceUsesIf(Value other, rx::FunctionRef<bool(ValueUse)> cond) {
if (other == this) {
std::abort();
}
auto savedUses = uses;
for (auto &use : savedUses) {
if (cond(use)) {
if (other == nullptr) {
use.user.replaceOperand(use.operandIndex, nullptr);
} else {
use.user.replaceOperand(use.operandIndex, other);
}
}
}
}
void print(std::ostream &os, NameStorage &ns,
const PrintOptions &opts) const override {
os << '%' << ns.getNameOf(const_cast<ValueImpl *>(this)); os << '%' << ns.getNameOf(const_cast<ValueImpl *>(this));
os << " = "; os << " = ";
InstructionImpl::print(os, ns); InstructionImpl::print(os, ns, opts);
} }
Node clone(Context &context, CloneMap &map) const override; Node clone(Context &context, CloneMap &map) const override;

View file

@ -113,7 +113,7 @@ std::optional<BinaryLayout> deserialize(ir::Context &context,
/// ///
/// \returns A vector of u32 values representing the SPIR-V binary. /// \returns A vector of u32 values representing the SPIR-V binary.
/// ///
std::vector<std::uint32_t> serialize(ir::Region body); std::vector<std::uint32_t> serialize(ir::RegionLike body);
inline std::vector<std::uint32_t> serialize(ir::Context &context, inline std::vector<std::uint32_t> serialize(ir::Context &context,
BinaryLayout &&layout) { BinaryLayout &&layout) {

View file

@ -3,6 +3,5 @@
#include "ir.hpp" #include "ir.hpp"
namespace shader { namespace shader {
void structurizeCfg(spv::Context &context, ir::RegionLike region, void structurizeCfg(spv::Context &context, ir::RegionLike region);
ir::Value exitLabel);
} }

View file

@ -374,7 +374,7 @@ struct ResourcesBuilder {
} else if (auto value = inst.cast<ir::Value>()) { } else if (auto value = inst.cast<ir::Value>()) {
resourcePhi.addOperand(value); resourcePhi.addOperand(value);
} else { } else {
auto block = resources.context.create<ir::Block>(inst.getLocation()); auto block = resources.context.create<ir::Block>(inst.getLocation(), ir::Kind::Builtin, ir::builtin::BLOCK);
inst.erase(); inst.erase();
block.addChild(inst); block.addChild(inst);
resourcePhi.addOperand(block); resourcePhi.addOperand(block);
@ -396,7 +396,8 @@ struct ResourcesBuilder {
return value; return value;
} }
auto block = resources.context.create<ir::Block>(inst.getLocation()); auto block = resources.context.create<ir::Block>(
inst.getLocation(), ir::Kind::Builtin, ir::builtin::BLOCK);
block.addChild(inst); block.addChild(inst);
return block; return block;
} }

View file

@ -12,7 +12,7 @@ static std::string getTypeName(ir::Value type);
static std::string getConstantName(ir::Value constant) { static std::string getConstantName(ir::Value constant) {
if (constant == ir::spv::OpConstant) { if (constant == ir::spv::OpConstant) {
auto typeValue = constant.getOperand(0).getAsValue(); auto typeValue = constant.getOperand(0).getAsValue();
auto value = constant.getOperand(1); auto &value = constant.getOperand(1);
if (typeValue == ir::spv::OpTypeInt) { if (typeValue == ir::spv::OpTypeInt) {
auto width = *typeValue.getOperand(0).getAsInt32(); auto width = *typeValue.getOperand(0).getAsInt32();
@ -310,6 +310,10 @@ ir::Node spv::Import::getOrCloneImpl(ir::Context &context, ir::Node node,
return CloneMap::getOrCloneImpl(context, node, isOperand); return CloneMap::getOrCloneImpl(context, node, isOperand);
} }
ir::Region spv::Context::createRegion(ir::Location loc) {
return create<ir::Region>(loc);
}
ir::Value spv::Context::createRegionWithLabel(ir::Location loc) { ir::Value spv::Context::createRegionWithLabel(ir::Location loc) {
return Builder::createAppend(*this, create<ir::Region>(loc)) return Builder::createAppend(*this, create<ir::Region>(loc))
.createSpvLabel(loc); .createSpvLabel(loc);
@ -340,6 +344,10 @@ ir::Value spv::Context::getNull(ir::Value typeValue) {
return getOrCreateGlobal(ir::spv::OpConstantNull, {{typeValue}}); return getOrCreateGlobal(ir::spv::OpConstantNull, {{typeValue}});
} }
ir::Value spv::Context::getUndef(ir::Value typeValue) {
return getOrCreateGlobal(ir::spv::OpUndef, {{typeValue}});
}
ir::Value spv::Context::getType(ir::spv::Op baseType, int width, ir::Value spv::Context::getType(ir::spv::Op baseType, int width,
bool isSigned) { bool isSigned) {
switch (baseType) { switch (baseType) {

View file

@ -3,8 +3,8 @@
#include "ir.hpp" #include "ir.hpp"
#include "rx/die.hpp" #include "rx/die.hpp"
#include "spv.hpp" #include "spv.hpp"
#include <algorithm>
#include <iostream> #include <iostream>
#include <print>
using namespace shader; using namespace shader;
@ -174,6 +174,274 @@ ir::Value shader::unwrapPointer(ir::Value pointer) {
} }
} }
ir::Instruction shader::getTerminator(ir::RegionLike region) {
if (auto block = region.cast<ir::Block>()) {
if (block == ir::builtin::LOOP_CONSTRUCT ||
block == ir::builtin::SELECTION_CONSTRUCT) {
return block;
}
}
auto terminator = region.getLast();
if (!terminator || !isTerminator(terminator)) {
return {};
}
return terminator;
}
static int getTotalSuccessorCount(ir::Instruction terminator) {
if (terminator == ir::spv::OpBranch) {
return 1;
}
if (terminator == ir::spv::OpBranchConditional) {
return 2;
}
if (terminator == ir::spv::OpSwitch) {
return terminator.getOperandCount() / 2;
}
if (terminator == ir::builtin::LOOP_CONSTRUCT ||
terminator == ir::builtin::SELECTION_CONSTRUCT) {
return 1;
}
return 0;
}
static void walkSuccessors(ir::Instruction terminator, auto &&cb) {
if (terminator == ir::spv::OpBranch) {
cb(terminator.getOperand(0).getAsValue(), 0);
return;
}
if (terminator == ir::spv::OpBranchConditional) {
cb(terminator.getOperand(1).getAsValue(), 1);
cb(terminator.getOperand(2).getAsValue(), 2);
return;
}
if (terminator == ir::spv::OpSwitch) {
for (std::size_t i = 1, end = terminator.getOperandCount(); i < end;
i += 2) {
cb(terminator.getOperand(i).getAsValue(), i);
}
return;
}
if (terminator == ir::builtin::LOOP_CONSTRUCT ||
terminator == ir::builtin::SELECTION_CONSTRUCT) {
cb(terminator.getOperand(0).getAsValue(), 0);
}
}
std::vector<std::pair<ir::Block, int>>
shader::getAllSuccessors(ir::Block region) {
auto terminator = getTerminator(region);
if (!terminator) {
return {};
}
std::vector<std::pair<ir::Block, int>> result;
result.reserve(getTotalSuccessorCount(terminator));
walkSuccessors(terminator, [&](ir::Value successor, int operandIndex) {
if (auto block = successor.cast<ir::Block>()) {
result.emplace_back(block, operandIndex);
}
});
return result;
}
std::vector<std::pair<ir::Block, int>>
shader::getAllPredecessors(ir::Block region) {
std::vector<std::pair<ir::Block, int>> result;
result.reserve(region.getUseList().size());
for (auto &use : region.getUseList()) {
if (isBranch(use.user)) {
if (auto block = use.user.getParent().cast<ir::Block>()) {
result.emplace_back(block, use.operandIndex);
}
continue;
}
if (use.operandIndex == 0 &&
(use.user == ir::builtin::LOOP_CONSTRUCT ||
use.user == ir::builtin::SELECTION_CONSTRUCT)) {
result.emplace_back(use.user.staticCast<ir::Block>(), use.operandIndex);
continue;
}
}
return result;
}
std::unordered_set<ir::Block> shader::getSuccessors(ir::Block block) {
auto terminator = getTerminator(block);
if (!terminator) {
return {};
}
std::unordered_set<ir::Block> result;
result.reserve(getTotalSuccessorCount(terminator));
walkSuccessors(terminator, [&](ir::Value successor, int) {
if (auto block = successor.cast<ir::Block>()) {
result.insert(block);
}
});
return result;
}
std::unordered_set<ir::Block> shader::getPredecessors(ir::Block block) {
std::unordered_set<ir::Block> result;
result.reserve(block.getUseList().size());
for (auto &use : block.getUseList()) {
if (use.operandIndex == 0 &&
(use.user == ir::builtin::LOOP_CONSTRUCT ||
use.user == ir::builtin::SELECTION_CONSTRUCT)) {
result.insert(use.user.staticCast<ir::Block>());
continue;
}
if (isBranch(use.user)) {
if (auto block = use.user.getParent().cast<ir::Block>()) {
result.insert(block);
}
}
}
return result;
}
std::size_t shader::getSuccessorCount(ir::Block region) {
return getSuccessors(region).size();
}
std::size_t shader::getPredecessorCount(ir::Block region) {
return getPredecessors(region).size();
}
bool shader::hasSuccessor(ir::Block region, ir::Block successor) {
auto terminator = getTerminator(region);
if (!terminator) {
return false;
}
bool result = false;
walkSuccessors(terminator, [&](ir::Value currentSuccessor, int) {
if (result) {
return;
}
if (currentSuccessor == successor) {
result = true;
return;
}
});
return result;
}
bool shader::hasAtLeastSuccessors(ir::Block region, std::size_t count) {
auto terminator = getTerminator(region);
if (!terminator) {
return false;
}
if (getTotalSuccessorCount(terminator) < count) {
return false;
}
std::vector<ir::Block> successors;
successors.reserve(count - 1);
bool result = false;
walkSuccessors(terminator, [&](ir::Value successor, int) {
if (result) {
return;
}
if (auto block = successor.cast<ir::Block>()) {
if (!std::ranges::contains(successors, block)) {
if (successors.size() + 1 >= count) {
result = true;
return;
}
successors.push_back(block);
}
}
});
return result;
}
ir::Block shader::getUniqSuccessor(ir::Block region) {
auto terminator = getTerminator(region);
if (!terminator) {
return {};
}
ir::Block result;
bool noUniqSuccessor = false;
walkSuccessors(terminator, [&](ir::Value successor, int) {
if (noUniqSuccessor) {
return;
}
if (auto block = successor.cast<ir::Block>()) {
if (!result) {
result = block;
} else if (result != block) {
noUniqSuccessor = true;
}
}
});
if (noUniqSuccessor) {
return {};
}
return result;
}
graph::DomTree<ir::Block> shader::buildDomTree(ir::Block block) {
return graph::buildDomTree(block, [&](ir::Block region, const auto &cb) {
for (auto succ : getSuccessors(region)) {
cb(succ);
}
});
}
graph::DomTree<ir::Block> shader::buildPostDomTree(ir::Block block) {
return graph::buildDomTree(block, [&](ir::Block region, const auto &cb) {
for (auto pred : getPredecessors(region)) {
cb(pred);
}
});
}
graph::DomTree<ir::Block> shader::buildDomTree(ir::RegionLike region) {
return buildDomTree(region.getFirst().staticCast<ir::Block>());
}
graph::DomTree<ir::Block> shader::buildPostDomTree(ir::RegionLike region) {
return buildPostDomTree(region.getLast().staticCast<ir::Block>());
}
graph::DomTree<ir::Value> shader::buildDomTree(CFG &cfg, ir::Value root) { graph::DomTree<ir::Value> shader::buildDomTree(CFG &cfg, ir::Value root) {
if (root == nullptr) { if (root == nullptr) {
root = cfg.getEntryLabel(); root = cfg.getEntryLabel();
@ -223,195 +491,38 @@ void CFG::print(std::ostream &os, ir::NameStorage &ns, bool subgraph,
std::string CFG::genTest() { std::string CFG::genTest() {
std::string result; std::string result;
result += "ir::Value genCfg(spv::Context &context) {\n"; result += "void cfgTest() {\n";
result += " auto loc = context.getUnknownLocation();\n";
result += " auto boolT = context.getTypeBool();\n";
result += " auto trueV = context.getTrue();\n";
result += " auto builder = Builder::createAppend(context, "
"context.layout.getOrCreateFunctions(context));\n";
result += " auto debugs = Builder::createAppend(context, "
"context.layout.getOrCreateDebugs(context));\n";
ir::NameStorage ns; ir::NameStorage ns;
for (auto node : getPreorderNodes()) { for (auto node : getPreorderNodes()) {
auto name = ns.getNameOf(node->getLabel()); auto name = ns.getNameOf(node->getLabel());
result += " auto _" + name + " = builder.createSpvLabel(loc);\n"; result += " auto _" + name + " = createLabel(\"" + name + "\");\n";
result += " context.ns.setNameOf(_" + name + ", \"" + name + "\");\n";
result += " debugs.createSpvName(loc, _" + name + ", \"" + name + "\");\n";
} }
for (auto node : getPreorderNodes()) { for (auto node : getPreorderNodes()) {
auto name = ns.getNameOf(node->getLabel()); auto name = ns.getNameOf(node->getLabel());
result +=
" builder = Builder::createInsertAfter(context, _" + name + ");\n";
if (node->getSuccessorCount() == 1) { if (node->getSuccessorCount() == 1) {
result += " builder.createSpvBranch(loc, _" + result += " createBranch(_" +
ns.getNameOf((*node->getSuccessors().begin())->getLabel()) + ns.getNameOf((*node->getSuccessors().begin())->getLabel()) +
");\n"; ");\n";
} else if (node->getSuccessorCount() == 2) { } else if (node->getSuccessorCount() == 2) {
auto firstIt = node->getSuccessors().begin(); auto firstIt = node->getSuccessors().begin();
auto secondIt = std::next(firstIt); auto secondIt = std::next(firstIt);
result += " builder.createSpvBranchConditional(loc, trueV, _" + result += " createConditionalBranch(_" +
ns.getNameOf((*firstIt)->getLabel()) + ", _" + ns.getNameOf((*firstIt)->getLabel()) + ", _" +
ns.getNameOf((*secondIt)->getLabel()) + ");\n"; ns.getNameOf((*secondIt)->getLabel()) + ");\n";
} else if (node->getSuccessorCount() == 0) { } else if (node->getSuccessorCount() == 0) {
result += " builder.createSpvReturn(loc);\n"; result += " createReturn(_" + name + ");\n";
result += " auto returnBlock = _" + name + ";\n";
} }
} }
result += " return returnBlock;\n";
result += "}\n"; result += "}\n";
return result; return result;
} }
static void walkSuccessors(ir::Instruction terminator, auto &&cb) { CFG shader::buildCFG(ir::Instruction firstInstruction, ir::Value exitLabel,
if (terminator == ir::spv::OpBranch) {
cb(terminator.getOperand(0).getAsValue());
return;
}
if (terminator == ir::spv::OpBranchConditional) {
cb(terminator.getOperand(1).getAsValue());
cb(terminator.getOperand(2).getAsValue());
return;
}
if (terminator == ir::spv::OpSwitch) {
for (std::size_t i = 1, end = terminator.getOperandCount(); i < end;
i += 2) {
cb(terminator.getOperand(i).getAsValue());
}
return;
}
}
CFG CFG::buildView(CFG::Node *from, PostDomTree *domTree,
const std::unordered_set<ir::Value> &stopLabels,
ir::Value continueLabel) {
struct Item {
CFG::Node *node;
std::vector<CFG::Node *> successors;
};
std::vector<CFG::Node *> workList;
std::unordered_set<ir::Value> visited;
workList.push_back(from);
CFG result;
result.mEntryNode = result.getOrCreateNode(from->getLabel());
visited.insert(from->getLabel());
// for (auto pred : from->getPredecessors()) {
// result.getOrCreateNode(pred->getLabel());
// }
auto createResultNode = [&](CFG::Node *node) {
auto newNode = result.getOrCreateNode(node->getLabel());
newNode->setTerminator(node->getTerminator());
return newNode;
};
while (!workList.empty()) {
auto item = workList.back();
workList.pop_back();
auto resultItem = createResultNode(item);
result.addPreorderNode(resultItem);
if (item != from) {
if (item->getLabel() == continueLabel) {
continue;
}
if (stopLabels.contains(item->getLabel())) {
if (domTree == nullptr) {
continue;
}
for (auto succ : item->getSuccessors()) {
if (!domTree->dominates(item->getLabel(), succ->getLabel())) {
continue;
}
auto resultSucc = createResultNode(succ);
resultItem->addEdge(resultSucc);
if (visited.insert(succ->getLabel()).second) {
workList.push_back(succ);
}
}
continue;
}
}
for (auto succ : item->getSuccessors()) {
auto resultSucc = createResultNode(succ);
resultItem->addEdge(resultSucc);
if (visited.insert(succ->getLabel()).second) {
workList.push_back(succ);
}
}
}
if (domTree != nullptr) {
return result;
}
for (auto exitLabel : stopLabels) {
if (exitLabel == nullptr) {
continue;
}
// collect internal branches from exitLabel. Need to collect all blocks
// first to be able discard edges to not exists in this CFG target blocks
if (auto from = result.getNode(exitLabel)) {
for (auto succ : getNode(exitLabel)->getSuccessors()) {
if (auto to = result.getNode(succ->getLabel())) {
from->addEdge(to);
}
}
}
}
return result;
}
void Construct::invalidateAll() {
Construct *root = this;
while (root->parent != nullptr) {
root = root->parent;
}
std::vector<Construct *> workList;
workList.push_back(root);
while (!workList.empty()) {
auto item = workList.back();
workList.pop_back();
item->analysis.invalidateAll();
for (auto &child : item->children) {
workList.push_back(&child);
}
}
}
void Construct::invalidate() {
invalidateAll();
// Construct *item = this;
// while (item != nullptr) {
// item->analysis.invalidateAll();
// item = item->parent;
// }
}
CFG shader::buildCFG(ir::Instruction firstInstruction,
const std::unordered_set<ir::Value> &exitLabels,
ir::Value continueLabel) { ir::Value continueLabel) {
struct Item { struct Item {
CFG::Node *node; CFG::Node *node;
@ -426,14 +537,16 @@ CFG shader::buildCFG(ir::Instruction firstInstruction,
std::unordered_set<CFG::Node *> visited; std::unordered_set<CFG::Node *> visited;
bool force = true;
auto addSuccessor = [&](Item &from, ir::Value toLabel) { auto addSuccessor = [&](Item &from, ir::Value toLabel) {
if (toLabel == continueLabel) {
return;
}
auto to = result.getOrCreateNode(toLabel); auto to = result.getOrCreateNode(toLabel);
from.node->addEdge(to); from.node->addEdge(to);
if (!force && (exitLabels.contains(from.node->getLabel()) || if (from.node->getLabel() == exitLabel ||
from.node->getLabel() == continueLabel)) { from.node->getLabel() == continueLabel) {
return; return;
} }
@ -471,7 +584,6 @@ CFG shader::buildCFG(ir::Instruction firstInstruction,
visited.insert(item.node); visited.insert(item.node);
} else { } else {
item.iterator = nullptr; item.iterator = nullptr;
force = false;
} }
continue; continue;
@ -481,7 +593,8 @@ CFG shader::buildCFG(ir::Instruction firstInstruction,
item.node->setTerminator(inst); item.node->setTerminator(inst);
item.iterator = nullptr; item.iterator = nullptr;
walkSuccessors(inst, [&](ir::Value label) { addSuccessor(item, label); }); walkSuccessors(inst,
[&](ir::Value label, int) { addSuccessor(item, label); });
continue; continue;
} }
@ -492,22 +605,6 @@ CFG shader::buildCFG(ir::Instruction firstInstruction,
} }
} }
for (auto exitLabel : exitLabels) {
if (exitLabel == nullptr) {
continue;
}
// collect internal branches from exitLabel. Need to collect all blocks
// first to be able discard edges to not exists in this CFG target blocks
if (auto from = result.getNode(exitLabel)) {
walkSuccessors(from->getTerminator(), [&](ir::Value toLabel) {
if (auto to = result.getNode(toLabel)) {
from->addEdge(to);
}
});
}
}
return result; return result;
} }
@ -930,13 +1027,6 @@ MemorySSA MemorySSABuilder::build(CFG &cfg, auto &&handleInst) {
} }
} }
// auto domTree = graph::DomTreeBuilder<ir::memssa::Scope>{}.build(
// entryScope, [&](ir::memssa::Scope scope, const auto &cb) {
// for (auto succ : scope.getSuccessors()) {
// cb(succ);
// }
// });
for (auto scope : ir::range<ir::memssa::Scope>(entryScope)) { for (auto scope : ir::range<ir::memssa::Scope>(entryScope)) {
for (auto use : scope.children<ir::memssa::Use>()) { for (auto use : scope.children<ir::memssa::Use>()) {
auto &user = memSSA.userDefs[use.getLinkedInst()]; auto &user = memSSA.userDefs[use.getLinkedInst()];
@ -1220,56 +1310,3 @@ shader::findNearestCommonDominator(ir::Instruction a, ir::Instruction b,
return domTree.findNearestCommonDominator(a.staticCast<ir::Value>(), return domTree.findNearestCommonDominator(a.staticCast<ir::Value>(),
b.staticCast<ir::Value>()); b.staticCast<ir::Value>());
} }
BackEdgeStorage::BackEdgeStorage(CFG &cfg) {
struct Entry {
ir::Value bb;
CFG::Node::Iterator successorsIt;
CFG::Node::Iterator successorsEnd;
};
std::vector<Entry> workList;
std::unordered_set<ir::Value> inWorkList;
// std::unordered_set<ir::Value> viewed;
workList.reserve(cfg.getPostorderNodes().size());
inWorkList.reserve(cfg.getPostorderNodes().size());
auto addToWorkList = [&](CFG::Node *node) {
if (inWorkList.insert(node->getLabel()).second) {
workList.push_back({
.bb = node->getLabel(),
.successorsIt = node->getSuccessors().begin(),
.successorsEnd = node->getSuccessors().end(),
});
return true;
}
return false;
};
addToWorkList(cfg.getEntryNode());
while (!workList.empty()) {
auto &entry = workList.back();
if (entry.successorsIt == entry.successorsEnd) {
// viewed.insert(inWorkList.extract(entry.bb));
workList.pop_back();
continue;
}
auto label = entry.bb;
auto it = entry.successorsIt;
++entry.successorsIt;
auto successor = *it;
// if (viewed.contains(successor->getLabel())) {
// continue;
// }
if (!addToWorkList(successor)) {
backEdges[successor->getLabel()].insert(label);
}
}
}

View file

@ -1902,7 +1902,6 @@ gcn::deserialize(gcn::Context &context, const gcn::Environment &environment,
std::print("\n\n{}\n\n", buildCFG(context.entryPoint).genTest()); std::print("\n\n{}\n\n", buildCFG(context.entryPoint).genTest());
structurizeCfg(context, context.body, structurizeCfg(context, context.body);
context.epilogue.getFirst().cast<ir::Value>());
return context.body; return context.body;
} }

View file

@ -7,7 +7,7 @@
using namespace shader; using namespace shader;
static std::uint32_t generateSpv(std::vector<std::uint32_t> &result, static std::uint32_t generateSpv(std::vector<std::uint32_t> &result,
shader::ir::Region body) { shader::ir::RegionLike body) {
std::map<shader::ir::Value, std::uint32_t> valueToId; std::map<shader::ir::Value, std::uint32_t> valueToId;
std::uint32_t bounds = 1; std::uint32_t bounds = 1;
@ -63,7 +63,7 @@ static std::uint32_t generateSpv(std::vector<std::uint32_t> &result,
addWord(getValueId(value)); addWord(getValueId(value));
} }
for (auto operand : operands) { for (auto &operand : operands) {
if (auto value = operand.getAsValue()) { if (auto value = operand.getAsValue()) {
addWord(getValueId(value)); addWord(getValueId(value));
continue; continue;
@ -126,7 +126,7 @@ shader::spv::deserialize(ir::Context &context,
return {}; return {};
} }
std::vector<std::uint32_t> shader::spv::serialize(ir::Region body) { std::vector<std::uint32_t> shader::spv::serialize(ir::RegionLike body) {
std::vector<std::uint32_t> result; std::vector<std::uint32_t> result;
result.resize(5); result.resize(5);
result[0] = 0x07230203; result[0] = 0x07230203;
@ -170,6 +170,7 @@ std::string shader::spv::disassembly(std::span<const std::uint32_t> spv,
std::string result; std::string result;
if (text != nullptr) { if (text != nullptr) {
result = std::string(text->str, text->length); result = std::string(text->str, text->length);
spvTextDestroy(text);
} }
spvDiagnosticDestroy(diagnostic); spvDiagnosticDestroy(diagnostic);

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <compare> #include <compare>
#include <type_traits>
#include <utility> #include <utility>
namespace rx { namespace rx {
@ -12,7 +13,7 @@ template <typename RT, typename... ArgsT> class FunctionRef<RT(ArgsT...)> {
public: public:
constexpr FunctionRef() = default; constexpr FunctionRef() = default;
template <typename T> template <typename T> requires (!std::is_same_v<std::remove_cvref_t<T>, FunctionRef>)
constexpr FunctionRef(T &&object) constexpr FunctionRef(T &&object)
requires requires(ArgsT... args) { RT(object(args...)); } requires requires(ArgsT... args) { RT(object(args...)); }
: context( : context(