Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/preload module #421

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions crates/runtime/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod wasm3;
#[cfg(feature = "wasmer")]
mod wasmer;

use std::sync::Arc;
//use std::sync::Arc;

use anyhow::Error;

Expand All @@ -15,13 +15,6 @@ pub(crate) use self::wasmer::WasmerEngine;

/// A WebAssembly virtual machine that links Rune with
pub(crate) trait WebAssemblyEngine {
fn load(
wasm: &[u8],
callbacks: Arc<dyn crate::callbacks::Callbacks>,
) -> Result<Self, LoadError>
where
Self: Sized;

/// Call the `_manifest()` function to initialize the Rune graph.
fn init(&mut self) -> Result<(), Error>;

Expand Down
85 changes: 41 additions & 44 deletions crates/runtime/src/engine/wasm3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,10 @@ pub struct Wasm3Engine {
}

impl Wasm3Engine {
/// Find a function in the wasm3 module and try to call it.
///
/// Sorry for the generics soup and the whole `apply` thing. The
/// `wasm3::Function` type doesn't actually have a single generic `call()`
/// method which accepts some arguments, instead they "overload"
/// `Function::call()` based on the argument type so you don't need to pass
/// in tuples.
fn call<Args, Ret>(
&mut self,
name: &str,
args: Args,
apply: impl FnOnce(Function<Args, Ret>, Args) -> Result<Ret, Wasm3Error>,
) -> Result<Ret, Error>
where
Ret: WasmType,
Args: WasmArgs,
{
let function: Function<Args, Ret> =
self.runtime.find_function(name).to_anyhow().with_context(
|| format!("Unable to find the \"{}()\" function", name),
)?;

match apply(function, args) {
Ok(ret) => Ok(ret),
// We know that host function errors will emit a trap and set
// last_error, so we can use that to try and give the user a more
// useful error message.
Err(Wasm3Error::Wasm3(e)) if e.is_trap(Trap::Abort) => {
match self.last_error.lock().unwrap().take() {
Some(e) => Err(e),
None => Err(Wasm3Error::Wasm3(e)).to_anyhow(),
}
},
Err(e) => Err(e).to_anyhow(),
}
}
}

impl WebAssemblyEngine for Wasm3Engine {
fn load(
pub(crate) fn load(
wasm: &[u8],
callbacks: Arc<dyn Callbacks>,
) -> Result<Self, LoadError>
where
Self: Sized,
{
) -> Result<Self, LoadError> {
let env = Environment::new().to_anyhow()?;
let host_functions =
Arc::new(Mutex::new(HostFunctions::new(Arc::clone(&callbacks))));
Expand Down Expand Up @@ -109,6 +67,45 @@ impl WebAssemblyEngine for Wasm3Engine {
})
}

/// Find a function in the wasm3 module and try to call it.
///
/// Sorry for the generics soup and the whole `apply` thing. The
/// `wasm3::Function` type doesn't actually have a single generic `call()`
/// method which accepts some arguments, instead they "overload"
/// `Function::call()` based on the argument type so you don't need to pass
/// in tuples.
fn call<Args, Ret>(
&mut self,
name: &str,
args: Args,
apply: impl FnOnce(Function<Args, Ret>, Args) -> Result<Ret, Wasm3Error>,
) -> Result<Ret, Error>
where
Ret: WasmType,
Args: WasmArgs,
{
let function: Function<Args, Ret> =
self.runtime.find_function(name).to_anyhow().with_context(
|| format!("Unable to find the \"{}()\" function", name),
)?;

match apply(function, args) {
Ok(ret) => Ok(ret),
// We know that host function errors will emit a trap and set
// last_error, so we can use that to try and give the user a more
// useful error message.
Err(Wasm3Error::Wasm3(e)) if e.is_trap(Trap::Abort) => {
match self.last_error.lock().unwrap().take() {
Some(e) => Err(e),
None => Err(Wasm3Error::Wasm3(e)).to_anyhow(),
}
},
Err(e) => Err(e).to_anyhow(),
}
}
}

impl WebAssemblyEngine for Wasm3Engine {
fn init(&mut self) -> Result<(), Error> {
let _: i32 = self.call("_manifest", (), |f, _| f.call())?;
let host_functions = self.host_functions.lock().unwrap();
Expand Down
45 changes: 26 additions & 19 deletions crates/runtime/src/engine/wasmer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ pub struct WasmerEngine {
callbacks: Arc<dyn Callbacks>,
}

impl WebAssemblyEngine for WasmerEngine {
fn load(
impl WasmerEngine {
pub(crate) fn load(
wasm: &[u8],
callbacks: Arc<dyn Callbacks>,
) -> Result<Self, LoadError>
where
Self: Sized,
{
) -> Result<Self, LoadError> {
let store = Store::default();
let module = Module::from_binary(&store, wasm)?;
WasmerEngine::from_module(&store, &module, callbacks)
}

/// Create a new `WasmerEngine` instance with an existing preloaded module
pub(crate) fn from_module(
store: &Store,
module: &Module,
callbacks: Arc<dyn Callbacks>,
) -> Result<Self, LoadError> {
let host_functions =
Arc::new(Mutex::new(HostFunctions::new(callbacks.clone())));
let env = Env {
Expand All @@ -42,19 +47,19 @@ impl WebAssemblyEngine for WasmerEngine {

let imports = wasmer::imports! {
"env" => {
"_debug" => Function::new_native_with_env(&store, env.clone(), debug),
"request_capability" => Function::new_native_with_env(&store, env.clone(), request_capability),
"request_capability_set_param" => Function::new_native_with_env(&store, env.clone(), request_capability_set_param),
"request_provider_response" => Function::new_native_with_env(&store, env.clone(), request_provider_response),
"tfm_model_invoke" => Function::new_native_with_env(&store, env.clone(), tfm_model_invoke),
"tfm_preload_model" => Function::new_native_with_env(&store, env.clone(), tfm_preload_model),
"rune_model_load" => Function::new_native_with_env(&store, env.clone(), rune_model_load),
"rune_model_infer" => Function::new_native_with_env(&store, env.clone(), rune_model_infer),
"request_output" => Function::new_native_with_env(&store, env.clone(), request_output),
"consume_output" => Function::new_native_with_env(&store, env.clone(), consume_output),
"rune_resource_open" => Function::new_native_with_env(&store, env.clone(), rune_resource_open),
"rune_resource_read" => Function::new_native_with_env(&store, env.clone(), rune_resource_read),
"rune_resource_close" => Function::new_native_with_env(&store, env.clone(), rune_resource_close),
"_debug" => Function::new_native_with_env(store, env.clone(), debug),
"request_capability" => Function::new_native_with_env(store, env.clone(), request_capability),
"request_capability_set_param" => Function::new_native_with_env(store, env.clone(), request_capability_set_param),
"request_provider_response" => Function::new_native_with_env(store, env.clone(), request_provider_response),
"tfm_model_invoke" => Function::new_native_with_env(store, env.clone(), tfm_model_invoke),
"tfm_preload_model" => Function::new_native_with_env(store, env.clone(), tfm_preload_model),
"rune_model_load" => Function::new_native_with_env(store, env.clone(), rune_model_load),
"rune_model_infer" => Function::new_native_with_env(store, env.clone(), rune_model_infer),
"request_output" => Function::new_native_with_env(store, env.clone(), request_output),
"consume_output" => Function::new_native_with_env(store, env.clone(), consume_output),
"rune_resource_open" => Function::new_native_with_env(store, env.clone(), rune_resource_open),
"rune_resource_read" => Function::new_native_with_env(store, env.clone(), rune_resource_read),
"rune_resource_close" => Function::new_native_with_env(store, env.clone(), rune_resource_close),
}
};

Expand All @@ -66,7 +71,9 @@ impl WebAssemblyEngine for WasmerEngine {
callbacks,
})
}
}

impl WebAssemblyEngine for WasmerEngine {
fn init(&mut self) -> Result<(), Error> {
let manifest: NativeFunc<(), i32> = self
.instance
Expand Down
93 changes: 68 additions & 25 deletions crates/runtime/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ use wasmparser::{Parser, Payload};

use crate::{
callbacks::{Callbacks, Model, ModelMetadata, RuneGraph},
engine::{LoadError, WebAssemblyEngine},
engine::{WebAssemblyEngine},
outputs::{parse_outputs, OutputTensor},
NodeMetadata, Tensor,
};
#[allow(unused_imports)] // used with the "wasm3" or "wasmer" features
use crate::engine::LoadError;

/// A loaded Rune.
pub struct Runtime {
Expand All @@ -61,23 +63,46 @@ impl Runtime {
/// Load a Rune, using WASM3 for executing WebAssembly.
#[cfg(feature = "wasm3")]
pub fn wasm3(rune: &[u8]) -> Result<Self, LoadError> {
Runtime::load::<crate::engine::Wasm3Engine>(rune)
let state = State::from_wasm_binary(rune);
let state = Arc::new(state);
let callbacks = Arc::clone(&state) as Arc<dyn Callbacks>;
let mut engine = crate::engine::Wasm3Engine::load(rune, callbacks)?;

engine.init()?;

Ok(Runtime {
state,
engine: Box::new(engine),
})
}

/// Load a Rune, using Wasmer for executing WebAssembly.
#[cfg(feature = "wasmer")]
pub fn wasmer(rune: &[u8]) -> Result<Self, LoadError> {
Runtime::load::<crate::engine::WasmerEngine>(rune)
let state = State::from_wasm_binary(rune);
let state = Arc::new(state);
let callbacks = Arc::clone(&state) as Arc<dyn Callbacks>;
let mut engine = crate::engine::WasmerEngine::load(rune, callbacks)?;

engine.init()?;

Ok(Runtime {
state,
engine: Box::new(engine),
})
}

fn load<E>(rune: &[u8]) -> Result<Self, LoadError>
where
E: WebAssemblyEngine + 'static,
{
let state = State::with_embedded_resources(rune);
#[cfg(feature = "wasmer")]
pub fn wasmer_from_module(
store: &wasmer::Store,
module: &wasmer::Module,
) -> Result<Self, LoadError> {
let resource_sections = module.custom_sections(".rune_resource");
let state = State::new(resource_sections);
let state = Arc::new(state);
let callbacks = Arc::clone(&state) as Arc<dyn Callbacks>;
let mut engine = E::load(rune, callbacks)?;
let mut engine =
crate::engine::WasmerEngine::from_module(store, module, callbacks)?;

engine.init()?;

Expand Down Expand Up @@ -154,25 +179,43 @@ struct State {
resources: UnsafeCell<HashMap<String, Vec<u8>>>,
}

#[allow(dead_code)] // used with the "wasmer" and/or "wasm3" feature flags
impl State {
fn with_embedded_resources(wasm: &[u8]) -> Self {
/// Construct the `State` by extracting resources from a WebAssembly
/// binary's custom sections.
fn from_wasm_binary(wasm: &[u8]) -> Self {
let _s = State::default();

let resource_sections =
Parser::default().parse_all(wasm).filter_map(|p| match p {
Ok(Payload::CustomSection {
name: ".rune_resource",
data,
..
}) => Some(data),
_ => None,
});

State::new(resource_sections)
}

fn new<'a, A>(resource_sections: impl Iterator<Item = A>) -> Self
where
A: AsRef<[u8]>,
{
let s = State::default();

for payload in Parser::default().parse_all(wasm) {
if let Ok(Payload::CustomSection { name, mut data, .. }) = payload {
if name != ".rune_resource" {
continue;
}

while let Some((resource_name, value, rest)) =
hotg_rune_core::decode_inline_resource(data)
{
// Safety: fine because we are the only ones with access to
// State at the moment.
let resources = unsafe { s.resources() };
resources.insert(resource_name.to_string(), value.to_vec());
data = rest;
}
for section in resource_sections {
let mut section = section.as_ref();

while let Some((resource_name, value, rest)) =
hotg_rune_core::decode_inline_resource(section)
{
// Safety: fine because we are the only ones with access to
// State at the moment.
let resources = unsafe { s.resources() };
resources.insert(resource_name.to_string(), value.to_vec());
section = rest;
}
}

Expand Down
24 changes: 17 additions & 7 deletions integration-tests/src/assertions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ use std::{
process::Output,
};

use regex::Regex;

use anyhow::{Context, Error};
use regex::Regex;

pub trait Assertion: Debug {
fn check_for_errors(&self, output: &Output) -> Result<(), Error>;
Expand Down Expand Up @@ -60,7 +59,8 @@ impl Assertion for MatchStdioStream {
.context("Unable to parse output as UTF-8")?;

use super::*;
if !reduce_precision(output).contains(&reduce_precision(&self.expected)) {
if !reduce_precision(output).contains(&reduce_precision(&self.expected))
{
return Err(Error::from(MismatchedStdio {
expected: self.expected.clone(),
actual: output.to_string(),
Expand Down Expand Up @@ -185,7 +185,8 @@ fn reduce_precision(s: &str) -> String {
let number_of_decimals = m.len() - decimal_at - 1;

if number_of_decimals > precision {
result = result.replace(m, m.split_at(decimal_at + precision + 1).0);
result =
result.replace(m, m.split_at(decimal_at + precision + 1).0);
}
}

Expand All @@ -201,8 +202,17 @@ mod tests {
assert_eq!(reduce_precision("5"), "5");
assert_eq!(reduce_precision("abcdef.abcdef"), "abcdef.abcdef");
assert_eq!(reduce_precision("5.12345678"), "5.12345");
assert_eq!(reduce_precision("5.12345678 6.12345678"), "5.12345 6.12345");
assert_eq!(reduce_precision("5.12345678 6.12345678 7.1234567"), "5.12345 6.12345 7.12345");
assert_eq!(reduce_precision("abcdef 5.12345678 6.12345678"), "abcdef 5.12345 6.12345")
assert_eq!(
reduce_precision("5.12345678 6.12345678"),
"5.12345 6.12345"
);
assert_eq!(
reduce_precision("5.12345678 6.12345678 7.1234567"),
"5.12345 6.12345 7.12345"
);
assert_eq!(
reduce_precision("abcdef 5.12345678 6.12345678"),
"abcdef 5.12345 6.12345"
)
}
}