Skip to content

Commit

Permalink
allow caller contract to specify reentry rules
Browse files Browse the repository at this point in the history
  • Loading branch information
heytdep committed Nov 7, 2024
1 parent 22f5770 commit d3bc92b
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 17 deletions.
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
channel = "stable"
channel = "1.81"
targets = ["wasm32-unknown-unknown"]
components = ["rustc", "cargo", "rustfmt", "clippy", "rust-src"]
17 changes: 15 additions & 2 deletions soroban-env-common/env.json
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@
],
"return": "Val",
"docs": "Calls a function in another contract with arguments contained in vector `args`. If the call is successful, returns the result of the called function. Traps otherwise. This functions enables re-entrancy in the immediate cross-contract call.",
"min_supported_protocol": 22
"min_supported_protocol": 21
},
{
"export": "2",
Expand All @@ -1601,7 +1601,20 @@
],
"return": "Val",
"docs": "Calls a function in another contract with arguments contained in vector `args`, returning either the result of the called function or an `Error` if the called function failed. The returned error is either a custom `ContractError` that the called contract returns explicitly, or an error with type `Context` and code `InvalidAction` in case of any other error in the called contract (such as a host function failure that caused a trap). `try_call` might trap in a few scenarios where the error can't be meaningfully recovered from, such as running out of budget. This functions enables re-entrancy in the immediate cross-contract call.",
"min_supported_protocol": 22
"min_supported_protocol": 21
},
{
"export": "3",
"name": "set_reentrant",
"args": [
{
"name": "enabled",
"type": "Bool"
}
],
"return": "Void",
"docs": "Enables the current contract to specify the reentrancy rules.",
"min_supported_protocol": 21
}
]
},
Expand Down
14 changes: 14 additions & 0 deletions soroban-env-host/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ struct HostImpl {
#[doc(hidden)]
#[cfg(any(test, feature = "recording_mode"))]
need_to_build_module_cache: RefCell<bool>,

// Enables calling modules that link functions that call with reentry.
enable_reentrant: RefCell<bool>,
}

// Host is a newtype on Rc<HostImpl> so we can impl Env for it below.
Expand Down Expand Up @@ -382,6 +385,8 @@ impl Host {
suppress_diagnostic_events: RefCell::new(false),
#[cfg(any(test, feature = "recording_mode"))]
need_to_build_module_cache: RefCell::new(false),

enable_reentrant: RefCell::new(false),
}))
}

Expand Down Expand Up @@ -2407,6 +2412,15 @@ impl VmCallerEnv for Host {
self.try_call_with_params(contract_address, func, args, call_params)
}

fn set_reentrant(
&self,
_vmcaller: &mut VmCaller<Host>,
enabled: Bool,
) -> Result<Void, HostError> {
*self.0.enable_reentrant.borrow_mut() = enabled.try_into()?;
Ok(Void::from(()))
}

// endregion: "call" module functions
// region: "buf" module functions

Expand Down
24 changes: 20 additions & 4 deletions soroban-env-host/src/host/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ impl Frame {
}

