simplified MemoryTable utility

This commit is contained in:
DH 2024-09-03 10:09:35 +03:00
parent bd39f9a070
commit 86e2d8b129
4 changed files with 153 additions and 113 deletions

View file

@ -2215,22 +2215,29 @@ struct CacheOverlayBase {
virtual void release(std::uint64_t tag) {}
std::optional<decltype(syncState)::AreaInfo> getSyncTag(std::uint64_t address,
std::uint64_t size) {
struct SyncTag {
std::uint64_t beginAddress;
std::uint64_t endAddress;
std::uint64_t value;
};
std::optional<SyncTag> getSyncTag(std::uint64_t address, std::uint64_t size) {
std::lock_guard lock(mtx);
auto it = syncState.queryArea(address);
if (it == syncState.end()) {
return {};
}
auto state = *it;
if (state.endAddress < address + size || state.beginAddress > address) {
if (it.endAddress() < address + size || it.beginAddress() > address) {
// has no single sync state
return {};
}
return state;
return SyncTag{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.value = it.get(),
};
}
bool isInSync(util::MemoryTableWithPayload<CacheSyncEntry> &table,
@ -2250,14 +2257,12 @@ struct CacheOverlayBase {
return false;
}
auto tableTag = *tableArea;
if (tableTag.beginAddress > address ||
tableTag.endAddress < address + size) {
if (tableArea.beginAddress() > address ||
tableArea.endAddress() < address + size) {
return false;
}
return tableTag.payload.tag == syncTag.payload;
return tableArea->tag == syncTag.value;
}
virtual void writeBuffer(TaskChain &taskChain,
@ -2277,6 +2282,13 @@ struct CacheOverlayBase {
}
};
struct CacheEntry {
std::uint64_t beginAddress;
std::uint64_t endAddress;
std::uint64_t tag;
Ref<CacheOverlayBase> overlay;
};
struct CacheBufferOverlay : CacheOverlayBase {
vk::Buffer buffer;
std::uint64_t bufferAddress;
@ -2305,7 +2317,12 @@ struct CacheBufferOverlay : CacheOverlayBase {
util::unreachable();
}
return *it;
return CacheEntry{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.tag = it->tag,
.overlay = it->overlay,
};
}
std::lock_guard lock(tableMtx);
@ -2313,7 +2330,12 @@ struct CacheBufferOverlay : CacheOverlayBase {
if (it == table.end()) {
util::unreachable();
}
return *it;
return CacheEntry{
.beginAddress = it.beginAddress(),
.endAddress = it.endAddress(),
.tag = it->tag,
.overlay = it->overlay,
};
};
while (size > 0) {
@ -2324,8 +2346,7 @@ struct CacheBufferOverlay : CacheOverlayBase {
auto areaSize = origAreaSize;
if (!cache) {
state.payload.overlay->readBuffer(taskChain, this, address, areaSize,
waitTask);
state.overlay->readBuffer(taskChain, this, address, areaSize, waitTask);
size -= areaSize;
address += areaSize;
continue;
@ -2335,17 +2356,16 @@ struct CacheBufferOverlay : CacheOverlayBase {
auto blockSyncStateIt = syncState.queryArea(address);
if (blockSyncStateIt == syncState.end()) {
doRead(address, areaSize, state.payload.tag, state.payload.overlay);
doRead(address, areaSize, state.tag, state.overlay);
address += areaSize;
break;
}
auto blockSyncState = *blockSyncStateIt;
auto blockSize =
std::min(blockSyncState.endAddress - address, areaSize);
std::min(blockSyncStateIt.endAddress() - address, areaSize);
if (blockSyncState.payload != state.payload.tag) {
doRead(address, areaSize, state.payload.tag, state.payload.overlay);
if (blockSyncStateIt.get() != state.tag) {
doRead(address, areaSize, state.tag, state.overlay);
}
areaSize -= blockSize;
@ -2445,7 +2465,7 @@ struct CacheImageOverlay : CacheOverlayBase {
VK_IMAGE_LAYOUT_GENERAL, 1, &region);
auto tag = *srcBuffer->getSyncTag(address, size);
std::lock_guard lock(self->mtx);
self->syncState.map(address, address + size, tag.payload);
self->syncState.map(address, address + size, tag.value);
});
return;
@ -2469,7 +2489,7 @@ struct CacheImageOverlay : CacheOverlayBase {
auto tag = *srcBuffer->getSyncTag(address, size);
std::lock_guard lock(self->mtx);
self->syncState.map(address, address + size, tag.payload);
self->syncState.map(address, address + size, tag.value);
});
}
@ -2666,8 +2686,9 @@ struct CacheLine {
std::mutex writeBackTableMtx;
util::MemoryTableWithPayload<Ref<AsyncTaskCtl>> writeBackTable;
CacheLine(RemoteMemory memory, std::uint64_t areaAddress, std::uint64_t areaSize)
:memory(memory), areaAddress(areaAddress), areaSize(areaSize) {
CacheLine(RemoteMemory memory, std::uint64_t areaAddress,
std::uint64_t areaSize)
: memory(memory), areaAddress(areaAddress), areaSize(areaSize) {
memoryOverlay = new MemoryOverlay();
memoryOverlay->memory = memory;
hostSyncTable.map(areaAddress, areaAddress + areaSize, {1, memoryOverlay});
@ -2720,25 +2741,24 @@ struct CacheLine {
auto it = writeBackTable.queryArea(address);
while (it != writeBackTable.end()) {
auto taskInfo = *it;
if (taskInfo.beginAddress >= address + size) {
if (it.beginAddress() >= address + size) {
break;
}
if (taskInfo.beginAddress >= address &&
taskInfo.endAddress <= address + size) {
if (taskInfo.payload != nullptr) {
auto task = it.get();
if (it.beginAddress() >= address && it.endAddress() <= address + size) {
if (task != nullptr) {
// another task with smaller range already in progress, we can
// cancel it
// std::printf("prev upload task cancelation\n");
taskInfo.payload->cancel();
task->cancel();
}
}
if (taskInfo.payload != nullptr) {
taskInfo.payload->wait();
if (task != nullptr) {
task->wait();
}
++it;
@ -2751,8 +2771,9 @@ struct CacheLine {
void lazyMemoryUpdate(std::uint64_t tag, std::uint64_t address) {
// std::printf("memory lazy update, address %lx\n", address);
decltype(hostSyncTable)::AreaInfo area;
std::size_t beginAddress;
std::size_t areaSize;
{
std::lock_guard lock(hostSyncMtx);
auto it = hostSyncTable.queryArea(address);
@ -2761,20 +2782,18 @@ struct CacheLine {
util::unreachable();
}
area = *it;
beginAddress = it.beginAddress();
areaSize = it.size();
}
auto areaSize = area.endAddress - area.beginAddress;
auto updateTaskChain = TaskChain::Create();
auto uploadBuffer =
getBuffer(tag, *updateTaskChain.get(), area.beginAddress, areaSize, 1,
1, shader::AccessOp::Load);
auto uploadBuffer = getBuffer(tag, *updateTaskChain.get(), beginAddress,
areaSize, 1, 1, shader::AccessOp::Load);
memoryOverlay->writeBuffer(*updateTaskChain.get(), uploadBuffer,
area.beginAddress, areaSize);
beginAddress, areaSize);
updateTaskChain->wait();
uploadBuffer->unlock(tag);
unlockReadWrite(memory.vmId, area.beginAddress, areaSize);
unlockReadWrite(memory.vmId, beginAddress, areaSize);
// std::printf("memory lazy update, %lx finish\n", address);
}
@ -3020,32 +3039,27 @@ private:
auto &table = bufferTable[offset];
if (auto it = table.queryArea(address); it != table.end()) {
auto bufferInfo = *it;
if (bufferInfo.beginAddress <= address &&
bufferInfo.endAddress >= address + size) {
if (!isAligned(address - bufferInfo.beginAddress, alignment)) {
if (it.beginAddress() <= address && it.endAddress() >= address + size) {
if (!isAligned(address - it.beginAddress(), alignment)) {
util::unreachable();
}
return bufferInfo.payload;
return it.get();
}
assert(bufferInfo.beginAddress <= address);
assert(it.beginAddress() <= address);
auto endAddress = std::max(bufferInfo.endAddress, address + size);
address = bufferInfo.beginAddress;
auto endAddress = std::max(it.endAddress(), address + size);
address = it.beginAddress();
while (it != table.end()) {
bufferInfo = *it;
if (endAddress > bufferInfo.endAddress) {
if (endAddress > it.endAddress()) {
auto nextIt = it;
if (++nextIt != table.end()) {
auto nextInfo = *nextIt;
if (nextInfo.beginAddress >= endAddress) {
if (nextIt.beginAddress() >= endAddress) {
break;
}
endAddress = nextInfo.endAddress;
endAddress = nextIt.endAddress();
}
}
++it;
@ -4817,8 +4831,8 @@ void amdgpu::device::AmdgpuDevice::handleProtectMemory(RemoteMemory memory,
protStr = "unknown";
break;
}
std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n", address,
size, protStr, memory.vmId);
std::fprintf(stderr, "Allocated area at %zx, size %lx, prot %s, vmid %u\n",
address, size, protStr, memory.vmId);
} else {
memoryAreaTable[memory.vmId].unmap(beginPage, endPage);
std::fprintf(stderr, "Unmapped area at %zx, size %lx\n", address, size);
@ -5069,8 +5083,8 @@ bool amdgpu::device::AmdgpuDevice::handleFlip(
g_bridge->flipBuffer[memory.vmId] = bufferIndex;
g_bridge->flipArg[memory.vmId] = arg;
g_bridge->flipCount[memory.vmId] = g_bridge->flipCount[memory.vmId] + 1;
auto bufferInUse =
memory.getPointer<std::uint64_t>(g_bridge->bufferInUseAddress[memory.vmId]);
auto bufferInUse = memory.getPointer<std::uint64_t>(
g_bridge->bufferInUseAddress[memory.vmId]);
if (bufferInUse != nullptr) {
bufferInUse[bufferIndex] = 0;
}