[orbis-kernel] Add basic EFAULT check

Check for valid pointer range
Replace some deprecated checks
This commit is contained in:
Ivan Chikish 2023-07-12 17:24:28 +03:00
parent 204336989d
commit 230193129f
5 changed files with 41 additions and 20 deletions

View file

@ -108,7 +108,9 @@ orbis::SysResult orbis::sys_nfssvc(Thread *thread, sint flag, caddr_t argp) {
}
orbis::SysResult orbis::sys_sysarch(Thread *thread, sint op, ptr<char> parms) {
if (op == 129) {
auto fs = uread((ptr<uint64_t>)parms);
uint64_t fs;
if (auto error = uread(fs, (ptr<uint64_t>)parms); error != ErrorCode{})
return error;
std::printf("sys_sysarch: set FS to 0x%zx\n", (std::size_t)fs);
thread->fsBase = fs;
return {};

View file

@ -601,7 +601,10 @@ orbis::SysResult orbis::sys_mdbg_service(Thread *thread, uint32_t op,
switch (op) {
case 1: {
auto prop = uread((ptr<mdbg_property>)arg0);
mdbg_property prop;
if (auto error = uread(prop, (ptr<mdbg_property>)arg0);
error != ErrorCode{})
return error;
ORBIS_LOG_WARNING(__FUNCTION__, prop.name, prop.addr_ptr, prop.areaSize);
break;
}

View file

@ -6,7 +6,9 @@
static orbis::ErrorCode ureadTimespec(orbis::timespec &ts,
orbis::ptr<orbis::timespec> addr) {
ts = uread(addr);
orbis::ErrorCode error = uread(ts, addr);
if (error != orbis::ErrorCode{})
return error;
if (ts.sec < 0 || ts.nsec < 0 || ts.nsec > 1000000000) {
return orbis::ErrorCode::INVAL;
}

View file

@ -216,7 +216,9 @@ static ErrorCode do_unlock_pp(Thread *thread, ptr<umutex> m, uint flags) {
orbis::ErrorCode orbis::umtx_trylock_umutex(Thread *thread, ptr<umutex> m) {
ORBIS_LOG_TRACE(__FUNCTION__, m);
uint flags = uread(&m->flags);
uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
switch (flags & (kUmutexPrioInherit | kUmutexPrioProtect)) {
case 0:
return do_lock_normal(thread, m, flags, 0, umutex_lock_mode::try_);
@ -231,7 +233,9 @@ orbis::ErrorCode orbis::umtx_trylock_umutex(Thread *thread, ptr<umutex> m) {
orbis::ErrorCode orbis::umtx_lock_umutex(Thread *thread, ptr<umutex> m,
std::uint64_t ut) {
ORBIS_LOG_TRACE(__FUNCTION__, m, ut);
uint flags = uread(&m->flags);
uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
switch (flags & (kUmutexPrioInherit | kUmutexPrioProtect)) {
case 0:
return do_lock_normal(thread, m, flags, ut, umutex_lock_mode::lock);
@ -245,7 +249,9 @@ orbis::ErrorCode orbis::umtx_lock_umutex(Thread *thread, ptr<umutex> m,
orbis::ErrorCode orbis::umtx_unlock_umutex(Thread *thread, ptr<umutex> m) {
ORBIS_LOG_TRACE(__FUNCTION__, m);
uint flags = uread(&m->flags);
uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
switch (flags & (kUmutexPrioInherit | kUmutexPrioProtect)) {
case 0:
return do_unlock_normal(thread, m, flags);
@ -268,7 +274,9 @@ orbis::ErrorCode orbis::umtx_cv_wait(Thread *thread, ptr<ucond> cv,
ptr<umutex> m, std::uint64_t ut,
ulong wflags) {
ORBIS_LOG_NOTICE(__FUNCTION__, thread, cv, m, ut, wflags);
const uint flags = uread(&cv->flags);
uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
if ((wflags & kCvWaitClockId) != 0) {
ORBIS_LOG_FATAL("umtx_cv_wait: CLOCK_ID unimplemented", wflags);
return ErrorCode::NOSYS;
@ -360,7 +368,9 @@ orbis::ErrorCode orbis::umtx_wake_private(Thread *thread, ptr<void> uaddr,
orbis::ErrorCode orbis::umtx_wait_umutex(Thread *thread, ptr<umutex> m,
std::uint64_t ut) {
ORBIS_LOG_TRACE(__FUNCTION__, m, ut);
uint flags = uread(&m->flags);
uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
switch (flags & (kUmutexPrioInherit | kUmutexPrioProtect)) {
case 0:
return do_lock_normal(thread, m, flags, ut, umutex_lock_mode::wait);
@ -378,7 +388,9 @@ orbis::ErrorCode orbis::umtx_wake_umutex(Thread *thread, ptr<umutex> m) {
if ((owner & ~kUmutexContested) != 0)
return {};
[[maybe_unused]] uint flags = uread(&m->flags);
[[maybe_unused]] uint flags;
if (ErrorCode err = uread(flags, &m->flags); err != ErrorCode{})
return err;
auto [chain, key, lock] = g_context.getUmtxChain1(thread->tproc->pid, m);
std::size_t count = chain.sleep_queue.count(key);

View file

@ -42,34 +42,36 @@ using caddr_t = ptr<char>;
[[nodiscard]] inline ErrorCode
ureadRaw(void *kernelAddress, ptr<const void> userAddress, size_t size) {
auto addr = reinterpret_cast<std::uintptr_t>(userAddress);
if (addr < 0x40000 || addr + size > 0x100'0000'0000 || addr + size < addr)
return ErrorCode::FAULT;
std::memcpy(kernelAddress, userAddress, size);
return {};
}
[[nodiscard]] inline ErrorCode
uwriteRaw(ptr<void> userAddress, const void *kernelAddress, size_t size) {
auto addr = reinterpret_cast<std::uintptr_t>(userAddress);
if (addr < 0x40000 || addr + size > 0x100'0000'0000 || addr + size < addr)
return ErrorCode::FAULT;
std::memcpy(userAddress, kernelAddress, size);
return {};
}
[[nodiscard]] inline ErrorCode ureadString(char *kernelAddress,
size_t kernelSize,
[[nodiscard]] inline ErrorCode ureadString(char *kernelAddress, size_t size,
ptr<const char> userAddress) {
std::strncpy(kernelAddress, userAddress, kernelSize);
if (kernelAddress[kernelSize - 1] != '\0') {
kernelAddress[kernelSize - 1] = '\0';
auto addr = reinterpret_cast<std::uintptr_t>(userAddress);
if (addr < 0x40000 || addr + size > 0x100'0000'0000 || addr + size < addr)
return ErrorCode::FAULT;
std::strncpy(kernelAddress, userAddress, size);
if (kernelAddress[size - 1] != '\0') {
kernelAddress[size - 1] = '\0';
return ErrorCode::NAMETOOLONG;
}
return {};
}
template <typename T> [[deprecated]] T uread(ptr<const T> pointer) {
T result{};
ureadRaw(&result, pointer, sizeof(T));
return result;
}
template <typename T> [[nodiscard]] ErrorCode uread(T &result, ptr<T> pointer) {
return ureadRaw(&result, pointer, sizeof(T));
}