impl Host {
pub(crate) fn get_reentrancy_flag(&self) -> Result<bool, HostError> {
Ok(*self.0.enable_reentrant.borrow())
}

/// Returns if the host currently has a frame on the stack.
///
/// A frame being on the stack usually indicates that a contract is currently
Expand Down Expand Up @@ -686,7 +690,7 @@ impl Host {
let args_vec = args.to_vec();
match &instance.executable {
ContractExecutable::Wasm(wasm_hash) => {
let vm = self.instantiate_vm(id, wasm_hash)?;
let vm = self.instantiate_vm(id, wasm_hash, true)?;
let relative_objects = Vec::new();
self.with_frame(
Frame::ContractVM {
Expand All @@ -709,7 +713,12 @@ impl Host {
}
}

fn instantiate_vm(&self, id: &Hash, wasm_hash: &Hash) -> Result<Rc<Vm>, HostError> {
fn instantiate_vm(
&self,
id: &Hash,
wasm_hash: &Hash,
reentry_guard: bool,
) -> Result<Rc<Vm>, HostError> {
#[cfg(any(test, feature = "recording_mode"))]
{
if !self.in_storage_recording_mode()? {
Expand Down Expand Up @@ -802,7 +811,14 @@ impl Host {
#[cfg(not(any(test, feature = "recording_mode")))]
let cost_mode = crate::vm::ModuleParseCostMode::Normal;

Vm::new_with_cost_inputs(self, contract_id, code.as_slice(), costs, cost_mode)
Vm::new_with_cost_inputs(
self,
contract_id,
code.as_slice(),
costs,
cost_mode,
reentry_guard,
)
}

pub(crate) fn get_contract_protocol_version(
Expand All @@ -817,7 +833,7 @@ impl Host {
let instance = self.retrieve_contract_instance_from_storage(&storage_key)?;
match &instance.executable {
ContractExecutable::Wasm(wasm_hash) => {
let vm = self.instantiate_vm(contract_id, wasm_hash)?;
let vm = self.instantiate_vm(contract_id, wasm_hash, false)?;
Ok(vm.module.proto_version)
}
ContractExecutable::StellarAsset => self.get_ledger_protocol_version(),
Expand Down
1 change: 1 addition & 0 deletions soroban-env-host/src/test/lifecycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,7 @@ mod cap_xx_opt_in_reentry {
let contract_id_b = host.register_test_contract_wasm(SIMPLE_REENTRY_CONTRACT_B);
host.enable_debug().unwrap();
let args = test_vec![&host, contract_id_b].into();

call_contract(&host, contract_id_a, args);

let event_body = ContractEventBody::V0(ContractEventV0 {
Expand Down
24 changes: 20 additions & 4 deletions soroban-env-host/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,20 @@ impl Host {
pub(crate) fn make_linker(
engine: &wasmi::Engine,
symbols: &BTreeSet<(&str, &str)>,
enable_reentrant_linking: bool,
) -> Result<Linker<Host>, HostError> {
let mut linker = Linker::new(&engine);
for hf in HOST_FUNCTIONS {
if !enable_reentrant_linking {
if symbols.contains(&("d", "1")) || symbols.contains(&("d", "2")) {
return Err(crate::Error::from_type_and_code(
ScErrorType::WasmVm,
ScErrorCode::ArithDomain,
)
.try_into()?);
}
}

if symbols.contains(&(hf.mod_str, hf.fn_str)) {
(hf.wrap)(&mut linker).map_err(|le| wasmi::Error::Linker(le))?;
}
Expand Down Expand Up @@ -257,7 +268,7 @@ impl Vm {
if let Some(linker) = &*host.try_borrow_linker()? {
Self::instantiate(host, contract_id, parsed_module, linker)
} else {
let linker = parsed_module.make_linker(host)?;
let linker = parsed_module.make_linker(host, true)?;
Self::instantiate(host, contract_id, parsed_module, &linker)
}
}
Expand Down Expand Up @@ -286,13 +297,16 @@ impl Vm {
let cost_inputs = VersionedContractCodeCostInputs::V0 {
wasm_bytes: wasm.len(),
};
Self::new_with_cost_inputs(

let vm = Self::new_with_cost_inputs(
host,
contract_id,
wasm,
cost_inputs,
ModuleParseCostMode::Normal,
)
false,
);
vm
}

pub(crate) fn new_with_cost_inputs(
Expand All @@ -301,11 +315,13 @@ impl Vm {
wasm: &[u8],
cost_inputs: VersionedContractCodeCostInputs,
cost_mode: ModuleParseCostMode,
reentry_guard: bool,
) -> Result<Rc<Self>, HostError> {
let _span = tracy_span!("Vm::new");
VmInstantiationTimer::new(host.clone());
let parsed_module = Self::parse_module(host, wasm, cost_inputs, cost_mode)?;
let linker = parsed_module.make_linker(host)?;
let linker = parsed_module.make_linker(host, reentry_guard)?;

Self::instantiate(host, contract_id, parsed_module, &linker)
}

Expand Down
5 changes: 4 additions & 1 deletion soroban-env-host/src/vm/module_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ impl ModuleCache {
}

pub fn make_linker(&self, host: &Host) -> Result<wasmi::Linker<Host>, HostError> {
self.with_import_symbols(host, |symbols| Host::make_linker(&self.engine, symbols))
let enable_reentrant_linking = host.get_reentrancy_flag()?;
self.with_import_symbols(host, |symbols| {
Host::make_linker(&self.engine, symbols, enable_reentrant_linking)
})
}

pub fn get_module(
Expand Down
14 changes: 12 additions & 2 deletions soroban-env-host/src/vm/parsed_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,19 @@ impl ParsedModule {
callback(&symbols)
}

pub fn make_linker(&self, host: &Host) -> Result<wasmi::Linker<Host>, HostError> {
pub fn make_linker(
&self,
host: &Host,
reentry_guard: bool,
) -> Result<wasmi::Linker<Host>, HostError> {
self.with_import_symbols(host, |symbols| {
Host::make_linker(self.module.engine(), symbols)
let enable_reentrant_linking = if reentry_guard {
host.get_reentrancy_flag()?
} else {
true
};

Host::make_linker(self.module.engine(), symbols, enable_reentrant_linking)
})
}

Expand Down
Binary file modified soroban-test-wasms/wasm-workspace/opt/23/example_reentry_a.wasm
Binary file not shown.
13 changes: 10 additions & 3 deletions soroban-test-wasms/wasm-workspace/reentry_a/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ use soroban_sdk::{contract, contractimpl, Env, Address, Symbol, TryIntoVal, Vec,
#[link(wasm_import_module = "d")]
extern "C" {
#[allow(improper_ctypes)]
#[link_name = "1"]
pub fn call_reentrant(contract: i64, func: i64, args: i64, ) -> i64;
#[link_name = "_"]
pub fn call_contract(contract: i64, func: i64, args: i64, ) -> i64;

#[allow(improper_ctypes)]
#[link_name = "3"]
pub fn set_reentrant(enabled: i64, ) -> i64;
}

#[contract]
Expand All @@ -19,9 +23,12 @@ impl Contract {
let called_val = called.as_val().get_payload() as i64;
let func_val = func.as_val().get_payload() as i64;
let args_val = args.as_val().get_payload() as i64;

let set_reentrant_flag = Val::from_bool(true).as_val().get_payload() as i64;

unsafe {
call_reentrant(called_val, func_val, args_val);
set_reentrant(set_reentrant_flag);
call_contract(called_val, func_val, args_val);
};
}

Expand Down

0 comments on commit d3bc92b

Please sign in to comment.