From b45c61a605880c0c8cdd4c9d51594458f4dc6be8 Mon Sep 17 00:00:00 2001 From: MysterD Date: Sun, 23 Jan 2022 16:35:43 -0800 Subject: [PATCH] Added Lua allow-list for cobject pointers --- build-windows-visual-studio/sm64ex.vcxproj | 2 + .../sm64ex.vcxproj.filters | 6 ++ src/pc/lua/smlua.c | 3 + src/pc/lua/smlua.h | 2 + src/pc/lua/smlua_cobject.c | 10 +++ src/pc/lua/smlua_cobject_allowlist.c | 68 +++++++++++++++++++ src/pc/lua/smlua_cobject_allowlist.h | 9 +++ src/pc/lua/smlua_utils.c | 13 +++- 8 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/pc/lua/smlua_cobject_allowlist.c create mode 100644 src/pc/lua/smlua_cobject_allowlist.h diff --git a/build-windows-visual-studio/sm64ex.vcxproj b/build-windows-visual-studio/sm64ex.vcxproj index f67eaf21..e5f56dbc 100644 --- a/build-windows-visual-studio/sm64ex.vcxproj +++ b/build-windows-visual-studio/sm64ex.vcxproj @@ -513,6 +513,7 @@ + @@ -968,6 +969,7 @@ + diff --git a/build-windows-visual-studio/sm64ex.vcxproj.filters b/build-windows-visual-studio/sm64ex.vcxproj.filters index 69402c77..ae381596 100644 --- a/build-windows-visual-studio/sm64ex.vcxproj.filters +++ b/build-windows-visual-studio/sm64ex.vcxproj.filters @@ -4860,6 +4860,9 @@ Source Files\src\pc\djui\panel + + Source Files\src\pc\lua + @@ -6001,5 +6004,8 @@ Source Files\src\pc\djui\panel + + Source Files\src\pc\lua + \ No newline at end of file diff --git a/src/pc/lua/smlua.c b/src/pc/lua/smlua.c index 83c3808f..f93e4d22 100644 --- a/src/pc/lua/smlua.c +++ b/src/pc/lua/smlua.c @@ -65,6 +65,8 @@ static void smlua_init_mario_states(void) { void smlua_init(void) { smlua_shutdown(); + smlua_cobject_allowlist_init(); + gLuaState = luaL_newstate(); lua_State* L = gLuaState; @@ -108,6 +110,7 @@ void smlua_update(void) { } void smlua_shutdown(void) { + smlua_cobject_allowlist_shutdown(); lua_State* L = gLuaState; if (L != NULL) { lua_close(L); diff --git a/src/pc/lua/smlua.h b/src/pc/lua/smlua.h index d404ec4c..04e8afc0 100644 --- a/src/pc/lua/smlua.h +++ b/src/pc/lua/smlua.h @@ -5,9 +5,11 @@ #include #include +#include #include "types.h" #include "smlua_cobject.h" +#include "smlua_cobject_allowlist.h" #include "smlua_utils.h" #include "smlua_functions.h" #include "smlua_functions_autogen.h" diff --git a/src/pc/lua/smlua_cobject.c b/src/pc/lua/smlua_cobject.c index 0b07e659..2e32ab4a 100644 --- a/src/pc/lua/smlua_cobject.c +++ b/src/pc/lua/smlua_cobject.c @@ -313,6 +313,11 @@ static int smlua__get_field(lua_State* L) { return 0; } + if (!smlua_cobject_allowlist_contains(lot, pointer)) { + LOG_LUA("_get_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer); + return 0; + } + struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key); if (data == NULL) { LOG_LUA("_get_field on invalid key '%s', lot '%d'", key, lot); @@ -353,6 +358,11 @@ static int smlua__set_field(lua_State* L) { return 0; } + if (!smlua_cobject_allowlist_contains(lot, pointer)) { + LOG_LUA("_set_field received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer); + return 0; + } + struct LuaObjectField* data = smlua_get_object_field(&sLuaObjectTable[lot], key); if (data == NULL) { LOG_LUA("_set_field on invalid key '%s'", key); diff --git a/src/pc/lua/smlua_cobject_allowlist.c b/src/pc/lua/smlua_cobject_allowlist.c new file mode 100644 index 00000000..e131a563 --- /dev/null +++ b/src/pc/lua/smlua_cobject_allowlist.c @@ -0,0 +1,68 @@ +#include +#include "smlua.h" + +#pragma pack(1) +struct CObjectAllowListNode { + u64 pointer; + struct CObjectAllowListNode* next; +}; + +static struct CObjectAllowListNode* sAllowList[LOT_MAX] = { 0 }; +static u16 sCachedAllowed[LOT_MAX] = { 0 }; + +void smlua_cobject_allowlist_init(void) { + smlua_cobject_allowlist_shutdown(); +} + +void smlua_cobject_allowlist_shutdown(void) { + for (int i = 0; i < LOT_MAX; i++) { + sCachedAllowed[i] = 0; + struct CObjectAllowListNode* node = sAllowList[i]; + while (node != NULL) { + struct CObjectAllowListNode* nextNode = node->next; + free(node); + node = nextNode; + } + sAllowList[i] = NULL; + } +} + +void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer) { + if (pointer == 0) { return; } + if (objectType == LOT_NONE || objectType >= LOT_MAX) { return; } + + if (sCachedAllowed[objectType] == pointer) { return; } + sCachedAllowed[objectType] = pointer; + + struct CObjectAllowListNode* curNode = sAllowList[objectType]; + struct CObjectAllowListNode* prevNode = NULL; + while (curNode != NULL) { + if (pointer == curNode->pointer) { return; } + if (pointer < curNode->pointer) { break; } + prevNode = curNode; + curNode = curNode->next; + } + + struct CObjectAllowListNode* node = malloc(sizeof(struct CObjectAllowListNode)); + node->pointer = pointer; + node->next = curNode; + if (prevNode == NULL) { + sAllowList[objectType] = node; + } else { + prevNode->next = node; + } +} + +bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer) { + if (pointer == 0) { return false; } + if (objectType == LOT_NONE || objectType >= LOT_MAX) { return false; } + if (sCachedAllowed[objectType] == pointer) { return true; } + + struct CObjectAllowListNode* node = sAllowList[objectType]; + while (node != NULL) { + if (pointer == node->pointer) { return true; } + if (pointer < node->pointer) { return false; } + node = node->next; + } + return false; +} \ No newline at end of file diff --git a/src/pc/lua/smlua_cobject_allowlist.h b/src/pc/lua/smlua_cobject_allowlist.h new file mode 100644 index 00000000..68501991 --- /dev/null +++ b/src/pc/lua/smlua_cobject_allowlist.h @@ -0,0 +1,9 @@ +#ifndef SMLUA_COBJECT_ALLOWLIST_H +#define SMLUA_COBJECT_ALLOWLIST_H + +void smlua_cobject_allowlist_init(void); +void smlua_cobject_allowlist_shutdown(void); +void smlua_cobject_allowlist_add(enum LuaObjectType objectType, u64 pointer); +bool smlua_cobject_allowlist_contains(enum LuaObjectType objectType, u64 pointer); + +#endif \ No newline at end of file diff --git a/src/pc/lua/smlua_utils.c b/src/pc/lua/smlua_utils.c index f3cced4c..f04ca5e9 100644 --- a/src/pc/lua/smlua_utils.c +++ b/src/pc/lua/smlua_utils.c @@ -71,6 +71,7 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) { return 0; } + // get LOT lua_getfield(L, index, "_lot"); enum LuaObjectType objLot = smlua_to_integer(L, -1); lua_pop(L, 1); @@ -83,11 +84,18 @@ void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) { return NULL; } + // get pointer lua_getfield(L, index, "_pointer"); void* pointer = (void*)smlua_to_integer(L, -1); lua_pop(L, 1); if (!gSmLuaConvertSuccess) { return NULL; } - // TODO: check address whitelists + + // check allowlist + if (!smlua_cobject_allowlist_contains(lot, (u64)pointer)) { + LOG_LUA("LUA: smlua_to_cobject received a pointer not in allow list. '%u', '%llu", lot, (u64)pointer); + gSmLuaConvertSuccess = false; + return NULL; + } if (pointer == NULL) { LOG_LUA("LUA: smlua_to_cobject received null pointer."); @@ -107,6 +115,9 @@ void smlua_push_object(lua_State* L, enum LuaObjectType lot, void* p) { lua_pushnil(L); return; } + // add to allowlist + smlua_cobject_allowlist_add(lot, (u64)p); + lua_newtable(L); int t = lua_gettop(L); smlua_push_integer_field(t, "_lot", lot);