Skip to content

Commit

Permalink
smlua improvements (#607)
Browse files Browse the repository at this point in the history
* smlua improvements

* fix non dev compile error

* fixes
  • Loading branch information
Isaac0-dev authored Jan 5, 2025
1 parent dff6634 commit 939218d
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 102 deletions.
3 changes: 0 additions & 3 deletions src/pc/lua/smlua.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "smlua.h"
#include "smlua_cobject_map.h"
#include "game/hardcoded.h"
#include "pc/mods/mods.h"
#include "pc/mods/mods_utils.h"
Expand Down Expand Up @@ -282,7 +281,6 @@ static void smlua_load_script(struct Mod* mod, struct ModFile* file, u16 remoteI

void smlua_init(void) {
smlua_shutdown();
smlua_pointer_user_data_init();

gLuaState = luaL_newstate();
lua_State* L = gLuaState;
Expand Down Expand Up @@ -362,7 +360,6 @@ void smlua_shutdown(void) {
smlua_text_utils_reset_all();
smlua_audio_utils_reset_all();
audio_custom_shutdown();
smlua_pointer_user_data_shutdown();
smlua_clear_hooks();
smlua_model_util_clear();
smlua_level_util_reset();
Expand Down
4 changes: 3 additions & 1 deletion src/pc/lua/smlua.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
#define LOG_LUA_LINE_WARNING(...) { if (!gLuaActiveMod->showedScriptWarning) { gLuaActiveMod->showedScriptWarning = true; smlua_mod_warning(); snprintf(gDjuiConsoleTmpBuffer, CONSOLE_MAX_TMP_BUFFER, __VA_ARGS__), sys_swap_backslashes(gDjuiConsoleTmpBuffer), djui_console_message_create(gDjuiConsoleTmpBuffer, CONSOLE_MESSAGE_WARNING); } }

#ifdef DEVELOPMENT
#define LUA_STACK_CHECK_BEGIN() int __LUA_STACK_TOP = lua_gettop(gLuaState)
#define LUA_STACK_CHECK_BEGIN_NUM(n) int __LUA_STACK_TOP = lua_gettop(gLuaState) + (n)
#define LUA_STACK_CHECK_BEGIN() LUA_STACK_CHECK_BEGIN_NUM(0)
#define LUA_STACK_CHECK_END() if ((__LUA_STACK_TOP) != lua_gettop(gLuaState)) { smlua_dump_stack(); fflush(stdout); } assert((__LUA_STACK_TOP) == lua_gettop(gLuaState))
#else
#define LUA_STACK_CHECK_BEGIN_NUM(n)
#define LUA_STACK_CHECK_BEGIN()
#define LUA_STACK_CHECK_END()
#endif
Expand Down
90 changes: 46 additions & 44 deletions src/pc/lua/smlua_cobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
#include "object_fields.h"
#include "pc/djui/djui_hud_utils.h"
#include "pc/lua/smlua.h"
#include "pc/lua/smlua_cobject_map.h"
#include "pc/lua/utils/smlua_anim_utils.h"
#include "pc/lua/utils/smlua_collision_utils.h"
#include "pc/lua/utils/smlua_obj_utils.h"
#include "pc/mods/mods.h"

extern struct LuaObjectTable sLuaObjectTable[LOT_MAX];

int gSmLuaCObjects = 0;
int gSmLuaCPointers = 0;
int gSmLuaCObjectMetatable = 0;
int gSmLuaCPointerMetatable = 0;

struct LuaObjectField* smlua_get_object_field_from_ot(struct LuaObjectTable* ot, const char* key) {
// binary search
s32 min = 0;
Expand Down Expand Up @@ -324,26 +328,28 @@ struct LuaObjectField* smlua_get_custom_field(lua_State* L, u32 lot, int keyInde
/////////////////////

static int smlua__get_field(lua_State* L) {
LUA_STACK_CHECK_BEGIN();
LUA_STACK_CHECK_BEGIN_NUM(1);

CObject *cobj = lua_touserdata(L, 1);
const CObject *cobj = lua_touserdata(L, 1);
enum LuaObjectType lot = cobj->lot;
u64 pointer = (u64)(intptr_t) cobj->pointer;

const char *key = smlua_to_string(L, 2);
if (!gSmLuaConvertSuccess) {
const char *key = lua_tostring(L, 2);
if (!key) {
LOG_LUA_LINE("Tried to get a non-string field of cobject");
return 0;
}

// Legacy support
if (strcmp(key, "_pointer") == 0) {
lua_pushinteger(L, pointer);
return 1;
}
if (strcmp(key, "_lot") == 0) {
lua_pushinteger(L, cobj->lot);
return 1;
if (key[0] == '_') {
if (strcmp(key, "_lot") == 0) {
lua_pushinteger(L, lot);
return 1;
}
if (strcmp(key, "_pointer") == 0) {
lua_pushinteger(L, pointer);
return 1;
}
}

if (cobj->freed) {
Expand All @@ -360,8 +366,6 @@ static int smlua__get_field(lua_State* L) {
return 0;
}

LUA_STACK_CHECK_END();

u8* p = ((u8*)(intptr_t)pointer) + data->valueOffset;
switch (data->valueType) {
case LVT_BOOL: lua_pushboolean(L, *(u8* )p); break;
Expand Down Expand Up @@ -406,18 +410,19 @@ static int smlua__get_field(lua_State* L) {
return 0;
}

LUA_STACK_CHECK_END();
return 1;
}

static int smlua__set_field(lua_State* L) {
LUA_STACK_CHECK_BEGIN();

CObject *cobj = lua_touserdata(L, 1);
const CObject *cobj = lua_touserdata(L, 1);
enum LuaObjectType lot = cobj->lot;
u64 pointer = (u64)(intptr_t) cobj->pointer;

const char *key = smlua_to_string(L, 2);
if (!gSmLuaConvertSuccess) {
const char *key = lua_tostring(L, 2);
if (!key) {
LOG_LUA_LINE("Tried to set a non-string field of cobject");
return 0;
}
Expand Down Expand Up @@ -496,37 +501,27 @@ static int smlua__set_field(lua_State* L) {
}

int smlua__eq(lua_State *L) {
CObject *a = lua_touserdata(L, 1);
CObject *b = lua_touserdata(L, 2);
lua_pushboolean(L, a->lot == b->lot && a->pointer == b->pointer);
const CObject *a = lua_touserdata(L, 1);
const CObject *b = lua_touserdata(L, 2);
lua_pushboolean(L, a && b && a->lot == b->lot && a->pointer == b->pointer);
return 1;
}

int smlua__gc(lua_State *L) {
CObject *cobj = lua_touserdata(L, 1);
if (!cobj->freed) {
switch (cobj->lot) {
case LOT_SURFACE: {
smlua_pointer_user_data_delete((uintptr_t) cobj->pointer);
}
}
}
return 0;
}

static int smlua_cpointer_get(lua_State* L) {
CPointer *cptr = lua_touserdata(L, 1);
const char *key = smlua_to_string(L, 2);
const CPointer *cptr = lua_touserdata(L, 1);
const char *key = lua_tostring(L, 2);
if (key == NULL) { return 0; }

// Legacy support
if (strcmp(key, "_pointer") == 0) {
lua_pushinteger(L, (u64)(intptr_t) cptr->pointer);
return 1;
}
if (strcmp(key, "_lot") == 0) {
lua_pushinteger(L, cptr->lvt);
return 1;
if (key[0] == '_') {
if (strcmp(key, "_pointer") == 0) {
lua_pushinteger(L, (u64)(intptr_t) cptr->pointer);
return 1;
}
if (strcmp(key, "_lot") == 0) {
lua_pushinteger(L, cptr->lvt);
return 1;
}
}

return 0;
Expand All @@ -540,26 +535,33 @@ static int smlua_cpointer_set(UNUSED lua_State* L) { return 0; }
void smlua_cobject_init_globals(void) {
lua_State* L = gLuaState;

// Create object pools
lua_newtable(L);
gSmLuaCObjects = luaL_ref(L, LUA_REGISTRYINDEX);
lua_newtable(L);
gSmLuaCPointers = luaL_ref(L, LUA_REGISTRYINDEX);

// Create metatables
luaL_newmetatable(L, "CObject");
luaL_Reg cObjectMethods[] = {
{ "__index", smlua__get_field },
{ "__newindex", smlua__set_field },
{ "__eq", smlua__eq },
{ "__gc", smlua__gc },
{ "__metatable", NULL },
{ NULL, NULL }
};
luaL_setfuncs(L, cObjectMethods, 0);
lua_pop(L, 1);
gSmLuaCObjectMetatable = luaL_ref(L, LUA_REGISTRYINDEX);
luaL_newmetatable(L, "CPointer");
luaL_Reg cPointerMethods[] = {
{ "__index", smlua_cpointer_get },
{ "__newindex", smlua_cpointer_set },
{ "__eq", smlua__eq },
{ "__metatable", NULL },
{ NULL, NULL }
};
luaL_setfuncs(L, cPointerMethods, 0);
lua_pop(L, 1);
gSmLuaCPointerMetatable = luaL_ref(L, LUA_REGISTRYINDEX);

#define EXPOSE_GLOBAL_ARRAY(lot, ptr, iterator) \
{ \
Expand Down
5 changes: 5 additions & 0 deletions src/pc/lua/smlua_cobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ typedef struct {
bool freed;
} CPointer;

extern int gSmLuaCObjects;
extern int gSmLuaCPointers;
extern int gSmLuaCObjectMetatable;
extern int gSmLuaCPointerMetatable;

bool smlua_valid_lot(u16 lot);
bool smlua_valid_lvt(u16 lvt);
struct LuaObjectField* smlua_get_object_field_from_ot(struct LuaObjectTable* ot, const char* key);
Expand Down
32 changes: 0 additions & 32 deletions src/pc/lua/smlua_cobject_map.c

This file was deleted.

10 changes: 0 additions & 10 deletions src/pc/lua/smlua_cobject_map.h

This file was deleted.

65 changes: 53 additions & 12 deletions src/pc/lua/smlua_utils.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "smlua.h"
#include "smlua_cobject_map.h"
#include "pc/mods/mods.h"
#include "audio/external.h"

Expand Down Expand Up @@ -354,33 +353,60 @@ void smlua_push_object(lua_State* L, u16 lot, void* p) {
lua_pushnil(L);
return;
}
LUA_STACK_CHECK_BEGIN_NUM(1);

uintptr_t key = lot ^ (uintptr_t) p;
lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjects);
lua_pushinteger(L, key);
lua_gettable(L, -2);
if (lua_isuserdata(L, -1)) {
lua_remove(L, -2); // Remove gSmLuaCObjects table
return;
}
lua_pop(L, 1);

CObject *cobject = lua_newuserdata(L, sizeof(CObject));
cobject->pointer = p;
cobject->lot = lot;
cobject->freed = false;
luaL_getmetatable(L, "CObject");
lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjectMetatable);
lua_setmetatable(L, -2);
lua_pushinteger(L, key);
lua_pushvalue(L, -2); // Duplicate userdata
lua_settable(L, -4);
lua_remove(L, -2); // Remove gSmLuaCObjects table

switch (lot) {
case LOT_SURFACE: {
smlua_pointer_user_data_add((uintptr_t) p, cobject);
}
}
LUA_STACK_CHECK_END();
}

void smlua_push_pointer(lua_State* L, u16 lvt, void* p) {
if (p == NULL) {
lua_pushnil(L);
return;
}
LUA_STACK_CHECK_BEGIN_NUM(1);

uintptr_t key = lvt ^ (uintptr_t) p;
lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCPointers);
lua_pushinteger(L, key);
lua_gettable(L, -2);
if (lua_isuserdata(L, -1)) {
lua_remove(L, -2); // Remove gSmLuaCPointers table
return;
}
lua_pop(L, 1);

CPointer *cpointer = lua_newuserdata(L, sizeof(CPointer));
cpointer->pointer = p;
cpointer->lvt = lvt;
cpointer->freed = false;
luaL_getmetatable(L, "CPointer");
lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCPointerMetatable);
lua_setmetatable(L, -2);
lua_pushinteger(L, key);
lua_pushvalue(L, -2); // Duplicate userdata
lua_settable(L, -4);
lua_remove(L, -2); // Remove gSmLuaCPointers table
LUA_STACK_CHECK_END();
}

void smlua_push_integer_field(int index, const char* name, lua_Integer val) {
Expand Down Expand Up @@ -710,7 +736,7 @@ void smlua_logline(void) {
while (lua_getstack(L, level, &info)) {
lua_getinfo(L, "nSl", &info);

// Get the folder and file of the crash
// Get the folder and file
// in the format: "folder/file.lua"
const char* src = info.source;
int slashCount = 0;
Expand All @@ -733,13 +759,28 @@ void smlua_logline(void) {

// If an object is freed that Lua has a CObject to,
// Lua is able to use-after-free that pointer
// todo figure out a better way to do this
void smlua_free(void *ptr) {
if (ptr && gLuaState) {
CObject *obj = smlua_pointer_user_data_get((uintptr_t) ptr);
if (obj) {
lua_State *L = gLuaState;
LUA_STACK_CHECK_BEGIN();
u16 lot = LOT_SURFACE; // Assuming this is a surface
uintptr_t key = lot ^ (uintptr_t) ptr;
lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjects);
lua_pushinteger(L, key);
lua_gettable(L, -2);
CObject *obj = (CObject *) lua_touserdata(L, -1);
if (obj && obj->pointer == ptr) {
obj->freed = true;
smlua_pointer_user_data_delete((uintptr_t) ptr);
lua_pop(L, 1);
lua_pushinteger(L, key);
lua_pushnil(L);
lua_settable(L, -3);
} else {
lua_pop(L, 1);
}
lua_pop(L, 1);
LUA_STACK_CHECK_END();
}
free(ptr);
}

0 comments on commit 939218d

Please sign in to comment.