diff --git a/Cargo.toml b/Cargo.toml index 7d0528c0..d1ab8747 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["./flipt-engine-ffi", "./flipt-engine-wasm-js", "./flipt-evaluation"] +members = ["./flipt-engine-ffi", "./flipt-engine-wasm", "./flipt-engine-wasm-js", "./flipt-evaluation"] [profile.release] strip = true diff --git a/flipt-client-go/evaluation.go b/flipt-client-go/evaluation.go index 7dca1466..28466d0b 100644 --- a/flipt-client-go/evaluation.go +++ b/flipt-client-go/evaluation.go @@ -1,6 +1,7 @@ package evaluation import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,7 +12,7 @@ import ( ) var ( - //go:embed ext/flipt_evaluation_wasi.wasm + //go:embed ext/flipt_evaluation_wasm.wasm wasm []byte ) @@ -20,6 +21,8 @@ type Client struct { namespace string instance *wasmtime.Instance store *wasmtime.Store + memory *wasmtime.Memory + enginePtr int32 } // NewClient constructs a Client. @@ -54,15 +57,55 @@ func NewClient(opts ...clientOption) (*Client, error) { return nil, err } - f := instance.GetFunc(store, "new") + f := instance.GetFunc(store, "initialize_engine") if f == nil { - return nil, fmt.Errorf("new function not found") + return nil, fmt.Errorf("initialize_engine function not found") } - if _, err := f.Call(store, "default", "{}"); err != nil { + namespace := "default" + lenNamespace := len(namespace) + + payload := "{\"namespace\": {\"key\": \"default\"}, \"flags\": []}" + lenPayload := len(payload) + + // allocate memory for the strings + allocate := instance.GetFunc(store, "allocate") + if allocate == nil { + return nil, fmt.Errorf("allocate function not found") + } + + namespacePtr, err := allocate.Call(store, lenNamespace) + if err != nil { + return nil, err + } + payloadPtr, err := allocate.Call(store, lenPayload) + if err != nil { + return nil, err + } + + // get a pointer to the memory + memoryInstance := instance.GetExport(store, "memory").Memory() + if memoryInstance == nil { + return nil, fmt.Errorf("memory not found in WASM instance") + } + + // need to keep this around for the lifetime of the client + // so it doesn't get garbage collected + client.memory = memoryInstance + + // write namespace and payload to memory + data := memoryInstance.UnsafeData(store) + copy(data[uint32(namespacePtr.(int32)):], []byte(namespace)) + copy(data[uint32(payloadPtr.(int32)):], []byte(payload)) + + // initialize_engine + enginePtr, err := f.Call(store, namespacePtr, lenNamespace, payloadPtr, lenPayload) + if err != nil { return nil, err } + client.enginePtr = enginePtr.(int32) + client.instance = instance client.store = store @@ -123,19 +166,47 @@ func (e *Client) EvaluateVariant(_ context.Context, flagKey, entityID string, ev return nil, err } - f := e.instance.GetFunc(e.store, "evaluate-variant") + lenEreq := len(ereq) + // allocate memory for the strings + allocate := e.instance.GetFunc(e.store, "allocate") + if allocate == nil { + return nil, fmt.Errorf("allocate function not found") + } + + deallocate := e.instance.GetFunc(e.store, "deallocate") + if deallocate == nil { + return nil, fmt.Errorf("deallocate function not found") + } + + ereqPtr, err := allocate.Call(e.store, lenEreq) + if err != nil { + return nil, err + } + + defer deallocate.Call(e.store, ereqPtr, lenEreq) + + // copy the evaluation request to the memory + data := e.memory.UnsafeData(e.store) + copy(data[uint32(ereqPtr.(int32)):], ereq) + + f := e.instance.GetFunc(e.store, "evaluate_variant") if f == nil { - return nil, fmt.Errorf("evaluate-variant function not found") + return nil, fmt.Errorf("evaluate_variant function not found") } - result, err := f.Call(e.store, ereq) + resultPtr, err := f.Call(e.store, e.enginePtr, ereqPtr, lenEreq) if err != nil { return nil, err } - var vr *VariantResult + result := e.memory.UnsafeData(e.store)[uint32(resultPtr.(int32)):] + n := bytes.IndexByte(result, 0) + if n < 0 { + n = 0 + } - if err := json.Unmarshal([]byte(result.(string)), &vr); err != nil { + var vr *VariantResult + if err := json.Unmarshal(result[:n], &vr); err != nil { return nil, err } diff --git a/flipt-engine-ffi/Cargo.toml b/flipt-engine-ffi/Cargo.toml index 676309a5..8c6075f8 100644 --- a/flipt-engine-ffi/Cargo.toml +++ b/flipt-engine-ffi/Cargo.toml @@ -16,9 +16,7 @@ tokio = { version = "1.33.0", features = ["full"] } futures = "0.3" openssl = { version = "0.10", features = ["vendored"] } thiserror = "1.0.63" - -[dependencies.flipt-evaluation] -path = "../flipt-evaluation" +flipt-evaluation = { path = "../flipt-evaluation" } [dev-dependencies] mockall = "0.13.0" diff --git a/flipt-evaluation-wasi/.gitignore b/flipt-engine-wasm/.gitignore similarity index 100% rename from flipt-evaluation-wasi/.gitignore rename to flipt-engine-wasm/.gitignore diff --git a/flipt-evaluation-wasi/Cargo.toml b/flipt-engine-wasm/Cargo.toml similarity index 70% rename from flipt-evaluation-wasi/Cargo.toml rename to flipt-engine-wasm/Cargo.toml index 700323a3..5d397b5e 100644 --- a/flipt-evaluation-wasi/Cargo.toml +++ b/flipt-engine-wasm/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "flipt-evaluation-wasi" +name = "flipt-evaluation-wasm" version = "0.1.0" edition = "2021" @@ -8,10 +8,9 @@ edition = "2021" [lib] crate-type = ["cdylib"] -[build] -target = "wasm32-wasi" - [dependencies] -wit-bindgen = "0.29.0" +libc = "0.2.150" +serde = { version = "1.0.147", features = ["derive"] } serde_json = { version = "1.0.89", features = ["raw_value"] } +thiserror = "1.0.63" flipt-evaluation = { path = "../flipt-evaluation" } \ No newline at end of file diff --git a/flipt-engine-wasm/src/lib.rs b/flipt-engine-wasm/src/lib.rs new file mode 100644 index 00000000..336968f8 --- /dev/null +++ b/flipt-engine-wasm/src/lib.rs @@ -0,0 +1,308 @@ +use fliptevaluation::EvaluationRequest; +use fliptevaluation::{ + batch_evaluation, boolean_evaluation, error::Error, models::source, store::Snapshot, + variant_evaluation, +}; +use libc::c_void; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::collections::HashMap; +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use std::panic::catch_unwind; +use thiserror::Error; + +#[derive(Deserialize)] +struct WASMEvaluationRequest { + flag_key: String, + entity_id: String, + context: Option>, +} + +#[derive(Serialize)] +struct WASMResponse +where + T: Serialize, +{ + status: Status, + result: Option, + error_message: Option, +} + +impl From> for WASMResponse +where + T: Serialize, +{ + fn from(value: Result) -> Self { + match value { + Ok(result) => WASMResponse { + status: Status::Success, + result: Some(result), + error_message: None, + }, + Err(e) => WASMResponse { + status: Status::Failure, + result: None, + error_message: Some(e.to_string()), + }, + } + } +} + +#[derive(Serialize)] +enum Status { + #[serde(rename = "success")] + Success, + #[serde(rename = "failure")] + Failure, +} + +#[derive(Error, Debug)] +pub enum EngineError { + #[error("Invalid JSON: {0}")] + InvalidJson(#[from] serde_json::Error), + #[error("Error building snapshot: {0}")] + SnapshotBuildError(#[from] Error), + #[error("Null pointer error")] + NullPointer, +} + +fn result_to_json_ptr(result: Result) -> *mut c_char { + let ffi_response: WASMResponse = result.into(); + let json_string = serde_json::to_string(&ffi_response).unwrap(); + CString::new(json_string).unwrap().into_raw() +} + +pub struct Engine { + namespace: String, + store: Snapshot, +} + +impl Engine { + pub fn new(namespace: &str, data: &str) -> Result { + let doc: source::Document = serde_json::from_str(data)?; + let store = Snapshot::build(namespace, doc)?; + + Ok(Self { + namespace: namespace.to_string(), + store, + }) + } + + pub fn snapshot(&mut self, data: &str) -> Result<(), EngineError> { + let doc: source::Document = serde_json::from_str(data)?; + self.store = Snapshot::build(&self.namespace, doc)?; + Ok(()) + } + + pub fn evaluate_boolean(&self, request: &EvaluationRequest) -> Result { + let result = boolean_evaluation(&self.store, &self.namespace, &request)?; + serde_json::to_string(&result).map_err(|e| Error::InvalidJSON(e.to_string())) + } + + pub fn evaluate_variant(&self, request: &EvaluationRequest) -> Result { + let result = variant_evaluation(&self.store, &self.namespace, &request)?; + serde_json::to_string(&result).map_err(|e| Error::InvalidJSON(e.to_string())) + } + + pub fn evaluate_batch( + &self, + request: Vec, + ) -> Result { + let result = batch_evaluation(&self.store, &self.namespace, request)?; + serde_json::to_string(&result).map_err(|e| Error::InvalidJSON(e.to_string())) + } +} + +/// # Safety +/// +/// This function should not be called unless an Engine is initiated. It provides a helper +/// utility to retrieve an Engine instance for evaluation use. +unsafe fn get_engine<'a>(engine_ptr: *mut c_void) -> Result<&'a mut Engine, EngineError> { + if engine_ptr.is_null() { + Err(EngineError::NullPointer) + } else { + Ok(unsafe { &mut *(engine_ptr as *mut Engine) }) + } +} + +/// # Safety +/// +/// This function will initialize an Engine and return a pointer back to the caller. +#[no_mangle] +pub unsafe extern "C" fn initialize_engine( + namespace_ptr: *const u8, + namespace_len: usize, + payload_ptr: *const u8, + payload_len: usize, +) -> *mut c_void { + let result = catch_unwind(|| { + let namespace = + std::str::from_utf8_unchecked(std::slice::from_raw_parts(namespace_ptr, namespace_len)); + let payload = + std::str::from_utf8_unchecked(std::slice::from_raw_parts(payload_ptr, payload_len)); + + match Engine::new(namespace, payload) { + Ok(engine) => Box::into_raw(Box::new(engine)) as *mut c_void, + Err(e) => { + eprintln!("Error initializing engine: {}", e); + std::ptr::null_mut() + } + } + }); + + result.unwrap_or_else(|_| { + eprintln!("Panic occurred while initializing engine"); + std::ptr::null_mut() + }) +} + +/// # Safety +/// +/// This function will take in a pointer to the engine and return a variant evaluation response. +#[no_mangle] +pub unsafe extern "C" fn evaluate_variant( + engine_ptr: *mut c_void, + evaluation_request_ptr: *const u8, + evaluation_request_len: usize, +) -> *const c_char { + let e = get_engine(engine_ptr).unwrap(); + let evaluation_request = unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts( + evaluation_request_ptr, + evaluation_request_len, + )) + }; + + let request = get_evaluation_request(evaluation_request); + result_to_json_ptr(e.evaluate_variant(&request)) +} + +/// # Safety +/// +/// This function will take in a pointer to the engine and return a boolean evaluation response. +#[no_mangle] +pub unsafe extern "C" fn evaluate_boolean( + engine_ptr: *mut c_void, + evaluation_request_ptr: *const u8, + evaluation_request_len: usize, +) -> *const c_char { + let e = get_engine(engine_ptr).unwrap(); + let evaluation_request = unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts( + evaluation_request_ptr, + evaluation_request_len, + )) + }; + + let request = get_evaluation_request(evaluation_request); + result_to_json_ptr(e.evaluate_boolean(&request)) +} + +/// # Safety +/// +/// This function will take in a pointer to the engine and return a batch evaluation response. +#[no_mangle] +pub unsafe extern "C" fn evaluate_batch( + engine_ptr: *mut c_void, + batch_evaluation_request: *const c_char, +) -> *const c_char { + let e = get_engine(engine_ptr).unwrap(); + let requests = get_batch_evaluation_request(batch_evaluation_request); + result_to_json_ptr(e.evaluate_batch(requests)) +} + +// /// # Safety +// /// +// /// This function will take in a pointer to the engine and return a list of flags for the given namespace. +// #[no_mangle] +// pub unsafe extern "C" fn list_flags(engine_ptr: *mut c_void) -> *const c_char { +// let res = get_engine(engine_ptr).unwrap().list_flags(); + +// result_to_json_ptr(res) +// } + +/// # Safety +/// +/// This function will free the memory occupied by the engine. +#[no_mangle] +pub unsafe extern "C" fn destroy_engine(engine_ptr: *mut c_void) { + if engine_ptr.is_null() { + return; + } + + drop(Box::from_raw(engine_ptr as *mut Engine)); +} + +/// # Safety +/// +/// This function will allocate memory for the engine. +#[no_mangle] +pub extern "C" fn allocate(size: usize) -> *mut c_void { + let mut buf = vec![0; size]; + let ptr = buf.as_mut_ptr(); + std::mem::forget(buf); + ptr as *mut c_void +} + +/// # Safety +/// +/// This function will deallocate the memory for the engine. +#[no_mangle] +pub extern "C" fn deallocate(ptr: *mut c_void, size: usize) { + unsafe { + let buf = Vec::from_raw_parts(ptr, size, size); + std::mem::drop(buf); + } +} + +unsafe fn get_batch_evaluation_request( + batch_evaluation_request: *const c_char, +) -> Vec { + let evaluation_request_bytes = CStr::from_ptr(batch_evaluation_request).to_bytes(); + let bytes_str_repr = std::str::from_utf8(evaluation_request_bytes).unwrap(); + + let batch_eval_request: Vec = + serde_json::from_str(bytes_str_repr).unwrap(); + + let mut evaluation_requests: Vec = + Vec::with_capacity(batch_eval_request.len()); + for req in batch_eval_request { + let mut context_map: HashMap = HashMap::new(); + if let Some(context_value) = req.context { + for (key, value) in context_value { + if let serde_json::Value::String(val) = value { + context_map.insert(key, val); + } + } + } + + evaluation_requests.push(EvaluationRequest { + flag_key: req.flag_key, + entity_id: req.entity_id, + context: context_map, + }); + } + + evaluation_requests +} + +unsafe fn get_evaluation_request(evaluation_request: &str) -> EvaluationRequest { + let client_eval_request: WASMEvaluationRequest = + serde_json::from_str(evaluation_request).unwrap(); + + let mut context_map: HashMap = HashMap::new(); + if let Some(context_value) = client_eval_request.context { + for (key, value) in context_value { + if let serde_json::Value::String(val) = value { + context_map.insert(key, val); + } + } + } + + EvaluationRequest { + flag_key: client_eval_request.flag_key, + entity_id: client_eval_request.entity_id, + context: context_map, + } +} diff --git a/flipt-evaluation-wasi/src/lib.rs b/flipt-evaluation-wasi/src/lib.rs deleted file mode 100644 index 6eb241b3..00000000 --- a/flipt-evaluation-wasi/src/lib.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::cell::RefCell; - -use exports::flipt::evaluation::evaluator::{Guest, GuestSnapshot}; -use fliptevaluation::models::source; - -wit_bindgen::generate!({ - world: "flipt:evaluation/host", -}); - -struct GuestEvaluator; - -impl Guest for GuestEvaluator { - type Snapshot = Snapshot; -} - -export!(GuestEvaluator); - -struct Snapshot { - namespace: RefCell, - store: RefCell, -} - -impl GuestSnapshot for Snapshot { - fn new(namespace: String, data: String) -> Self { - let doc: source::Document = serde_json::from_str(&data).unwrap_or_default(); - let store = fliptevaluation::store::Snapshot::build(&namespace, doc) - .unwrap_or(fliptevaluation::store::Snapshot::empty(&namespace)); - Self { - namespace: RefCell::new(namespace), - store: RefCell::new(store), - } - } - - fn snapshot(&self, data: String) { - let namespace = self.namespace.borrow(); - let doc: source::Document = serde_json::from_str(&data).unwrap_or_default(); - let store = fliptevaluation::store::Snapshot::build(&namespace, doc) - .unwrap_or(fliptevaluation::store::Snapshot::empty(&namespace)); - self.store.replace(store); - } - - fn evaluate_variant(&self, request: String) -> Option { - let request: fliptevaluation::EvaluationRequest = serde_json::from_str(&request).unwrap(); - let response = fliptevaluation::variant_evaluation( - &*self.store.borrow(), - &self.namespace.borrow(), - &request, - ); - match response { - Ok(r) => Some(serde_json::to_string(&r).unwrap()), - Err(_e) => None, - } - } - - fn evaluate_boolean(&self, request: String) -> Option { - let request: fliptevaluation::EvaluationRequest = serde_json::from_str(&request).unwrap(); - let response = fliptevaluation::boolean_evaluation( - &*self.store.borrow(), - &self.namespace.borrow(), - &request, - ); - match response { - Ok(r) => Some(serde_json::to_string(&r).unwrap()), - Err(_e) => None, - } - } - - fn evaluate_batch(&self, _requests: String) -> Option { - todo!() - } -} diff --git a/flipt-evaluation-wasi/wit/host.wit b/flipt-evaluation-wasi/wit/host.wit deleted file mode 100644 index 68935080..00000000 --- a/flipt-evaluation-wasi/wit/host.wit +++ /dev/null @@ -1,15 +0,0 @@ -package flipt:evaluation; - -interface evaluator { - resource snapshot { - constructor(namespace: string, data: string); - snapshot: func(data: string); - evaluate-variant: func(request: string) -> option; - evaluate-boolean: func(request: string) -> option; - evaluate-batch: func(requests: string) -> option; - } -} - -world host { - export evaluator; -} \ No newline at end of file diff --git a/flipt-evaluation/src/error.rs b/flipt-evaluation/src/error.rs index 082b001d..ace2416b 100644 --- a/flipt-evaluation/src/error.rs +++ b/flipt-evaluation/src/error.rs @@ -1,6 +1,7 @@ +use serde::Serialize; use thiserror::Error; -#[derive(Error, Debug, Clone)] +#[derive(Error, Debug, Clone, Serialize)] pub enum Error { #[error("error parsing json: {0}")] InvalidJSON(String),