From 1b15ef4d13bfa1e2d1f61d87998c535d27cf92f4 Mon Sep 17 00:00:00 2001 From: DH Date: Mon, 7 Aug 2023 23:49:45 +0300 Subject: [PATCH] [amdgpu] scheduler: avoid dead lock on cpu workloads --- .../include/amdgpu/device/gpu-scheduler.hpp | 82 ++++++++++++------- 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/hw/amdgpu/device/include/amdgpu/device/gpu-scheduler.hpp b/hw/amdgpu/device/include/amdgpu/device/gpu-scheduler.hpp index 6233beec0..a080369ea 100644 --- a/hw/amdgpu/device/include/amdgpu/device/gpu-scheduler.hpp +++ b/hw/amdgpu/device/include/amdgpu/device/gpu-scheduler.hpp @@ -81,27 +81,40 @@ struct TaskChain { std::uint64_t add(std::uint64_t waitId, T &&task) { auto prevTaskId = getLastTaskId(); auto id = nextTaskId++; - auto cpuTask = - createCpuTask([=, task = std::forward(task), - self = Ref(this)](const AsyncTaskCtl &) mutable { - if (waitId != GpuTaskLayout::kInvalidId) { - if (self->semaphore.getCounterValue() < waitId) { - return TaskResult::Reschedule; - } + enum class State { + WaitTask, + PrevTask, + }; + auto cpuTask = createCpuTask([=, task = std::forward(task), + self = Ref(this), state = State::WaitTask]( + const AsyncTaskCtl &) mutable { + if (state == State::WaitTask) { + if (waitId != GpuTaskLayout::kInvalidId) { + if (self->semaphore.getCounterValue() < waitId) { + return TaskResult::Reschedule; } + } - auto result = task(); - if (result != TaskResult::Complete) { - return result; + auto result = task(); + + if (result != TaskResult::Complete) { + return result; + } + state = State::PrevTask; + } + + if (state == State::PrevTask) { + if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) { + if (self->semaphore.getCounterValue() < prevTaskId) { + return TaskResult::Reschedule; } + } - if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) { - self->wait(prevTaskId); - } + self->semaphore.signal(id); + } - self->semaphore.signal(id); - return TaskResult::Complete; - }); + return TaskResult::Complete; + }); getCpuScheduler().enqueue(std::move(cpuTask)); return id; } @@ -113,24 +126,35 @@ struct TaskChain { std::uint64_t add(std::uint64_t waitId, T &&task) { auto prevTaskId = getLastTaskId(); auto id = nextTaskId++; - auto cpuTask = - createCpuTask([=, task = std::forward(task), - self = Ref(this)](const AsyncTaskCtl &) mutable { - if (waitId != GpuTaskLayout::kInvalidId) { - if (self->semaphore.getCounterValue() < waitId) { - return TaskResult::Reschedule; - } + enum class State { + WaitTask, + PrevTask, + }; + auto cpuTask = createCpuTask([=, task = std::forward(task), + self = Ref(this), state = State::WaitTask]( + const AsyncTaskCtl &) mutable { + if (state == State::WaitTask) { + if (waitId != GpuTaskLayout::kInvalidId) { + if (self->semaphore.getCounterValue() < waitId) { + return TaskResult::Reschedule; } + } - task(); + task(); + state = State::PrevTask; + } - if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) { - self->wait(prevTaskId); + if (state == State::PrevTask) { + if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) { + if (self->semaphore.getCounterValue() < prevTaskId) { + return TaskResult::Reschedule; } + } - self->semaphore.signal(id); - return TaskResult::Complete; - }); + self->semaphore.signal(id); + } + return TaskResult::Complete; + }); getCpuScheduler().enqueue(std::move(cpuTask)); return id; }