diff --git a/crates/runtime/src/engine/mod.rs b/crates/runtime/src/engine/mod.rs index 798f0a9465..78927bbae3 100644 --- a/crates/runtime/src/engine/mod.rs +++ b/crates/runtime/src/engine/mod.rs @@ -4,7 +4,7 @@ mod wasm3; #[cfg(feature = "wasmer")] mod wasmer; -use std::sync::Arc; +//use std::sync::Arc; use anyhow::Error; @@ -15,13 +15,6 @@ pub(crate) use self::wasmer::WasmerEngine; /// A WebAssembly virtual machine that links Rune with pub(crate) trait WebAssemblyEngine { - fn load( - wasm: &[u8], - callbacks: Arc, - ) -> Result - where - Self: Sized; - /// Call the `_manifest()` function to initialize the Rune graph. fn init(&mut self) -> Result<(), Error>; diff --git a/crates/runtime/src/engine/wasm3.rs b/crates/runtime/src/engine/wasm3.rs index 8cbda51946..f7b7b5a124 100644 --- a/crates/runtime/src/engine/wasm3.rs +++ b/crates/runtime/src/engine/wasm3.rs @@ -26,52 +26,10 @@ pub struct Wasm3Engine { } impl Wasm3Engine { - /// Find a function in the wasm3 module and try to call it. - /// - /// Sorry for the generics soup and the whole `apply` thing. The - /// `wasm3::Function` type doesn't actually have a single generic `call()` - /// method which accepts some arguments, instead they "overload" - /// `Function::call()` based on the argument type so you don't need to pass - /// in tuples. - fn call( - &mut self, - name: &str, - args: Args, - apply: impl FnOnce(Function, Args) -> Result, - ) -> Result - where - Ret: WasmType, - Args: WasmArgs, - { - let function: Function = - self.runtime.find_function(name).to_anyhow().with_context( - || format!("Unable to find the \"{}()\" function", name), - )?; - - match apply(function, args) { - Ok(ret) => Ok(ret), - // We know that host function errors will emit a trap and set - // last_error, so we can use that to try and give the user a more - // useful error message. - Err(Wasm3Error::Wasm3(e)) if e.is_trap(Trap::Abort) => { - match self.last_error.lock().unwrap().take() { - Some(e) => Err(e), - None => Err(Wasm3Error::Wasm3(e)).to_anyhow(), - } - }, - Err(e) => Err(e).to_anyhow(), - } - } -} - -impl WebAssemblyEngine for Wasm3Engine { - fn load( + pub(crate) fn load( wasm: &[u8], callbacks: Arc, - ) -> Result - where - Self: Sized, - { + ) -> Result { let env = Environment::new().to_anyhow()?; let host_functions = Arc::new(Mutex::new(HostFunctions::new(Arc::clone(&callbacks)))); @@ -109,6 +67,45 @@ impl WebAssemblyEngine for Wasm3Engine { }) } + /// Find a function in the wasm3 module and try to call it. + /// + /// Sorry for the generics soup and the whole `apply` thing. The + /// `wasm3::Function` type doesn't actually have a single generic `call()` + /// method which accepts some arguments, instead they "overload" + /// `Function::call()` based on the argument type so you don't need to pass + /// in tuples. + fn call( + &mut self, + name: &str, + args: Args, + apply: impl FnOnce(Function, Args) -> Result, + ) -> Result + where + Ret: WasmType, + Args: WasmArgs, + { + let function: Function = + self.runtime.find_function(name).to_anyhow().with_context( + || format!("Unable to find the \"{}()\" function", name), + )?; + + match apply(function, args) { + Ok(ret) => Ok(ret), + // We know that host function errors will emit a trap and set + // last_error, so we can use that to try and give the user a more + // useful error message. + Err(Wasm3Error::Wasm3(e)) if e.is_trap(Trap::Abort) => { + match self.last_error.lock().unwrap().take() { + Some(e) => Err(e), + None => Err(Wasm3Error::Wasm3(e)).to_anyhow(), + } + }, + Err(e) => Err(e).to_anyhow(), + } + } +} + +impl WebAssemblyEngine for Wasm3Engine { fn init(&mut self) -> Result<(), Error> { let _: i32 = self.call("_manifest", (), |f, _| f.call())?; let host_functions = self.host_functions.lock().unwrap(); diff --git a/crates/runtime/src/engine/wasmer.rs b/crates/runtime/src/engine/wasmer.rs index 423aa633ab..832ed6168d 100644 --- a/crates/runtime/src/engine/wasmer.rs +++ b/crates/runtime/src/engine/wasmer.rs @@ -22,17 +22,22 @@ pub struct WasmerEngine { callbacks: Arc, } -impl WebAssemblyEngine for WasmerEngine { - fn load( +impl WasmerEngine { + pub(crate) fn load( wasm: &[u8], callbacks: Arc, - ) -> Result - where - Self: Sized, - { + ) -> Result { let store = Store::default(); let module = Module::from_binary(&store, wasm)?; + WasmerEngine::from_module(&store, &module, callbacks) + } + /// Create a new `WasmerEngine` instance with an existing preloaded module + pub(crate) fn from_module( + store: &Store, + module: &Module, + callbacks: Arc, + ) -> Result { let host_functions = Arc::new(Mutex::new(HostFunctions::new(callbacks.clone()))); let env = Env { @@ -42,19 +47,19 @@ impl WebAssemblyEngine for WasmerEngine { let imports = wasmer::imports! { "env" => { - "_debug" => Function::new_native_with_env(&store, env.clone(), debug), - "request_capability" => Function::new_native_with_env(&store, env.clone(), request_capability), - "request_capability_set_param" => Function::new_native_with_env(&store, env.clone(), request_capability_set_param), - "request_provider_response" => Function::new_native_with_env(&store, env.clone(), request_provider_response), - "tfm_model_invoke" => Function::new_native_with_env(&store, env.clone(), tfm_model_invoke), - "tfm_preload_model" => Function::new_native_with_env(&store, env.clone(), tfm_preload_model), - "rune_model_load" => Function::new_native_with_env(&store, env.clone(), rune_model_load), - "rune_model_infer" => Function::new_native_with_env(&store, env.clone(), rune_model_infer), - "request_output" => Function::new_native_with_env(&store, env.clone(), request_output), - "consume_output" => Function::new_native_with_env(&store, env.clone(), consume_output), - "rune_resource_open" => Function::new_native_with_env(&store, env.clone(), rune_resource_open), - "rune_resource_read" => Function::new_native_with_env(&store, env.clone(), rune_resource_read), - "rune_resource_close" => Function::new_native_with_env(&store, env.clone(), rune_resource_close), + "_debug" => Function::new_native_with_env(store, env.clone(), debug), + "request_capability" => Function::new_native_with_env(store, env.clone(), request_capability), + "request_capability_set_param" => Function::new_native_with_env(store, env.clone(), request_capability_set_param), + "request_provider_response" => Function::new_native_with_env(store, env.clone(), request_provider_response), + "tfm_model_invoke" => Function::new_native_with_env(store, env.clone(), tfm_model_invoke), + "tfm_preload_model" => Function::new_native_with_env(store, env.clone(), tfm_preload_model), + "rune_model_load" => Function::new_native_with_env(store, env.clone(), rune_model_load), + "rune_model_infer" => Function::new_native_with_env(store, env.clone(), rune_model_infer), + "request_output" => Function::new_native_with_env(store, env.clone(), request_output), + "consume_output" => Function::new_native_with_env(store, env.clone(), consume_output), + "rune_resource_open" => Function::new_native_with_env(store, env.clone(), rune_resource_open), + "rune_resource_read" => Function::new_native_with_env(store, env.clone(), rune_resource_read), + "rune_resource_close" => Function::new_native_with_env(store, env.clone(), rune_resource_close), } }; @@ -66,7 +71,9 @@ impl WebAssemblyEngine for WasmerEngine { callbacks, }) } +} +impl WebAssemblyEngine for WasmerEngine { fn init(&mut self) -> Result<(), Error> { let manifest: NativeFunc<(), i32> = self .instance diff --git a/crates/runtime/src/runtime.rs b/crates/runtime/src/runtime.rs index 923ceb8478..b4cd73cfa8 100644 --- a/crates/runtime/src/runtime.rs +++ b/crates/runtime/src/runtime.rs @@ -46,10 +46,12 @@ use wasmparser::{Parser, Payload}; use crate::{ callbacks::{Callbacks, Model, ModelMetadata, RuneGraph}, - engine::{LoadError, WebAssemblyEngine}, + engine::{WebAssemblyEngine}, outputs::{parse_outputs, OutputTensor}, NodeMetadata, Tensor, }; +#[allow(unused_imports)] // used with the "wasm3" or "wasmer" features +use crate::engine::LoadError; /// A loaded Rune. pub struct Runtime { @@ -61,23 +63,46 @@ impl Runtime { /// Load a Rune, using WASM3 for executing WebAssembly. #[cfg(feature = "wasm3")] pub fn wasm3(rune: &[u8]) -> Result { - Runtime::load::(rune) + let state = State::from_wasm_binary(rune); + let state = Arc::new(state); + let callbacks = Arc::clone(&state) as Arc; + let mut engine = crate::engine::Wasm3Engine::load(rune, callbacks)?; + + engine.init()?; + + Ok(Runtime { + state, + engine: Box::new(engine), + }) } /// Load a Rune, using Wasmer for executing WebAssembly. #[cfg(feature = "wasmer")] pub fn wasmer(rune: &[u8]) -> Result { - Runtime::load::(rune) + let state = State::from_wasm_binary(rune); + let state = Arc::new(state); + let callbacks = Arc::clone(&state) as Arc; + let mut engine = crate::engine::WasmerEngine::load(rune, callbacks)?; + + engine.init()?; + + Ok(Runtime { + state, + engine: Box::new(engine), + }) } - fn load(rune: &[u8]) -> Result - where - E: WebAssemblyEngine + 'static, - { - let state = State::with_embedded_resources(rune); + #[cfg(feature = "wasmer")] + pub fn wasmer_from_module( + store: &wasmer::Store, + module: &wasmer::Module, + ) -> Result { + let resource_sections = module.custom_sections(".rune_resource"); + let state = State::new(resource_sections); let state = Arc::new(state); let callbacks = Arc::clone(&state) as Arc; - let mut engine = E::load(rune, callbacks)?; + let mut engine = + crate::engine::WasmerEngine::from_module(store, module, callbacks)?; engine.init()?; @@ -154,25 +179,43 @@ struct State { resources: UnsafeCell>>, } +#[allow(dead_code)] // used with the "wasmer" and/or "wasm3" feature flags impl State { - fn with_embedded_resources(wasm: &[u8]) -> Self { + /// Construct the `State` by extracting resources from a WebAssembly + /// binary's custom sections. + fn from_wasm_binary(wasm: &[u8]) -> Self { + let _s = State::default(); + + let resource_sections = + Parser::default().parse_all(wasm).filter_map(|p| match p { + Ok(Payload::CustomSection { + name: ".rune_resource", + data, + .. + }) => Some(data), + _ => None, + }); + + State::new(resource_sections) + } + + fn new<'a, A>(resource_sections: impl Iterator) -> Self + where + A: AsRef<[u8]>, + { let s = State::default(); - for payload in Parser::default().parse_all(wasm) { - if let Ok(Payload::CustomSection { name, mut data, .. }) = payload { - if name != ".rune_resource" { - continue; - } - - while let Some((resource_name, value, rest)) = - hotg_rune_core::decode_inline_resource(data) - { - // Safety: fine because we are the only ones with access to - // State at the moment. - let resources = unsafe { s.resources() }; - resources.insert(resource_name.to_string(), value.to_vec()); - data = rest; - } + for section in resource_sections { + let mut section = section.as_ref(); + + while let Some((resource_name, value, rest)) = + hotg_rune_core::decode_inline_resource(section) + { + // Safety: fine because we are the only ones with access to + // State at the moment. + let resources = unsafe { s.resources() }; + resources.insert(resource_name.to_string(), value.to_vec()); + section = rest; } } diff --git a/integration-tests/src/assertions.rs b/integration-tests/src/assertions.rs index d62e6e153c..2422a98c36 100644 --- a/integration-tests/src/assertions.rs +++ b/integration-tests/src/assertions.rs @@ -4,9 +4,8 @@ use std::{ process::Output, }; -use regex::Regex; - use anyhow::{Context, Error}; +use regex::Regex; pub trait Assertion: Debug { fn check_for_errors(&self, output: &Output) -> Result<(), Error>; @@ -60,7 +59,8 @@ impl Assertion for MatchStdioStream { .context("Unable to parse output as UTF-8")?; use super::*; - if !reduce_precision(output).contains(&reduce_precision(&self.expected)) { + if !reduce_precision(output).contains(&reduce_precision(&self.expected)) + { return Err(Error::from(MismatchedStdio { expected: self.expected.clone(), actual: output.to_string(), @@ -185,7 +185,8 @@ fn reduce_precision(s: &str) -> String { let number_of_decimals = m.len() - decimal_at - 1; if number_of_decimals > precision { - result = result.replace(m, m.split_at(decimal_at + precision + 1).0); + result = + result.replace(m, m.split_at(decimal_at + precision + 1).0); } } @@ -201,8 +202,17 @@ mod tests { assert_eq!(reduce_precision("5"), "5"); assert_eq!(reduce_precision("abcdef.abcdef"), "abcdef.abcdef"); assert_eq!(reduce_precision("5.12345678"), "5.12345"); - assert_eq!(reduce_precision("5.12345678 6.12345678"), "5.12345 6.12345"); - assert_eq!(reduce_precision("5.12345678 6.12345678 7.1234567"), "5.12345 6.12345 7.12345"); - assert_eq!(reduce_precision("abcdef 5.12345678 6.12345678"), "abcdef 5.12345 6.12345") + assert_eq!( + reduce_precision("5.12345678 6.12345678"), + "5.12345 6.12345" + ); + assert_eq!( + reduce_precision("5.12345678 6.12345678 7.1234567"), + "5.12345 6.12345 7.12345" + ); + assert_eq!( + reduce_precision("abcdef 5.12345678 6.12345678"), + "abcdef 5.12345 6.12345" + ) } }