[amdgpu] scheduler: avoid dead lock on cpu workloads

This commit is contained in:
DH 2023-08-07 23:49:45 +03:00
parent f5949e5f65
commit 1b15ef4d13

View file

@ -81,27 +81,40 @@ struct TaskChain {
std::uint64_t add(std::uint64_t waitId, T &&task) {
auto prevTaskId = getLastTaskId();
auto id = nextTaskId++;
auto cpuTask =
createCpuTask([=, task = std::forward<T>(task),
self = Ref(this)](const AsyncTaskCtl &) mutable {
if (waitId != GpuTaskLayout::kInvalidId) {
if (self->semaphore.getCounterValue() < waitId) {
return TaskResult::Reschedule;
}
enum class State {
WaitTask,
PrevTask,
};
auto cpuTask = createCpuTask([=, task = std::forward<T>(task),
self = Ref(this), state = State::WaitTask](
const AsyncTaskCtl &) mutable {
if (state == State::WaitTask) {
if (waitId != GpuTaskLayout::kInvalidId) {
if (self->semaphore.getCounterValue() < waitId) {
return TaskResult::Reschedule;
}
}
auto result = task();
if (result != TaskResult::Complete) {
return result;
auto result = task();
if (result != TaskResult::Complete) {
return result;
}
state = State::PrevTask;
}
if (state == State::PrevTask) {
if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) {
if (self->semaphore.getCounterValue() < prevTaskId) {
return TaskResult::Reschedule;
}
}
if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) {
self->wait(prevTaskId);
}
self->semaphore.signal(id);
}
self->semaphore.signal(id);
return TaskResult::Complete;
});
return TaskResult::Complete;
});
getCpuScheduler().enqueue(std::move(cpuTask));
return id;
}
@ -113,24 +126,35 @@ struct TaskChain {
std::uint64_t add(std::uint64_t waitId, T &&task) {
auto prevTaskId = getLastTaskId();
auto id = nextTaskId++;
auto cpuTask =
createCpuTask([=, task = std::forward<T>(task),
self = Ref(this)](const AsyncTaskCtl &) mutable {
if (waitId != GpuTaskLayout::kInvalidId) {
if (self->semaphore.getCounterValue() < waitId) {
return TaskResult::Reschedule;
}
enum class State {
WaitTask,
PrevTask,
};
auto cpuTask = createCpuTask([=, task = std::forward<T>(task),
self = Ref(this), state = State::WaitTask](
const AsyncTaskCtl &) mutable {
if (state == State::WaitTask) {
if (waitId != GpuTaskLayout::kInvalidId) {
if (self->semaphore.getCounterValue() < waitId) {
return TaskResult::Reschedule;
}
}
task();
task();
state = State::PrevTask;
}
if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) {
self->wait(prevTaskId);
if (state == State::PrevTask) {
if (prevTaskId != GpuTaskLayout::kInvalidId && waitId != prevTaskId) {
if (self->semaphore.getCounterValue() < prevTaskId) {
return TaskResult::Reschedule;
}
}
self->semaphore.signal(id);
return TaskResult::Complete;
});
self->semaphore.signal(id);
}
return TaskResult::Complete;
});
getCpuScheduler().enqueue(std::move(cpuTask));
return id;
}