diff --git a/Cargo.lock b/Cargo.lock index 59d32a97..d316fca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1571,6 +1571,7 @@ dependencies = [ "mlua-luau-scheduler", "serde", "serde_json", + "thiserror", "tokio", ] diff --git a/crates/lune-std/Cargo.toml b/crates/lune-std/Cargo.toml index 07762b64..f6901e84 100644 --- a/crates/lune-std/Cargo.toml +++ b/crates/lune-std/Cargo.toml @@ -45,6 +45,8 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1", default-features = false, features = ["fs", "sync"] } +thiserror = "1.0.63" + lune-utils = { version = "0.1.3", path = "../lune-utils" } lune-std-datetime = { optional = true, version = "0.1.3", path = "../lune-std-datetime" } diff --git a/crates/lune-std/src/globals/require/alias.rs b/crates/lune-std/src/globals/require/alias.rs deleted file mode 100644 index 924056b7..00000000 --- a/crates/lune-std/src/globals/require/alias.rs +++ /dev/null @@ -1,73 +0,0 @@ -use mlua::prelude::*; - -use lune_utils::path::{clean_path_and_make_absolute, diff_path, get_current_dir}; - -use crate::luaurc::LuauRc; - -use super::context::*; - -pub(super) async fn require<'lua, 'ctx>( - lua: &'lua Lua, - ctx: &'ctx RequireContext, - source: &str, - alias: &str, - path: &str, -) -> LuaResult> -where - 'lua: 'ctx, -{ - let alias = alias.to_ascii_lowercase(); - - let parent = clean_path_and_make_absolute(source) - .parent() - .expect("how did a root path end up here..") - .to_path_buf(); - - // Try to gather the first luaurc and / or error we - // encounter to display better error messages to users - let mut first_luaurc = None; - let mut first_error = None; - let predicate = |rc: &LuauRc| { - if first_luaurc.is_none() { - first_luaurc.replace(rc.clone()); - } - if let Err(e) = rc.validate() { - if first_error.is_none() { - first_error.replace(e); - } - false - } else { - rc.find_alias(&alias).is_some() - } - }; - - // Try to find a luaurc that contains the alias we're searching for - let luaurc = LuauRc::read_recursive(parent, predicate) - .await - .ok_or_else(|| { - if let Some(error) = first_error { - LuaError::runtime(format!("error while parsing .luaurc file: {error}")) - } else if let Some(luaurc) = first_luaurc { - LuaError::runtime(format!( - "failed to find alias '{alias}' - known aliases:\n{}", - luaurc - .aliases() - .iter() - .map(|(name, path)| format!(" {name} > {path}")) - .collect::>() - .join("\n") - )) - } else { - LuaError::runtime(format!("failed to find alias '{alias}' (no .luaurc)")) - } - })?; - - // We now have our aliased path, our path require function just needs it - // in a slightly different format with both absolute + relative to cwd - let abs_path = luaurc.find_alias(&alias).unwrap().join(path); - let rel_path = diff_path(&abs_path, get_current_dir()).ok_or_else(|| { - LuaError::runtime(format!("failed to find relative path for alias '{alias}'")) - })?; - - super::path::require_abs_rel(lua, ctx, abs_path, rel_path).await -} diff --git a/crates/lune-std/src/globals/require/context.rs b/crates/lune-std/src/globals/require/context.rs index 0355d270..753b3730 100644 --- a/crates/lune-std/src/globals/require/context.rs +++ b/crates/lune-std/src/globals/require/context.rs @@ -1,289 +1,276 @@ -use std::{ - collections::HashMap, - path::{Path, PathBuf}, - sync::Arc, -}; - +use crate::{library::StandardLibrary, luaurc::RequireAlias}; use mlua::prelude::*; -use mlua_luau_scheduler::LuaSchedulerExt; - +use mlua_luau_scheduler::{IntoLuaThread, LuaSchedulerExt}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; use tokio::{ - fs::read, + fs, sync::{ broadcast::{self, Sender}, - Mutex as AsyncMutex, + Mutex, }, }; -use lune_utils::path::{clean_path, clean_path_and_make_absolute}; - -use crate::library::LuneStandardLibrary; - -/** - Context containing cached results for all `require` operations. +use super::RequireError; - The cache uses absolute paths, so any given relative - path will first be transformed into an absolute path. -*/ -#[derive(Debug, Clone)] -pub(super) struct RequireContext { - libraries: Arc>>>, - results: Arc>>>, - pending: Arc>>>, +/// The private struct that's stored in mlua's app data container +#[derive(Debug, Default)] +struct RequireContextData<'a> { + std: HashMap<&'a str, HashMap<&'a str, Box>>, + std_cache: HashMap, + cache: Arc>>, + pending: Arc>>>, } +#[derive(Debug)] +pub struct RequireContext {} + impl RequireContext { /** - Creates a new require context for the given [`Lua`] struct. - - Note that this require context is global and only one require - context should be created per [`Lua`] struct, creating more - than one context may lead to undefined require-behavior. - */ - pub fn new() -> Self { - Self { - libraries: Arc::new(AsyncMutex::new(HashMap::new())), - results: Arc::new(AsyncMutex::new(HashMap::new())), - pending: Arc::new(AsyncMutex::new(HashMap::new())), + + # Errors + + - when `RequireContext::init` is called more than once on the same `Lua` instance + + */ + pub(crate) fn init(lua: &Lua) -> Result<(), RequireError> { + if lua.set_app_data(RequireContextData::default()).is_some() { + Err(RequireError::RequireContextInitCalledTwice) + } else { + Ok(()) } } - /** - Resolves the given `source` and `path` into require paths - to use, based on the current require context settings. - - This will resolve path segments such as `./`, `../`, ..., and - if the resolved path is not an absolute path, will create an - absolute path by prepending the current working directory. - */ - pub fn resolve_paths( - source: impl AsRef, - path: impl AsRef, - ) -> LuaResult<(PathBuf, PathBuf)> { - let path = PathBuf::from(source.as_ref()) - .parent() - .ok_or_else(|| LuaError::runtime("Failed to get parent path of source"))? - .join(path.as_ref()); - - let abs_path = clean_path_and_make_absolute(&path); - let rel_path = clean_path(path); - - Ok((abs_path, rel_path)) - } + pub(crate) fn std_exists(lua: &Lua, alias: &str) -> Result { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; - /** - Checks if the given path has a cached require result. - */ - pub fn is_cached(&self, abs_path: impl AsRef) -> LuaResult { - let is_cached = self - .results - .try_lock() - .expect("RequireContext may not be used from multiple threads") - .contains_key(abs_path.as_ref()); - Ok(is_cached) + Ok(data_ref.std.contains_key(alias)) } - /** - Checks if the given path is currently being used in `require`. - */ - pub fn is_pending(&self, abs_path: impl AsRef) -> LuaResult { - let is_pending = self - .pending - .try_lock() - .expect("RequireContext may not be used from multiple threads") - .contains_key(abs_path.as_ref()); - Ok(is_pending) - } + pub(crate) fn require_std( + lua: &Lua, + require_alias: RequireAlias, + ) -> Result, RequireError> { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; - /** - Gets the resulting value from the require cache. + if let Some(cached) = data_ref.std_cache.get(&require_alias) { + let multi_vec = lua.registry_value::>(cached)?; + + return Ok(LuaMultiValue::from_vec(multi_vec)); + } + + let libraries = data_ref + .std + .get(&require_alias.alias.as_str()) + .ok_or_else(|| RequireError::InvalidStdAlias(require_alias.alias.to_string()))?; + + let std = libraries.get(require_alias.path.as_str()).ok_or_else(|| { + RequireError::StdMemberNotFound( + require_alias.path.to_string(), + require_alias.alias.to_string(), + ) + })?; - Will panic if the path has not been cached, use [`is_cached`] first. - */ - pub fn get_from_cache<'lua>( - &self, + let multi = std.module(lua)?; + let mutli_clone = multi.clone(); + let multi_reg = lua.create_registry_value(mutli_clone.into_vec())?; + + drop(data_ref); + + let mut data = lua + .app_data_mut::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + data.std_cache.insert(require_alias, multi_reg); + + Ok(multi) + } + + async fn wait_for_pending<'lua>( lua: &'lua Lua, - abs_path: impl AsRef, - ) -> LuaResult> { - let results = self - .results - .try_lock() - .expect("RequireContext may not be used from multiple threads"); - - let cached = results - .get(abs_path.as_ref()) - .expect("Path does not exist in results cache"); - match cached { - Err(e) => Err(e.clone()), - Ok(k) => { - let multi_vec = lua - .registry_value::>(k) - .expect("Missing require result in lua registry"); - Ok(LuaMultiValue::from_vec(multi_vec)) - } + path_abs: &'_ PathBuf, + ) -> Result<(), RequireError> { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + let pending = data_ref.pending.try_lock()?; + + if let Some(sender) = pending.get(path_abs) { + let mut receiver = sender.subscribe(); + + // unlock mutex before using async + drop(pending); + + receiver.recv().await?; } + + Ok(()) } - /** - Waits for the resulting value from the require cache. + fn is_pending(lua: &Lua, path_abs: &PathBuf) -> Result { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; - Will panic if the path has not been cached, use [`is_cached`] first. - */ - pub async fn wait_for_cache<'lua>( - &self, - lua: &'lua Lua, - abs_path: impl AsRef, - ) -> LuaResult> { - let mut thread_recv = { - let pending = self - .pending - .try_lock() - .expect("RequireContext may not be used from multiple threads"); - let thread_id = pending - .get(abs_path.as_ref()) - .expect("Path is not currently pending require"); - thread_id.subscribe() - }; + let pending = data_ref.pending.try_lock()?; + + Ok(pending.get(path_abs).is_some()) + } - thread_recv.recv().await.into_lua_err()?; + fn is_cached(lua: &Lua, path_abs: &PathBuf) -> Result { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; - self.get_from_cache(lua, abs_path.as_ref()) + let cache = data_ref.cache.try_lock()?; + + Ok(cache.get(path_abs).is_some()) } - async fn load<'lua>( - &self, + async fn from_cache<'lua>( lua: &'lua Lua, - abs_path: impl AsRef, - rel_path: impl AsRef, - ) -> LuaResult { - let abs_path = abs_path.as_ref(); - let rel_path = rel_path.as_ref(); - - // Read the file at the given path, try to parse and - // load it into a new lua thread that we can schedule - let file_contents = read(&abs_path).await?; - let file_thread = lua - .load(file_contents) - .set_name(rel_path.to_string_lossy().to_string()); - - // Schedule the thread to run, wait for it to finish running - let thread_id = lua.push_thread_back(file_thread, ())?; - lua.track_thread(thread_id); - lua.wait_for_thread(thread_id).await; - let thread_res = lua.get_thread_result(thread_id).unwrap(); - - // Return the result of the thread, storing any lua value(s) in the registry - match thread_res { - Err(e) => Err(e), - Ok(v) => { - let multi_vec = v.into_vec(); - let multi_key = lua - .create_registry_value(multi_vec) - .expect("Failed to store require result in registry - out of memory"); - Ok(multi_key) + path_abs: &'_ PathBuf, + ) -> Result, RequireError> { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + let cache = data_ref.cache.lock().await; + + match cache.get(path_abs) { + Some(cached) => { + let multi_vec = lua.registry_value::>(cached)?; + + Ok(LuaMultiValue::from_vec(multi_vec)) } + None => Err(RequireError::CacheNotFound( + path_abs.to_string_lossy().to_string(), + )), } } - /** - Loads (requires) the file at the given path. - */ - pub async fn load_with_caching<'lua>( - &self, - lua: &'lua Lua, - abs_path: impl AsRef, - rel_path: impl AsRef, - ) -> LuaResult> { - let abs_path = abs_path.as_ref(); - let rel_path = rel_path.as_ref(); - - // Set this abs path as currently pending - let (broadcast_tx, _) = broadcast::channel(1); - self.pending - .try_lock() - .expect("RequireContext may not be used from multiple threads") - .insert(abs_path.to_path_buf(), broadcast_tx); - - // Try to load at this abs path - let load_res = self.load(lua, abs_path, rel_path).await; - let load_val = match &load_res { - Err(e) => Err(e.clone()), - Ok(k) => { - let multi_vec = lua - .registry_value::>(k) - .expect("Failed to fetch require result from registry"); - Ok(LuaMultiValue::from_vec(multi_vec)) + pub(crate) async fn require( + lua: &Lua, + path_abs: PathBuf, + ) -> Result { + if Self::is_pending(lua, &path_abs)? { + Self::wait_for_pending(lua, &path_abs).await?; + return Self::from_cache(lua, &path_abs).await; + } else if Self::is_cached(lua, &path_abs)? { + return Self::from_cache(lua, &path_abs).await; + } + + // create a broadcast channel + { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + let broadcast_tx = broadcast::Sender::new(1); + + { + let mut pending = data_ref.pending.try_lock()?; + pending.insert(path_abs.clone(), broadcast_tx); + } + } + + let content = match fs::read_to_string(&path_abs).await { + Ok(content) => content, + Err(err) => { + // this error is expected to happen in most cases + // because this function will be retried on the same path + // with different extensions when it fails here + + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + data_ref.pending.lock().await.remove(&path_abs); + + return Err(err.into()); } }; - // NOTE: We use the async lock and not try_lock here because - // some other thread may be wanting to insert into the require - // cache at the same time, and that's not an actual error case - self.results + let thread = lua + .load(&content) + .set_name(path_abs.to_string_lossy()) + .into_lua_thread(lua)?; + + let thread_id = lua.push_thread_back(thread, ())?; + lua.track_thread(thread_id); + lua.wait_for_thread(thread_id).await; + + let multi = lua + .get_thread_result(thread_id) + .ok_or_else(|| RequireError::ThreadReturnedNone)??; + + let multi_reg = lua.create_registry_value(multi.into_vec())?; + + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + data_ref + .cache .lock() .await - .insert(abs_path.to_path_buf(), load_res); + .insert(path_abs.clone(), multi_reg); - // Remove the pending thread id from the require context, - // broadcast a message to let any listeners know that this - // path has now finished the require process and is cached - let broadcast_tx = self + let broadcast_tx = data_ref .pending - .try_lock() - .expect("RequireContext may not be used from multiple threads") - .remove(abs_path) + .lock() + .await + .remove(&path_abs) .expect("Pending require broadcaster was unexpectedly removed"); + broadcast_tx.send(()).ok(); - load_val + Self::from_cache(lua, &path_abs).await } /** - Loads (requires) the library with the given name. - */ - pub fn load_library<'lua>( - &self, - lua: &'lua Lua, - name: impl AsRef, - ) -> LuaResult> { - let library: LuneStandardLibrary = match name.as_ref().parse() { - Err(e) => return Err(LuaError::runtime(e)), - Ok(b) => b, - }; - let mut cache = self - .libraries - .try_lock() - .expect("RequireContext may not be used from multiple threads"); - - if let Some(res) = cache.get(&library) { - return match res { - Err(e) => return Err(e.clone()), - Ok(key) => { - let multi_vec = lua - .registry_value::>(key) - .expect("Missing library result in lua registry"); - Ok(LuaMultiValue::from_vec(multi_vec)) - } - }; + add a standard library into the require function + + # Example + + ```rs + inject_std(lua, "lune", LuneStandardLibrary::Task)?; + ``` + + ```luau + -- luau + local task = require("@lune/task") + ``` + + # Errors + + - when `RequireStorage::init` isn't called + + */ + pub fn inject_std( + lua: &Lua, + alias: &'static str, + std: impl StandardLibrary + 'static, + ) -> Result<(), RequireError> { + let mut data = lua + .app_data_mut::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + if let Some(map) = data.std.get_mut(alias) { + map.insert(std.name(), Box::new(std)); + } else { + let mut map: HashMap<&str, Box> = HashMap::new(); + + map.insert(std.name(), Box::new(std)); + + data.std.insert(alias, map); }; - let result = library.module(lua); - - cache.insert( - library, - match result.clone() { - Err(e) => Err(e), - Ok(multi) => { - let multi_vec = multi.into_vec(); - let multi_key = lua - .create_registry_value(multi_vec) - .expect("Failed to store require result in registry - out of memory"); - Ok(multi_key) - } - }, - ); - - result + Ok(()) } } diff --git a/crates/lune-std/src/globals/require/library.rs b/crates/lune-std/src/globals/require/library.rs deleted file mode 100644 index b47ea928..00000000 --- a/crates/lune-std/src/globals/require/library.rs +++ /dev/null @@ -1,14 +0,0 @@ -use mlua::prelude::*; - -use super::context::*; - -pub(super) fn require<'lua, 'ctx>( - lua: &'lua Lua, - ctx: &'ctx RequireContext, - name: &str, -) -> LuaResult> -where - 'lua: 'ctx, -{ - ctx.load_library(lua, name) -} diff --git a/crates/lune-std/src/globals/require/mod.rs b/crates/lune-std/src/globals/require/mod.rs index 3876e363..5b67c362 100644 --- a/crates/lune-std/src/globals/require/mod.rs +++ b/crates/lune-std/src/globals/require/mod.rs @@ -1,93 +1,114 @@ +use crate::{ + luaurc::{Luaurc, RequireAlias}, + path::get_parent_path, + LuneStandardLibrary, +}; +use lune_utils::path::clean_path_and_make_absolute; use mlua::prelude::*; +use path::append_extension; +use std::path::PathBuf; +use thiserror::Error; -use lune_utils::TableBuilder; +pub mod context; +mod path; -mod context; -use context::RequireContext; +#[derive(Error, Debug)] +pub enum RequireError { + #[error("failed to find RequireContextData in the app data container, make sure to call RequireContext::init first")] + RequireContextNotFound, + #[error("RequireContext::init has been called twice on the same Lua instance")] + RequireContextInitCalledTwice, + #[error("Can not require '{0}' as it does not exist")] + InvalidRequire(String), + #[error("Alias '{0}' does not point to a built-in standard library")] + InvalidStdAlias(String), + #[error("Library '{0}' does not point to a member of '{1}' standard libraries")] + StdMemberNotFound(String, String), + #[error("Thread result returned none")] + ThreadReturnedNone, + #[error("Could not get '{0}' from cache")] + CacheNotFound(String), + + #[error("IOError: {0}")] + IOError(#[from] std::io::Error), + #[error("TryLockError: {0}")] + TryLockError(#[from] tokio::sync::TryLockError), + #[error("BroadcastRecvError: {0}")] + BroadcastRecvError(#[from] tokio::sync::broadcast::error::RecvError), + #[error("LuaError: {0}")] + LuaError(#[from] mlua::Error), +} -mod alias; -mod library; -mod path; +/** +tries different extensions on the path and if all alternatives fail, we'll try to look for an init file + */ +async fn try_alternatives(lua: &Lua, require_path_abs: PathBuf) -> LuaResult { + for ext in ["lua", "luau"] { + // try the path with ext + let ext_path = append_extension(&require_path_abs, ext); + + match context::RequireContext::require(lua, ext_path).await { + Ok(res) => return Ok(res), + Err(err) => { + if !matches!(err, RequireError::IOError(_)) { + return Err(err).into_lua_err(); + }; + } + }; + } -const REQUIRE_IMPL: &str = r" -return require(source(), ...) -"; + for ext in ["lua", "luau"] { + // append init to path and try it with ext + let ext_path = append_extension(require_path_abs.join("init"), ext); + + match context::RequireContext::require(lua, ext_path).await { + Ok(res) => return Ok(res), + Err(err) => { + if !matches!(err, RequireError::IOError(_)) { + return Err(err).into_lua_err(); + }; + } + }; + } -pub fn create(lua: &Lua) -> LuaResult { - lua.set_app_data(RequireContext::new()); - - /* - Require implementation needs a few workarounds: - - - Async functions run outside of the lua resumption cycle, - so the current lua thread, as well as its stack/debug info - is not available, meaning we have to use a normal function - - - Using the async require function directly in another lua function - would mean yielding across the metamethod/c-call boundary, meaning - we have to first load our two functions into a normal lua chunk - and then load that new chunk into our final require function - - Also note that we inspect the stack at level 2: - - 1. The current c / rust function - 2. The wrapper lua chunk defined above - 3. The lua chunk we are require-ing from - */ - - let require_fn = lua.create_async_function(require)?; - let get_source_fn = lua.create_function(move |lua, (): ()| match lua.inspect_stack(2) { - None => Err(LuaError::runtime( - "Failed to get stack info for require source", - )), - Some(info) => match info.source().source { - None => Err(LuaError::runtime( - "Stack info is missing source for require", - )), - Some(source) => lua.create_string(source.as_bytes()), - }, - })?; - - let require_env = TableBuilder::new(lua)? - .with_value("source", get_source_fn)? - .with_value("require", require_fn)? - .build_readonly()?; - - lua.load(REQUIRE_IMPL) - .set_name("require") - .set_environment(require_env) - .into_function()? - .into_lua(lua) + Err(RequireError::InvalidRequire( + require_path_abs.to_string_lossy().to_string(), + )) + .into_lua_err() } -async fn require<'lua>( - lua: &'lua Lua, - (source, path): (LuaString<'lua>, LuaString<'lua>), -) -> LuaResult> { - let source = source - .to_str() - .into_lua_err() - .context("Failed to parse require source as string")? - .to_string(); - - let path = path - .to_str() - .into_lua_err() - .context("Failed to parse require path as string")? - .to_string(); - - let context = lua - .app_data_ref() - .expect("Failed to get RequireContext from app data"); - - if let Some(builtin_name) = path.strip_prefix("@lune/").map(str::to_ascii_lowercase) { - library::require(lua, &context, &builtin_name) - } else if let Some(aliased_path) = path.strip_prefix('@') { - let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime( - "Require with custom alias must contain '/' delimiter", - ))?; - alias::require(lua, &context, &source, alias, path).await +async fn lua_require(lua: &Lua, path: String) -> LuaResult { + let require_path_rel = PathBuf::from(path); + let require_alias = RequireAlias::from_path(&require_path_rel).into_lua_err()?; + + if let Some(require_alias) = require_alias { + if context::RequireContext::std_exists(lua, &require_alias.alias).into_lua_err()? { + context::RequireContext::require_std(lua, require_alias).into_lua_err() + } else { + let require_path_abs = clean_path_and_make_absolute( + Luaurc::resolve_path(lua, &require_alias) + .await + .into_lua_err()?, + ); + + try_alternatives(lua, require_path_abs).await + } } else { - path::require(lua, &context, &source, &path).await + let parent_path = get_parent_path(lua)?; + let require_path_abs = clean_path_and_make_absolute(parent_path.join(&require_path_rel)); + + try_alternatives(lua, require_path_abs).await } } + +pub fn create(lua: &Lua) -> LuaResult { + let f = lua.create_async_function(lua_require).into_lua_err()?; + + context::RequireContext::init(lua).into_lua_err()?; + + for std in LuneStandardLibrary::ALL { + context::RequireContext::inject_std(lua, "lune", *std).into_lua_err()?; + } + + f.into_lua(lua) +} diff --git a/crates/lune-std/src/globals/require/path.rs b/crates/lune-std/src/globals/require/path.rs index 1fabebf4..b1e932ea 100644 --- a/crates/lune-std/src/globals/require/path.rs +++ b/crates/lune-std/src/globals/require/path.rs @@ -1,117 +1,16 @@ -use std::path::{Path, PathBuf}; +use std::path::PathBuf; -use mlua::prelude::*; -use mlua::Error::ExternalError; +/** -use super::context::*; +adds extension to path without replacing it's current extensions -pub(super) async fn require<'lua, 'ctx>( - lua: &'lua Lua, - ctx: &'ctx RequireContext, - source: &str, - path: &str, -) -> LuaResult> -where - 'lua: 'ctx, -{ - let (abs_path, rel_path) = RequireContext::resolve_paths(source, path)?; - require_abs_rel(lua, ctx, abs_path, rel_path).await -} - -pub(super) async fn require_abs_rel<'lua, 'ctx>( - lua: &'lua Lua, - ctx: &'ctx RequireContext, - abs_path: PathBuf, // Absolute to filesystem - rel_path: PathBuf, // Relative to CWD (for displaying) -) -> LuaResult> -where - 'lua: 'ctx, -{ - // 1. Try to require the exact path - match require_inner(lua, ctx, &abs_path, &rel_path).await { - Ok(res) => return Ok(res), - Err(err) => { - if !is_file_not_found_error(&err) { - return Err(err); - } - } - } - - // 2. Try to require the path with an added "luau" extension - // 3. Try to require the path with an added "lua" extension - for extension in ["luau", "lua"] { - match require_inner( - lua, - ctx, - &append_extension(&abs_path, extension), - &append_extension(&rel_path, extension), - ) - .await - { - Ok(res) => return Ok(res), - Err(err) => { - if !is_file_not_found_error(&err) { - return Err(err); - } - } - } - } +### Example - // We didn't find any direct file paths, look - // for directories with "init" files in them... - let abs_init = abs_path.join("init"); - let rel_init = rel_path.join("init"); - - // 4. Try to require the init path with an added "luau" extension - // 5. Try to require the init path with an added "lua" extension - for extension in ["luau", "lua"] { - match require_inner( - lua, - ctx, - &append_extension(&abs_init, extension), - &append_extension(&rel_init, extension), - ) - .await - { - Ok(res) => return Ok(res), - Err(err) => { - if !is_file_not_found_error(&err) { - return Err(err); - } - } - } - } - - // Nothing left to try, throw an error - Err(LuaError::runtime(format!( - "No file exists at the path '{}'", - rel_path.display() - ))) -} +appending `.luau` to `path/path.config` will return `path/path.config.luau` -async fn require_inner<'lua, 'ctx>( - lua: &'lua Lua, - ctx: &'ctx RequireContext, - abs_path: impl AsRef, - rel_path: impl AsRef, -) -> LuaResult> -where - 'lua: 'ctx, -{ - let abs_path = abs_path.as_ref(); - let rel_path = rel_path.as_ref(); - - if ctx.is_cached(abs_path)? { - ctx.get_from_cache(lua, abs_path) - } else if ctx.is_pending(abs_path)? { - ctx.wait_for_cache(lua, &abs_path).await - } else { - ctx.load_with_caching(lua, &abs_path, &rel_path).await - } -} - -fn append_extension(path: impl Into, ext: &'static str) -> PathBuf { - let mut new = path.into(); + */ +pub fn append_extension(path: impl Into, ext: &'static str) -> PathBuf { + let mut new: PathBuf = path.into(); match new.extension() { // FUTURE: There's probably a better way to do this than converting to a lossy string Some(e) => new.set_extension(format!("{}.{ext}", e.to_string_lossy())), @@ -119,11 +18,3 @@ fn append_extension(path: impl Into, ext: &'static str) -> PathBuf { }; new } - -fn is_file_not_found_error(err: &LuaError) -> bool { - if let ExternalError(err) = err { - err.as_ref().downcast_ref::().is_some() - } else { - false - } -} diff --git a/crates/lune-std/src/lib.rs b/crates/lune-std/src/lib.rs index a29bef03..508a8919 100644 --- a/crates/lune-std/src/lib.rs +++ b/crates/lune-std/src/lib.rs @@ -6,10 +6,12 @@ mod global; mod globals; mod library; mod luaurc; +mod path; pub use self::global::LuneStandardGlobal; +pub use self::globals::require::context::RequireContext; pub use self::globals::version::set_global_version; -pub use self::library::LuneStandardLibrary; +pub use self::library::{LuneStandardLibrary, StandardLibrary}; /** Injects all standard globals into the given Lua state / VM. diff --git a/crates/lune-std/src/library.rs b/crates/lune-std/src/library.rs index 9a301f57..2097bb6d 100644 --- a/crates/lune-std/src/library.rs +++ b/crates/lune-std/src/library.rs @@ -1,7 +1,26 @@ -use std::str::FromStr; +use std::{fmt::Debug, str::FromStr}; use mlua::prelude::*; +pub trait StandardLibrary +where + Self: Debug, +{ + /** + Gets the name of the library, such as `datetime` or `fs`. + */ + fn name(&self) -> &'static str; + + /** + Creates the Lua module for the library. + + # Errors + + If the library could not be created. + */ + fn module<'lua>(&self, lua: &'lua Lua) -> LuaResult>; +} + /** A standard library provided by Lune. */ @@ -37,14 +56,13 @@ impl LuneStandardLibrary { #[cfg(feature = "stdio")] Self::Stdio, #[cfg(feature = "roblox")] Self::Roblox, ]; +} - /** - Gets the name of the library, such as `datetime` or `fs`. - */ +impl StandardLibrary for LuneStandardLibrary { #[must_use] #[rustfmt::skip] #[allow(unreachable_patterns)] - pub fn name(&self) -> &'static str { + fn name(&self) -> &'static str { match self { #[cfg(feature = "datetime")] Self::DateTime => "datetime", #[cfg(feature = "fs")] Self::Fs => "fs", @@ -61,16 +79,9 @@ impl LuneStandardLibrary { } } - /** - Creates the Lua module for the library. - - # Errors - - If the library could not be created. - */ #[rustfmt::skip] #[allow(unreachable_patterns)] - pub fn module<'lua>(&self, lua: &'lua Lua) -> LuaResult> { + fn module<'lua>(&self, lua: &'lua Lua) -> LuaResult> { let res: LuaResult = match self { #[cfg(feature = "datetime")] Self::DateTime => lune_std_datetime::module(lua), #[cfg(feature = "fs")] Self::Fs => lune_std_fs::module(lua), @@ -114,12 +125,7 @@ impl FromStr for LuneStandardLibrary { _ => { return Err(format!( - "Unknown standard library '{low}'\nValid libraries are: {}", - Self::ALL - .iter() - .map(Self::name) - .collect::>() - .join(", ") + "Unknown standard library '{low}'" )) } }) diff --git a/crates/lune-std/src/luaurc.rs b/crates/lune-std/src/luaurc.rs index 0eada593..ef61d127 100644 --- a/crates/lune-std/src/luaurc.rs +++ b/crates/lune-std/src/luaurc.rs @@ -1,18 +1,21 @@ +use crate::path::get_parent_path; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; use std::{ collections::HashMap, - path::{Path, PathBuf, MAIN_SEPARATOR}, - sync::Arc, + env::current_dir, + path::{Path, PathBuf}, }; +use thiserror::Error; +use tokio::fs; -use serde::{Deserialize, Serialize}; -use serde_json::Value as JsonValue; -use tokio::fs::read; - -use lune_utils::path::{clean_path, clean_path_and_make_absolute}; - -const LUAURC_FILE: &str = ".luaurc"; +#[derive(Debug, Clone, Eq, Hash, PartialEq)] +pub struct RequireAlias { + pub alias: String, + pub path: String, +} -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] enum LuauLanguageMode { NoCheck, @@ -20,9 +23,8 @@ enum LuauLanguageMode { Strict, } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct LuauRcConfig { +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Luaurc { #[serde(skip_serializing_if = "Option::is_none")] language_mode: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -34,131 +36,97 @@ struct LuauRcConfig { #[serde(skip_serializing_if = "Option::is_none")] globals: Option>, #[serde(skip_serializing_if = "Option::is_none")] - paths: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - aliases: Option>, + aliases: Option>, } -/** - A deserialized `.luaurc` file. - - Contains utility methods for validating and searching for aliases. -*/ -#[derive(Debug, Clone)] -pub struct LuauRc { - dir: Arc, - config: LuauRcConfig, +#[derive(Debug, Error)] +pub enum LuaurcError { + #[error("Require with alias doesn't contain '/'")] + UsedAliasWithoutSlash, + #[error("Failed to convert string to path")] + FailedStringToPathConversion, + #[error("Failed to find a path for alias '{0}' in .luaurc files")] + FailedToFindAlias(String), + #[error("Failed to parse {0}\nParserError: {1}")] + FailedToParse(PathBuf, serde_json::Error), + + #[error("IOError: {0}")] + IOError(#[from] std::io::Error), + #[error("LuaError: {0}")] + LuaError(#[from] mlua::Error), } -impl LuauRc { - /** - Reads a `.luaurc` file from the given directory. - - If the file does not exist, or if it is invalid, this function returns `None`. - */ - pub async fn read(dir: impl AsRef) -> Option { - let dir = clean_path_and_make_absolute(dir); - let path = dir.join(LUAURC_FILE); - let bytes = read(&path).await.ok()?; - let config = serde_json::from_slice(&bytes).ok()?; - Some(Self { - dir: dir.into(), - config, - }) - } - - /** - Reads a `.luaurc` file from the given directory, and then recursively searches - for a `.luaurc` file in the parent directories if a predicate is not satisfied. - - If no `.luaurc` file exists, or if they are invalid, this function returns `None`. - */ - pub async fn read_recursive( - dir: impl AsRef, - mut predicate: impl FnMut(&Self) -> bool, - ) -> Option { - let mut current = clean_path_and_make_absolute(dir); - loop { - if let Some(rc) = Self::read(¤t).await { - if predicate(&rc) { - return Some(rc); - } - } - if let Some(parent) = current.parent() { - current = parent.to_path_buf(); - } else { - return None; - } +impl RequireAlias { + /// Parses path into `RequireAlias` struct + /// + /// ### Examples + /// + /// `@lune/task` becomes `Some({ alias: "lune", path: "task" })` + /// + /// `../path/script` becomes `None` + pub fn from_path(path: &Path) -> Result, LuaurcError> { + if let Some(aliased_path) = path + .to_str() + .ok_or(LuaurcError::FailedStringToPathConversion)? + .strip_prefix('@') + { + let (alias, path) = aliased_path + .split_once('/') + .ok_or(LuaurcError::UsedAliasWithoutSlash)?; + + Ok(Some(RequireAlias { + alias: alias.to_string(), + path: path.to_string(), + })) + } else { + Ok(None) } } +} - /** - Validates that the `.luaurc` file is correct. +/** +# Errors - This primarily validates aliases since they are not - validated during creation of the [`LuauRc`] struct. +* when `serde_json` fails to deserialize content of the file - # Errors + */ +async fn parse_luaurc(_: &mlua::Lua, path: &PathBuf) -> Result, LuaurcError> { + if let Ok(content) = fs::read(path).await { + serde_json::from_slice(&content) + .map(Some) + .map_err(|err| LuaurcError::FailedToParse(path.clone(), err)) + } else { + Ok(None) + } +} - If an alias key is invalid. - */ - pub fn validate(&self) -> Result<(), String> { - if let Some(aliases) = &self.config.aliases { - for alias in aliases.keys() { - if !is_valid_alias_key(alias) { - return Err(format!("invalid alias key: {alias}")); +impl Luaurc { + /// Searches for .luaurc recursively + /// until an alias for the provided `RequireAlias` is found + pub async fn resolve_path<'lua>( + lua: &'lua mlua::Lua, + alias: &'lua RequireAlias, + ) -> Result { + let cwd = current_dir()?; + let parent = cwd.join(get_parent_path(lua)?); + let ancestors = parent.ancestors(); + + for path in ancestors { + if path.starts_with(&cwd) { + if let Some(luaurc) = parse_luaurc(lua, &path.join(".luaurc")).await? { + if let Some(aliases) = luaurc.aliases { + if let Some(alias_path) = aliases.get(&alias.alias) { + let resolved = path.join(alias_path.join(&alias.path)); + + return Ok(resolved); + } + } } + } else { + break; } } - Ok(()) - } - - /** - Gets a copy of all aliases in the `.luaurc` file. - - Will return an empty map if there are no aliases. - */ - #[must_use] - pub fn aliases(&self) -> HashMap { - self.config.aliases.clone().unwrap_or_default() - } - - /** - Finds an alias in the `.luaurc` file by name. - If the alias does not exist, this function returns `None`. - */ - #[must_use] - pub fn find_alias(&self, name: &str) -> Option { - self.config.aliases.as_ref().and_then(|aliases| { - aliases.iter().find_map(|(alias, path)| { - if alias - .trim_end_matches(MAIN_SEPARATOR) - .eq_ignore_ascii_case(name) - && is_valid_alias_key(alias) - { - Some(clean_path(self.dir.join(path))) - } else { - None - } - }) - }) + Err(LuaurcError::FailedToFindAlias(alias.alias.to_string())) } } - -fn is_valid_alias_key(alias: impl AsRef) -> bool { - let alias = alias.as_ref(); - if alias.is_empty() - || alias.starts_with('.') - || alias.starts_with("..") - || alias.chars().any(|c| c == MAIN_SEPARATOR) - { - false // Paths are not valid alias keys - } else { - alias.chars().all(is_valid_alias_char) - } -} - -fn is_valid_alias_char(c: char) -> bool { - c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' -} diff --git a/crates/lune-std/src/path.rs b/crates/lune-std/src/path.rs new file mode 100644 index 00000000..bf333095 --- /dev/null +++ b/crates/lune-std/src/path.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +/** + +return's the path of the script that called this function + + */ +pub fn get_script_path(lua: &mlua::Lua) -> Result { + let Some(debug) = lua.inspect_stack(2) else { + return Err(mlua::Error::runtime("Failed to inspect stack")); + }; + + match debug + .source() + .source + .map(|raw_source| PathBuf::from(raw_source.to_string())) + { + Some(script) => Ok(script), + None => Err(mlua::Error::runtime( + "Failed to get path of the script that called require", + )), + } +} + +/** + +return's the parent directory of the script that called this function + + */ +pub fn get_parent_path(lua: &mlua::Lua) -> Result { + let script = get_script_path(lua)?; + + match script.parent() { + Some(parent) => Ok(parent.to_path_buf()), + None => Err(mlua::Error::runtime( + "Failed to get parent of the script that called require", + )), + } +}