Skip to content

Commit

Permalink
Add threads, continuation support and yield/breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
bjcscat committed Dec 6, 2024
1 parent dd37205 commit ec3d0f9
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 21 deletions.
210 changes: 189 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod compile;

mod ffi;
mod memory;
mod threads;
mod userdata;

use core::str;
Expand All @@ -12,6 +13,7 @@ use std::{
ffi::{c_int, c_uint, CString},
os::raw::c_void,
ptr::{null, null_mut},
rc::Rc,
slice,
};

Expand All @@ -20,6 +22,7 @@ use ffi::{
prelude::*,
};
use memory::{luau_alloc_cb, DefaultLuauAllocator};
use threads::{LuauThread, MainStateDeadError};
use userdata::{
drop_userdata, dtor_rs_luau_userdata_callback, Userdata, UserdataBorrowError, UserdataRef,
UserdataRefMut, UD_TAG,
Expand All @@ -37,6 +40,7 @@ macro_rules! luau_stack_precondition {
}

struct AssociatedData {
main_thread_rc: Rc<Cell<bool>>,
allocator: Box<dyn LuauAllocator>,
app_data: Option<Box<dyn Any>>,
}
Expand All @@ -56,6 +60,7 @@ pub struct Luau {
impl Luau {
unsafe fn new_state(allocator: impl LuauAllocator + 'static) -> *mut _LuaState {
let associated_data = Box::new(AssociatedData {
main_thread_rc: Rc::new(Cell::new(true)),
app_data: None,
allocator: Box::new(allocator),
});
Expand Down Expand Up @@ -98,6 +103,14 @@ impl Luau {
}
}

/// Creates a Luau struct from a raw state pointer
///
/// # Safety
/// The pointer must be a valid Luau state and must not alias a Luau struct
pub unsafe fn from_ptr_owned(state: *mut _LuaState) -> Self {
Self { owned: true, state }
}

const ASSOCIATED_DATA_ERROR: &str = "Expected associated data structure";

pub(crate) fn get_associated(&self) -> &AssociatedData {
Expand All @@ -109,12 +122,14 @@ impl Luau {
}
}

pub(crate) fn get_associated_mut(&mut self) -> &mut AssociatedData {
pub(crate) fn get_associated_mut(&self) -> *mut AssociatedData {
unsafe {
let mut ptr: *mut AssociatedData = null_mut();
lua_getallocf(self.state, &raw mut ptr as _);

ptr.as_mut().expect(Self::ASSOCIATED_DATA_ERROR)
assert!(!ptr.is_null(), "{}", Self::ASSOCIATED_DATA_ERROR);

ptr
}
}

Expand All @@ -125,16 +140,16 @@ impl Luau {
.and_then(|v| v.downcast_ref())
}

pub fn get_app_data_mut<T: Any>(&mut self) -> Option<&mut T> {
self.get_associated_mut()
pub fn get_app_data_mut<T: Any>(&self) -> Option<&mut T> {
unsafe { &mut *self.get_associated_mut() }
.app_data
.as_mut()
.and_then(|v| v.downcast_mut())
}

/// Sets the associated app data for the Luau state returning the previous value
pub fn set_app_data<T: Any>(&mut self, ud: Option<T>) -> Option<Box<dyn Any>> {
let associated = self.get_associated_mut();
pub fn set_app_data<T: Any>(&self, ud: Option<T>) -> Option<Box<dyn Any>> {
let associated = unsafe { &mut *self.get_associated_mut() };

if let Some(v) = ud {
let boxed_data = Box::new(v);
Expand All @@ -155,6 +170,34 @@ impl Luau {
unsafe { lua_gettop(self.state) }
}

/// Returns the status of the Luau state
pub fn status(&self) -> LuauStatus {
unsafe { lua_status(self.state) }
}

/// Yields the luau state with the number of results
///
/// Should be used as the end expression or a return from a function as this returns `-1`
pub fn yield_luau(&self, nresults: c_int) -> c_int {
assert!(
self.top() >= nresults,
"The number of yield returns must not exceed the stack size"
);

unsafe {
lua_yield(self.state, nresults)
}
}

/// Breaks the luau state for the purposes of a debug interrupt
///
/// Should be used as the end expression or a return from a function as this returns `-1`
pub fn break_luau(&self) -> c_int {
unsafe {
lua_break(self.state)
}
}

/// Returns the type of a luau value at `idx`
pub fn type_of(&self, idx: c_int) -> LuauType {
luau_stack_precondition!(self.check_index(idx));
Expand Down Expand Up @@ -216,11 +259,6 @@ impl Luau {
unsafe { lua_checkstack(self.state, sz) == 1 }
}

/// Returns the status of the Luau state
pub fn status(&self) -> LuaStatus {
unsafe { lua_status(self.state) }
}

#[inline]
pub fn registry(&self) -> c_int {
LUA_REGISTRYINDEX
Expand Down Expand Up @@ -713,6 +751,40 @@ impl Luau {
self.type_of(idx) == LuauType::LUA_TVECTOR
}

/// Returns true if the value at `idx` is a thread, false otherwise
pub fn is_thread(&self, idx: c_int) -> bool {
self.type_of(idx) == LuauType::LUA_TTHREAD
}

pub fn push_thread(&self) -> LuauThread {
unsafe {
let thread_ptr = lua_newthread(self.state);
LuauThread::from_ptr(thread_ptr, self.get_associated().main_thread_rc.clone())
}
}

pub fn get_thread(&self, idx: c_int) -> Option<LuauThread> {
let ptr = unsafe { lua_tothread(self.state, idx) };

if !ptr.is_null() {
unsafe {
Some(LuauThread::from_ptr(
ptr,
self.get_associated().main_thread_rc.clone(),
))
}
} else {
None
}
}

/// Resumes the given Luau thread with the number of arguments.
///
/// Will resume the function on the top of the given Luau thread's execution stack
pub fn resume(&self, luau_thread: LuauThread, nargs: c_int) -> LuauStatus {
unsafe { lua_resume(luau_thread.get_state().state, self.state, nargs) }
}

/// Returns true if the value at `idx` is a function, false otherwise
pub fn is_function(&self, idx: c_int) -> bool {
self.type_of(idx) == LuauType::LUA_TFUNCTION
Expand All @@ -726,9 +798,10 @@ impl Luau {
/// You will need to uphold all safety invariants with respect to the Luau VM in the user supplied `func`
pub unsafe fn push_raw_function(
&self,
func: unsafe extern "C-unwind" fn(*mut _LuaState) -> c_int,
func: CFunction,
debug_name: Option<&str>,
num_upvals: c_int,
continuation: Option<LuaContinuation>,
) {
luau_stack_precondition!(self.check_stack(1));

Expand All @@ -739,7 +812,7 @@ impl Luau {

// SAFETY: upvalue count and stack size are validated as a precondition and assert
unsafe {
lua_pushcclosure(
lua_pushcclosurek(
self.state,
func,
if let Some(name) = debug_name {
Expand All @@ -750,6 +823,74 @@ impl Luau {
null()
},
num_upvals,
continuation,
);
}
}

/// Pushes a Rust function into Luau with an associated continuation
///
/// This function wraps a Rust function to allow closures to capture values, to avoid this minor overhead you can use `push_function_raw`
pub fn push_function_continuation<
F: FnMut(Luau) -> c_int,
Cont: FnMut(Luau, LuauStatus) -> c_int,
>(
&self,
func: F,
debug_name: Option<&str>,
num_upvals: c_int,
cont: Cont,
) {
assert!(
self.top() >= num_upvals,
"The number of upvalues for a raw function must not exceed the stack length"
);

luau_stack_precondition!(self.check_stack(2));

struct CallState<F, Cont> {
func: F,
cont: Cont,
}

let call_state = Box::new(CallState { func, cont });

unsafe extern "C-unwind" fn invoke_fn<
F: FnMut(Luau) -> c_int,
Cont: FnMut(Luau, LuauStatus) -> c_int,
>(
state: *mut _LuaState,
) -> c_int {
let call_state =
lua_tolightuserdata(state, lua_upvalueindex(1)).cast::<CallState<F, Cont>>();

((*call_state).func)(Luau::from_ptr(state))
}

unsafe extern "C-unwind" fn invoke_continuation<
F: FnMut(Luau) -> c_int,
Cont: FnMut(Luau, LuauStatus) -> c_int,
>(
state: *mut _LuaState,
status: c_int,
) -> c_int {
let call_state =
lua_tolightuserdata(state, lua_upvalueindex(1)).cast::<CallState<F, Cont>>();

((*call_state).cont)(
Luau::from_ptr(state),
std::mem::transmute::<c_int, LuauStatus>(status),
)
}

unsafe {
lua_pushlightuserdata(self.state, Box::into_raw(call_state) as _);

self.push_raw_function(
invoke_fn::<F, Cont>,
debug_name,
1 + num_upvals,
Some(invoke_continuation::<F, Cont>),
);
}
}
Expand Down Expand Up @@ -783,12 +924,12 @@ impl Luau {
unsafe {
lua_pushlightuserdata(self.state, Box::into_raw(func_box) as _);

self.push_raw_function(invoke_fn::<F>, debug_name, 1 + num_upvals);
self.push_raw_function(invoke_fn::<F>, debug_name, 1 + num_upvals, None);
}
}

/// Calls the Luau function at the top of the stack returning the status of the Luau state when it returns
pub fn call(&self, nargs: c_int, nresults: c_int) -> LuaStatus {
pub fn call(&self, nargs: c_int, nresults: c_int) -> LuauStatus {
assert!(
self.is_function(-1),
"The value at top of stack must be a function"
Expand Down Expand Up @@ -855,14 +996,14 @@ unsafe extern "C-unwind" fn fatal_runtime_error_handler(state: *mut _LuaState) -
}

/// Final resting place for Luau code, we don't return from this.
unsafe extern "C-unwind" fn fatal_error_handler(state: *mut _LuaState, status: LuaStatus) {
unsafe extern "C-unwind" fn fatal_error_handler(state: *mut _LuaState, status: LuauStatus) {
match status {
// Unhandled runtime error
LuaStatus::LUA_ERRRUN => fatal_runtime_error_handler(state),
LuauStatus::LUA_ERRRUN => fatal_runtime_error_handler(state),
// memory allocation error, just die
LuaStatus::LUA_ERRMEM => std::process::abort(),
LuauStatus::LUA_ERRMEM => std::process::abort(),
// some error handling mechanism errored
LuaStatus::LUA_ERRERR => panic!("Error originating from error handling mechanism"),
LuauStatus::LUA_ERRERR => panic!("Error originating from error handling mechanism"),
// shouldnt be reachable
_ => unreachable!(),
};
Expand All @@ -888,6 +1029,9 @@ impl Drop for Luau {

let associated_owned = Box::from_raw(associated);

// mark main thread dead
associated_owned.main_thread_rc.set(false);

lua_close(self.state);

_ = associated_owned
Expand Down Expand Up @@ -1044,9 +1188,32 @@ mod tests {
assert_eq!(luau.top(), 0);
}

#[test]
fn threads() {
let luau = Luau::default();

let thread = luau.push_thread();
let thread_state = thread.get_state();

let mut was_called = false;

thread_state.push_function(
|_| {
was_called = true;
0
},
None,
0,
);

luau.resume(thread, 0);

assert!(was_called, "Expected thread function to be called");
}

#[test]
fn app_data() {
let mut luau = Luau::default();
let luau = Luau::default();

luau.set_app_data(Some(true));

Expand Down Expand Up @@ -1093,8 +1260,9 @@ mod tests {
}

unsafe {
luau.push_raw_function(test_extern_fn, Some("test"), 3);
luau.push_raw_function(test_extern_fn, Some("test"), 3, None);
}

luau.call(0, 0);
}

Expand Down
Loading

0 comments on commit ec3d0f9

Please sign in to comment.