From 86e2d8b1298a7761cdf4967e5dbaba6dfad28fd6 Mon Sep 17 00:00:00 2001 From: DH Date: Tue, 3 Sep 2024 10:09:35 +0300 Subject: [PATCH] simplified MemoryTable utility --- hw/amdgpu/device/src/device.cpp | 134 ++++++++++++++++++-------------- rpcsx-os/iodev/dmem.cpp | 51 ++++++------ rpcsx-os/vm.cpp | 26 +++---- rx/include/rx/MemoryTable.hpp | 55 ++++++++++--- 4 files changed, 153 insertions(+), 113 deletions(-) diff --git a/hw/amdgpu/device/src/device.cpp b/hw/amdgpu/device/src/device.cpp index 024796135..500b024ab 100644 --- a/hw/amdgpu/device/src/device.cpp +++ b/hw/amdgpu/device/src/device.cpp @@ -2215,22 +2215,29 @@ struct CacheOverlayBase { virtual void release(std::uint64_t tag) {} - std::optional getSyncTag(std::uint64_t address, - std::uint64_t size) { + struct SyncTag { + std::uint64_t beginAddress; + std::uint64_t endAddress; + std::uint64_t value; + }; + + std::optional getSyncTag(std::uint64_t address, std::uint64_t size) { std::lock_guard lock(mtx); auto it = syncState.queryArea(address); if (it == syncState.end()) { return {}; } - auto state = *it; - - if (state.endAddress < address + size || state.beginAddress > address) { + if (it.endAddress() < address + size || it.beginAddress() > address) { // has no single sync state return {}; } - return state; + return SyncTag{ + .beginAddress = it.beginAddress(), + .endAddress = it.endAddress(), + .value = it.get(), + }; } bool isInSync(util::MemoryTableWithPayload &table, @@ -2250,14 +2257,12 @@ struct CacheOverlayBase { return false; } - auto tableTag = *tableArea; - - if (tableTag.beginAddress > address || - tableTag.endAddress < address + size) { + if (tableArea.beginAddress() > address || + tableArea.endAddress() < address + size) { return false; } - return tableTag.payload.tag == syncTag.payload; + return tableArea->tag == syncTag.value; } virtual void writeBuffer(TaskChain &taskChain, @@ -2277,6 +2282,13 @@ struct CacheOverlayBase { } }; +struct CacheEntry { + std::uint64_t beginAddress; + std::uint64_t endAddress; + std::uint64_t tag; + Ref overlay; +}; + struct CacheBufferOverlay : CacheOverlayBase { vk::Buffer buffer; std::uint64_t bufferAddress; @@ -2305,7 +2317,12 @@ struct CacheBufferOverlay : CacheOverlayBase { util::unreachable(); } - return *it; + return CacheEntry{ + .beginAddress = it.beginAddress(), + .endAddress = it.endAddress(), + .tag = it->tag, + .overlay = it->overlay, + }; } std::lock_guard lock(tableMtx); @@ -2313,7 +2330,12 @@ struct CacheBufferOverlay : CacheOverlayBase { if (it == table.end()) { util::unreachable(); } - return *it; + return CacheEntry{ + .beginAddress = it.beginAddress(), + .endAddress = it.endAddress(), + .tag = it->tag, + .overlay = it->overlay, + }; }; while (size > 0) { @@ -2324,8 +2346,7 @@ struct CacheBufferOverlay : CacheOverlayBase { auto areaSize = origAreaSize; if (!cache) { - state.payload.overlay->readBuffer(taskChain, this, address, areaSize, - waitTask); + state.overlay->readBuffer(taskChain, this, address, areaSize, waitTask); size -= areaSize; address += areaSize; continue; @@ -2335,17 +2356,16 @@ struct CacheBufferOverlay : CacheOverlayBase { auto blockSyncStateIt = syncState.queryArea(address); if (blockSyncStateIt == syncState.end()) { - doRead(address, areaSize, state.payload.tag, state.payload.overlay); + doRead(address, areaSize, state.tag, state.overlay); address += areaSize; break; } - auto blockSyncState = *blockSyncStateIt; auto blockSize = - std::min(blockSyncState.endAddress - address, areaSize); + std::min(blockSyncStateIt.endAddress() - address, areaSize); - if (blockSyncState.payload != state.payload.tag) { - doRead(address, areaSize, state.payload.tag, state.payload.overlay); + if (blockSyncStateIt.get() != state.tag) { + doRead(address, areaSize, state.tag, state.overlay); } areaSize -= blockSize; @@ -2445,7 +2465,7 @@ struct CacheImageOverlay : CacheOverlayBase { VK_IMAGE_LAYOUT_GENERAL, 1, ®ion); auto tag = *srcBuffer->getSyncTag(address, size); std::lock_guard lock(self->mtx); - self->syncState.map(address, address + size, tag.payload); + self->syncState.map(address, address + size, tag.value); }); return; @@ -2469,7 +2489,7 @@ struct CacheImageOverlay : CacheOverlayBase { auto tag = *srcBuffer->getSyncTag(address, size); std::lock_guard lock(self->mtx); - self->syncState.map(address, address + size, tag.payload); + self->syncState.map(address, address + size, tag.value); }); } @@ -2666,8 +2686,9 @@ struct CacheLine { std::mutex writeBackTableMtx; util::MemoryTableWithPayload> writeBackTable; - CacheLine(RemoteMemory memory, std::uint64_t areaAddress, std::uint64_t areaSize) - :memory(memory), areaAddress(areaAddress), areaSize(areaSize) { + CacheLine(RemoteMemory memory, std::uint64_t areaAddress, + std::uint64_t areaSize) + : memory(memory), areaAddress(areaAddress), areaSize(areaSize) { memoryOverlay = new MemoryOverlay(); memoryOverlay->memory = memory; hostSyncTable.map(areaAddress, areaAddress + areaSize, {1, memoryOverlay}); @@ -2720,25 +2741,24 @@ struct CacheLine { auto it = writeBackTable.queryArea(address); while (it != writeBackTable.end()) { - auto taskInfo = *it; - - if (taskInfo.beginAddress >= address + size) { + if (it.beginAddress() >= address + size) { break; } - if (taskInfo.beginAddress >= address && - taskInfo.endAddress <= address + size) { - if (taskInfo.payload != nullptr) { + auto task = it.get(); + + if (it.beginAddress() >= address && it.endAddress() <= address + size) { + if (task != nullptr) { // another task with smaller range already in progress, we can // cancel it // std::printf("prev upload task cancelation\n"); - taskInfo.payload->cancel(); + task->cancel(); } } - if (taskInfo.payload != nullptr) { - taskInfo.payload->wait(); + if (task != nullptr) { + task->wait(); } ++it; @@ -2751,8 +2771,9 @@ struct CacheLine { void lazyMemoryUpdate(std::uint64_t tag, std::uint64_t address) { // std::printf("memory lazy update, address %lx\n", address); - decltype(hostSyncTable)::AreaInfo area; + std::size_t beginAddress; + std::size_t areaSize; { std::lock_guard lock(hostSyncMtx); auto it = hostSyncTable.queryArea(address); @@ -2761,20 +2782,18 @@ struct CacheLine { util::unreachable(); } - area = *it; + beginAddress = it.beginAddress(); + areaSize = it.size(); } - auto areaSize = area.endAddress - area.beginAddress; - auto updateTaskChain = TaskChain::Create(); - auto uploadBuffer = - getBuffer(tag, *updateTaskChain.get(), area.beginAddress, areaSize, 1, - 1, shader::AccessOp::Load); + auto uploadBuffer = getBuffer(tag, *updateTaskChain.get(), beginAddress, + areaSize, 1, 1, shader::AccessOp::Load); memoryOverlay->writeBuffer(*updateTaskChain.get(), uploadBuffer, - area.beginAddress, areaSize); + beginAddress, areaSize); updateTaskChain->wait(); uploadBuffer->unlock(tag); - unlockReadWrite(memory.vmId, area.beginAddress, areaSize); + unlockReadWrite(memory.vmId, beginAddress, areaSize); // std::printf("memory lazy update, %lx finish\n", address); } @@ -3020,32 +3039,27 @@ private: auto &table = bufferTable[offset]; if (auto it = table.queryArea(address); it != table.end()) { - auto bufferInfo = *it; - - if (bufferInfo.beginAddress <= address && - bufferInfo.endAddress >= address + size) { - if (!isAligned(address - bufferInfo.beginAddress, alignment)) { + if (it.beginAddress() <= address && it.endAddress() >= address + size) { + if (!isAligned(address - it.beginAddress(), alignment)) { util::unreachable(); } - return bufferInfo.payload; + return it.get(); } - assert(bufferInfo.beginAddress <= address); + assert(it.beginAddress() <= address); - auto endAddress = std::max(bufferInfo.endAddress, address + size); - address = bufferInfo.beginAddress; + auto endAddress = std::max(it.endAddress(), address + size); + address = it.beginAddress(); while (it != table.end()) { - bufferInfo = *it; - if (endAddress > bufferInfo.endAddress) { + if (endAddress > it.endAddress()) { auto nextIt = it; if (++nextIt != table.end()) { - auto nextInfo = *nextIt; - if (nextInfo.beginAddress >= endAddress) { + if (nextIt.beginAddress() >= endAddress) { break; } - endAddress = nextInfo.endAddress; + endAddress = nextIt.endAddress(); } } ++it; @@ -4817,8 +4831,8 @@ void amdgpu::device::AmdgpuDevice::handleProtectMemory(RemoteMemory memory, protStr = "unknown"; break; } - std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n", address, - size, protStr, memory.vmId); + std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n", + address, size, protStr, memory.vmId); } else { memoryAreaTable[memory.vmId].unmap(beginPage, endPage); std::fprintf(stderr, "Unmapped area at %zx, size %lx\n", address, size); @@ -5069,8 +5083,8 @@ bool amdgpu::device::AmdgpuDevice::handleFlip( g_bridge->flipBuffer[memory.vmId] = bufferIndex; g_bridge->flipArg[memory.vmId] = arg; g_bridge->flipCount[memory.vmId] = g_bridge->flipCount[memory.vmId] + 1; - auto bufferInUse = - memory.getPointer(g_bridge->bufferInUseAddress[memory.vmId]); + auto bufferInUse = memory.getPointer( + g_bridge->bufferInUseAddress[memory.vmId]); if (bufferInUse != nullptr) { bufferInUse[bufferIndex] = 0; } diff --git a/rpcsx-os/iodev/dmem.cpp b/rpcsx-os/iodev/dmem.cpp index 3fbe5550d..0ce4ed25b 100644 --- a/rpcsx-os/iodev/dmem.cpp +++ b/rpcsx-os/iodev/dmem.cpp @@ -47,8 +47,7 @@ orbis::ErrorCode DmemDevice::mmap(void **address, std::uint64_t len, int memoryType = 0; if (auto allocationInfoIt = allocations.queryArea(directMemoryStart); allocationInfoIt != allocations.end()) { - auto allocationInfo = *allocationInfoIt; - memoryType = allocationInfo.payload.memoryType; + memoryType = allocationInfoIt->memoryType; } auto result = @@ -183,25 +182,24 @@ static orbis::ErrorCode dmem_ioctl(orbis::File *file, std::uint64_t request, auto queryInfo = *it; - if (queryInfo.payload.memoryType == -1u) { + if (it->memoryType == -1u) { return orbis::ErrorCode::ACCES; } if ((args->flags & 1) == 0) { - if (queryInfo.endAddress <= args->offset) { + if (it.endAddress() <= args->offset) { return orbis::ErrorCode::ACCES; } } else { - if (queryInfo.beginAddress > args->offset || - queryInfo.endAddress <= args->offset) { + if (it.beginAddress() > args->offset || it.endAddress() <= args->offset) { return orbis::ErrorCode::ACCES; } } DirectMemoryQueryInfo info{ - .start = queryInfo.beginAddress, - .end = queryInfo.endAddress, - .memoryType = queryInfo.payload.memoryType, + .start = it.beginAddress(), + .end = it.endAddress(), + .memoryType = it->memoryType, }; ORBIS_LOG_WARNING("dmem directMemoryQuery", device->index, args->devIndex, @@ -255,20 +253,19 @@ orbis::ErrorCode DmemDevice::allocate(std::uint64_t *start, auto it = allocations.lowerBound(offset); if (it != allocations.end()) { - auto allocation = *it; - if (allocation.payload.memoryType == -1u) { - if (offset < allocation.beginAddress) { - offset = allocation.beginAddress + alignment - 1; + if (it->memoryType == -1u) { + if (offset < it.beginAddress()) { + offset = it.beginAddress() + alignment - 1; offset &= ~(alignment - 1); } - if (offset + len >= allocation.endAddress) { - offset = allocation.endAddress; + if (offset + len >= it.endAddress()) { + offset = it.endAddress(); continue; } } else { - if (offset + len > allocation.beginAddress) { - offset = allocation.endAddress; + if (offset + len > it.beginAddress()) { + offset = it.endAddress(); continue; } } @@ -315,25 +312,23 @@ orbis::ErrorCode DmemDevice::queryMaxFreeChunkSize(std::uint64_t *start, break; } - auto allocation = *it; - if (allocation.payload.memoryType == -1u) { - if (offset < allocation.beginAddress) { - offset = allocation.beginAddress + alignment - 1; + if (it->memoryType == -1u) { + if (offset < it.beginAddress()) { + offset = it.beginAddress() + alignment - 1; offset &= ~(alignment - 1); } - if (allocation.endAddress > offset && - resultSize < allocation.endAddress - offset) { - resultSize = allocation.endAddress - offset; + if (it.endAddress() > offset && resultSize < it.endAddress() - offset) { + resultSize = it.endAddress() - offset; resultOffset = offset; } - } else if (offset > allocation.beginAddress && - resultSize < offset - allocation.beginAddress) { - resultSize = offset - allocation.beginAddress; + } else if (offset > it.beginAddress() && + resultSize < offset - it.beginAddress()) { + resultSize = offset - it.beginAddress(); resultOffset = offset; } - offset = allocation.endAddress; + offset = it.endAddress(); } *start = resultOffset; diff --git a/rpcsx-os/vm.cpp b/rpcsx-os/vm.cpp index 78e4d1623..119787d5b 100644 --- a/rpcsx-os/vm.cpp +++ b/rpcsx-os/vm.cpp @@ -929,7 +929,7 @@ void *rx::vm::map(void *addr, std::uint64_t len, std::int32_t prot, { MapInfo info; if (auto it = gMapInfo.queryArea(address); it != gMapInfo.end()) { - info = (*it).payload; + info = it.get(); } info.device = device; info.flags = flags; @@ -1124,29 +1124,27 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags, return false; } - auto queryInfo = *it; if ((flags & 1) == 0) { - if (queryInfo.endAddress <= address) { + if (it.endAddress() <= address) { return false; } } else { - if (queryInfo.beginAddress > address || queryInfo.endAddress <= address) { + if (it.beginAddress() > address || it.endAddress() <= address) { return false; } } std::int32_t memoryType = 0; std::uint32_t blockFlags = 0; - if (queryInfo.payload.device != nullptr) { + if (it->device != nullptr) { if (auto dmem = - dynamic_cast(queryInfo.payload.device.get())) { - auto dmemIt = dmem->allocations.queryArea(queryInfo.payload.offset); + dynamic_cast(it->device.get())) { + auto dmemIt = dmem->allocations.queryArea(it->offset); if (dmemIt == dmem->allocations.end()) { return false; } - auto alloc = *dmemIt; - memoryType = alloc.payload.memoryType; + memoryType = dmemIt->memoryType; blockFlags = kBlockFlagDirectMemory; std::fprintf(stderr, "virtual query %p", addr); std::fprintf(stderr, "memory type: %u\n", memoryType); @@ -1154,11 +1152,11 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags, // TODO } - std::int32_t prot = getPageProtectionImpl(queryInfo.beginAddress); + std::int32_t prot = getPageProtectionImpl(it.beginAddress()); *info = { - .start = queryInfo.beginAddress, - .end = queryInfo.endAddress, + .start = it.beginAddress(), + .end = it.endAddress(), .protection = prot, .memoryType = memoryType, .flags = blockFlags, @@ -1167,7 +1165,7 @@ bool rx::vm::virtualQuery(const void *addr, std::int32_t flags, ORBIS_LOG_ERROR("virtualQuery", addr, flags, info->start, info->end, info->protection, info->memoryType, info->flags); - std::memcpy(info->name, queryInfo.payload.name, sizeof(info->name)); + std::memcpy(info->name, it->name, sizeof(info->name)); return true; } @@ -1177,7 +1175,7 @@ void rx::vm::setName(std::uint64_t start, std::uint64_t size, MapInfo info; if (auto it = gMapInfo.queryArea(start); it != gMapInfo.end()) { - info = (*it).payload; + info = it.get(); } std::strncpy(info.name, name, sizeof(info.name)); diff --git a/rx/include/rx/MemoryTable.hpp b/rx/include/rx/MemoryTable.hpp index 70d7530f6..83cd67a4d 100644 --- a/rx/include/rx/MemoryTable.hpp +++ b/rx/include/rx/MemoryTable.hpp @@ -214,7 +214,9 @@ public: struct AreaInfo { std::uint64_t beginAddress; std::uint64_t endAddress; - PayloadT payload; + PayloadT &payload; + + std::size_t size() const { return endAddress - beginAddress; } }; class iterator { @@ -230,6 +232,12 @@ public: return {it->first, std::next(it)->first, it->second.second}; } + std::uint64_t beginAddress() const { return it->first; } + std::uint64_t endAddress() const { return std::next(it)->first; } + std::uint64_t size() const { return endAddress() - beginAddress(); } + + PayloadT &get() const { return it->second.second; } + PayloadT *operator->() const { return &it->second.second; } iterator &operator++() { ++it; @@ -242,6 +250,8 @@ public: bool operator==(iterator other) const { return it == other.it; } bool operator!=(iterator other) const { return it != other.it; } + + friend MemoryTableWithPayload; }; iterator begin() { return iterator(mAreas.begin()); } @@ -252,18 +262,14 @@ public: iterator lowerBound(std::uint64_t address) { auto it = mAreas.lower_bound(address); - if (it == mAreas.end()) { + if (it == mAreas.end() || it->second.first != Kind::X) { return it; } if (it->first == address) { - if (it->second.first == Kind::X) { - ++it; - } + ++it; } else { - if (it->second.first != Kind::O) { - --it; - } + --it; } return it; @@ -296,8 +302,8 @@ public: return endAddress < address ? mAreas.end() : it; } - void map(std::uint64_t beginAddress, std::uint64_t endAddress, - PayloadT payload, bool merge = true) { + iterator map(std::uint64_t beginAddress, std::uint64_t endAddress, + PayloadT payload, bool merge = true) { assert(beginAddress < endAddress); auto [beginIt, beginInserted] = mAreas.emplace(beginAddress, std::pair{Kind::O, payload}); @@ -370,7 +376,7 @@ public: } if (!merge) { - return; + return origBegin; } if (origBegin->second.first == Kind::XO) { @@ -378,6 +384,7 @@ public: if (prevBegin->second.second == origBegin->second.second) { mAreas.erase(origBegin); + origBegin = prevBegin; } } @@ -386,6 +393,32 @@ public: mAreas.erase(endIt); } } + + return origBegin; + } + + void unmap(iterator it) { + auto openIt = it.it; + auto closeIt = openIt; + ++closeIt; + + if (openIt->second.first == Kind::XO) { + openIt->second.first = Kind::X; + openIt->second.second = {}; + } else { + mAreas.erase(openIt); + } + + if (closeIt->second.first == Kind::XO) { + closeIt->second.first = Kind::O; + } else { + mAreas.erase(closeIt); + } + } + + void unmap(std::uint64_t beginAddress, std::uint64_t endAddress) { + // FIXME: can be optimized + unmap(map(beginAddress, endAddress, PayloadT{}, false)); } }; } // namespace rx