#pragma once #include "ModuleInfo.hpp" #include "SemanticInfo.hpp" #include "dialect/memssa.hpp" #include "graph.hpp" #include "ir/Instruction.hpp" #include "ir/Value.hpp" #include "rx/FunctionRef.hpp" #include "rx/TypeId.hpp" #include #include #include #include #include namespace shader { struct DomTree; struct PostDomTree; class CFG { public: class Node { ir::Value mLabel; ir::Instruction mTerminator; std::unordered_set mPredecessors; std::unordered_set mSuccessors; public: using Iterator = std::unordered_set::iterator; Node() = default; Node(ir::Value label) : mLabel(label) {} ir::Value getLabel() { return mLabel; } void setTerminator(ir::Instruction inst) { mTerminator = inst; } bool hasTerminator() { return mTerminator != nullptr; } ir::Instruction getTerminator() { return mTerminator; } void addEdge(Node *to) { to->mPredecessors.insert(this); mSuccessors.insert(to); } bool hasPredecessor(Node *node) { return mPredecessors.contains(node); } bool hasSuccessor(Node *node) { return mSuccessors.contains(node); } auto &getPredecessors() { return mPredecessors; } auto &getSuccessors() { return mSuccessors; } std::size_t getPredecessorCount() { return mPredecessors.size(); } std::size_t getSuccessorCount() { return mSuccessors.size(); } bool hasPredecessors() { return !mPredecessors.empty(); } bool hasSuccessors() { return !mSuccessors.empty(); } template auto range() { return ir::range(mLabel, mTerminator.getNext()); } template auto rangeWithoutLabel() { return ir::range(mLabel.getNext(), mTerminator ? mTerminator.getNext() : nullptr); } template auto rangeWithoutTerminator() { return ir::range(mLabel, mTerminator); } template auto rangeWithoutLabelAndTerminator() { return ir::range(mLabel.getNext(), mTerminator); } }; private: std::map mNodes; std::vector mPreorderNodes; std::vector mPostorderNodes; Node *mEntryNode = nullptr; public: bool empty() { return mNodes.empty(); } void clear() { mNodes.clear(); mPreorderNodes.clear(); mPostorderNodes.clear(); mEntryNode = nullptr; } void addPreorderNode(Node *node) { mPreorderNodes.push_back(node); } void addPostorderNode(Node *node) { mPostorderNodes.push_back(node); } Node *getEntryNode() { return mEntryNode; } ir::Value getEntryLabel() { return getEntryNode()->getLabel(); } void setEntryNode(Node *node) { mEntryNode = node; } std::span getPreorderNodes() { return mPreorderNodes; } std::span getPostorderNodes() { return mPostorderNodes; } Node *getOrCreateNode(ir::Value label) { return &mNodes.emplace(label, label).first->second; } Node *getNode(ir::Value label) { if (auto it = mNodes.find(label); it != mNodes.end()) { return &it->second; } return nullptr; } auto &getSuccessors(ir::Value label) { return getNode(label)->getSuccessors(); } auto &getPredecessors(ir::Value label) { return getNode(label)->getPredecessors(); } void print(std::ostream &os, ir::NameStorage &ns, bool subgraph = false, std::string_view nameSuffix = ""); std::string genTest(); CFG buildView(CFG::Node *from, PostDomTree *domTree = nullptr, const std::unordered_set &stopLabels = {}, ir::Value continueLabel = nullptr); CFG buildView(ir::Value from, PostDomTree *domTree = nullptr, const std::unordered_set &stopLabels = {}, ir::Value continueLabel = nullptr) { return buildView(getNode(from), domTree, stopLabels, continueLabel); } }; class MemorySSA { public: ir::Context context; ir::Region region; std::map variableToVar; std::map> userDefs; ir::memssa::Var getVar(ir::Value variable, std::span path); ir::memssa::Var getVar(ir::Value pointer); ir::memssa::Def getDef(ir::Instruction user, ir::memssa::Var var) { auto userIt = userDefs.find(user); if (userIt == userDefs.end()) { return {}; } if (auto it = userIt->second.find(var); it != userIt->second.end()) { return it->second; } return {}; } ir::memssa::Def getDef(ir::Instruction user, ir::Value pointer) { if (auto var = getVar(pointer)) { return getDef(user, var); } return {}; } ir::Instruction getDefInst(ir::Instruction user, ir::Value pointer) { if (auto def = getDef(user, pointer)) { return def.getLinkedInst(); } return {}; } void print(std::ostream &os, ir::Region irRegion, ir::NameStorage &ns); void print(std::ostream &os, ir::NameStorage &ns); void dump(); private: ir::memssa::Var getVarImpl(ir::Value variable); }; bool isWithoutSideEffects(ir::InstructionId id); bool isTerminator(ir::Instruction inst); bool isBranch(ir::Instruction inst); ir::Value unwrapPointer(ir::Value pointer); graph::DomTree buildDomTree(CFG &cfg, ir::Value root = nullptr); graph::DomTree buildPostDomTree(CFG &cfg, ir::Value root); CFG buildCFG(ir::Instruction firstInstruction, const std::unordered_set &exitLabels = {}, ir::Value continueLabel = nullptr); MemorySSA buildMemorySSA(CFG &cfg, ModuleInfo *moduleInfo = nullptr); MemorySSA buildMemorySSA(CFG &cfg, const SemanticInfo &instructionSemantic, std::function getRegisterVarCb); bool dominates(ir::Instruction a, ir::Instruction b, bool isPostDom, graph::DomTree &domTree); ir::Value findNearestCommonDominator(ir::Instruction a, ir::Instruction b, graph::DomTree &domTree); class BackEdgeStorage { std::unordered_map> backEdges; public: BackEdgeStorage() = default; BackEdgeStorage(CFG &cfg); const std::unordered_set *get(ir::Value value) { if (auto it = backEdges.find(value); it != backEdges.end()) { return &it->second; } return nullptr; } auto &all() { return backEdges; } }; struct AnalysisStorage { template requires(sizeof...(T) > 0) bool invalidate() { bool invalidated = false; ((invalidated = invalidate(rx::TypeId::get()) || invalidated), ...); return invalidated; } bool invalidate(rx::TypeId id) { if (auto it = mStorage.find(id); it != mStorage.end()) { return std::exchange(it->second.invalid, true) == false; } return false; } void invalidateAll() { for (auto &entry : mStorage) { entry.second.invalid = true; } } template T &get(ArgsT &&...args) requires requires { T(std::forward(args)...); } { void *result = getImpl( rx::TypeId::get(), getDeleter(), [&] { return std::make_unique(std::forward(args)...).release(); }, [&](void *object) { *reinterpret_cast(object) = T(std::forward(args)...); }); return *static_cast(result); } template T &get(BuilderFn &&builder) requires requires { T(std::forward(builder)()); } { void *result = getImpl( rx::TypeId::get(), getDeleter(), [&] { return std::make_unique(std::forward(builder)()) .release(); }, [&](void *object) { *reinterpret_cast(object) = std::forward(builder)(); }); return *static_cast(result); } private: template static void (*getDeleter())(void *) { return +[](void *data) { delete static_cast(data); }; } void *getImpl(rx::TypeId typeId, void (*deleter)(void *), rx::FunctionRef constructor, rx::FunctionRef placementConstructor) { auto [it, inserted] = mStorage.emplace(typeId, getNullPointer()); if (inserted) { it->second.object = std::unique_ptr(constructor(), deleter); } else if (it->second.invalid) { placementConstructor(it->second.object.get()); it->second.invalid = false; } return it->second.object.get(); } static constexpr std::unique_ptr getNullPointer() { return {nullptr, [](void *) {}}; } struct Entry { std::unique_ptr object; bool invalid = false; }; std::map mStorage; }; struct PostDomTree : graph::DomTree { PostDomTree() = default; PostDomTree(graph::DomTree &&other) : graph::DomTree::DomTree(std::move(other)) {} PostDomTree(CFG &cfg, ir::Value root) : PostDomTree(buildPostDomTree(cfg, root)) {} }; struct DomTree : graph::DomTree { DomTree() = default; DomTree(graph::DomTree &&other) : graph::DomTree::DomTree(std::move(other)) {} DomTree(CFG &cfg, ir::Value root = nullptr) : DomTree(buildDomTree(cfg, root)) {} }; template struct Tag : T { using T::T; using T::operator=; Tag(T &&other) : T(std::move(other)) {} Tag(const T &other) : T(other) {} Tag &operator=(T &&other) { T::operator=(std::move(other)); return *this; } Tag &operator=(const T &other) { T::operator=(other); return *this; } }; struct Construct { Construct *parent; std::forward_list children; ir::Value header; ir::Value merge; ir::Value loopBody; ir::Value loopContinue; AnalysisStorage analysis; static std::unique_ptr createRoot(ir::RegionLike region, ir::Value merge) { auto result = std::make_unique(); auto &cfg = result->analysis.get([&] { return buildCFG(region.getFirst()); }); result->header = cfg.getEntryLabel(); result->merge = merge; return result; } Construct *createChild(ir::Value header, ir::Value merge) { auto &result = children.emplace_front(); result.parent = this; result.header = header; result.merge = merge; return &result; } Construct *createChild(ir::Value header, ir::Value merge, ir::Value loopContinue, ir::Value loopBody) { auto &result = children.emplace_front(); result.parent = this; result.header = header; result.merge = merge; result.loopContinue = loopContinue; result.loopBody = loopBody; return &result; } Construct createTemporaryChild(ir::Value header, ir::Value merge) { Construct result; result.parent = this; result.header = header; result.merge = merge; return result; } CFG &getCfg() { return analysis.get([this] { if (parent != nullptr) { return parent->getCfg().buildView( header, &parent->getPostDomTree(), {header, merge}); } return buildCFG(header); }); } CFG &getCfgWithoutContinue() { if (loopContinue == nullptr) { return getCfg(); } return analysis.get>([this] { if (parent != nullptr) { return parent->getCfg().buildView( header, &parent->getPostDomTree(), {header, merge}, loopContinue); } return buildCFG(header, {}, loopContinue); }); } DomTree &getDomTree() { return analysis.get(getCfg(), header); } PostDomTree &getPostDomTree() { return analysis.get(getCfg(), merge); } BackEdgeStorage &getBackEdgeStorage() { return analysis.get(getCfg()); } BackEdgeStorage &getBackEdgeWithoutContinueStorage() { if (loopContinue == nullptr) { return getBackEdgeStorage(); } return analysis.get>( getCfgWithoutContinue()); } auto getBackEdges(ir::Value node) { return getBackEdgeStorage().get(node); } auto getBackEdgesWithoutContinue(ir::Value node) { return getBackEdgeWithoutContinueStorage().get(node); } auto getBackEdges() { return getBackEdges(header); } void invalidate(); void invalidateAll(); bool isNull() const { return header == nullptr; } void removeLastChild() { children.pop_front(); } private: enum { kWithoutContinue, }; }; } // namespace shader