diff --git a/programs/bpf_loader/src/syscalls/mem_ops.rs b/programs/bpf_loader/src/syscalls/mem_ops.rs index 0367c03d006ad8..e6db37ccaa29bf 100644 --- a/programs/bpf_loader/src/syscalls/mem_ops.rs +++ b/programs/bpf_loader/src/syscalls/mem_ops.rs @@ -1,5 +1,6 @@ use { super::*, + solana_program_runtime::invoke_context::SerializedAccountMetadata, solana_rbpf::{error::EbpfError, memory_region::MemoryRegion}, std::slice, }; @@ -77,7 +78,9 @@ declare_builtin_function!( cmp_result_addr, invoke_context.get_check_aligned(), )?; - *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, memory_mapping)?; + let syscall_context = invoke_context.get_syscall_context()?; + + *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping)?; } else { let s1 = translate_slice::( memory_mapping, @@ -128,7 +131,9 @@ declare_builtin_function!( .get_feature_set() .is_active(&solana_feature_set::bpf_account_data_direct_mapping::id()) { - memset_non_contiguous(dst_addr, c as u8, n, memory_mapping) + let syscall_context = invoke_context.get_syscall_context()?; + + memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping) } else { let s = translate_slice_mut::( memory_mapping, @@ -153,7 +158,15 @@ fn memmove( .get_feature_set() .is_active(&solana_feature_set::bpf_account_data_direct_mapping::id()) { - memmove_non_contiguous(dst_addr, src_addr, n, memory_mapping) + let syscall_context = invoke_context.get_syscall_context()?; + + memmove_non_contiguous( + dst_addr, + src_addr, + n, + &syscall_context.accounts_metadata, + memory_mapping, + ) } else { let dst_ptr = translate_slice_mut::( memory_mapping, @@ -179,6 +192,7 @@ fn memmove_non_contiguous( dst_addr: u64, src_addr: u64, n: u64, + accounts: &[SerializedAccountMetadata], memory_mapping: &MemoryMapping, ) -> Result { let reverse = dst_addr.wrapping_sub(src_addr) < n; @@ -188,6 +202,7 @@ fn memmove_non_contiguous( AccessType::Store, dst_addr, n, + accounts, memory_mapping, reverse, |src_host_addr, dst_host_addr, chunk_len| { @@ -214,6 +229,7 @@ fn memcmp_non_contiguous( src_addr: u64, dst_addr: u64, n: u64, + accounts: &[SerializedAccountMetadata], memory_mapping: &MemoryMapping, ) -> Result { let memcmp_chunk = |s1_addr, s2_addr, chunk_len| { @@ -237,6 +253,7 @@ fn memcmp_non_contiguous( AccessType::Load, dst_addr, n, + accounts, memory_mapping, false, memcmp_chunk, @@ -274,9 +291,11 @@ fn memset_non_contiguous( dst_addr: u64, c: u8, n: u64, + accounts: &[SerializedAccountMetadata], memory_mapping: &MemoryMapping, ) -> Result { - let dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, AccessType::Store, dst_addr, n)?; + let dst_chunk_iter = + MemoryChunkIterator::new(memory_mapping, accounts, AccessType::Store, dst_addr, n)?; for item in dst_chunk_iter { let (dst_region, dst_vm_addr, dst_len) = item?; let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?; @@ -292,6 +311,7 @@ fn iter_memory_pair_chunks( dst_access: AccessType, dst_addr: u64, n_bytes: u64, + accounts: &[SerializedAccountMetadata], memory_mapping: &MemoryMapping, reverse: bool, mut fun: F, @@ -301,10 +321,10 @@ where F: FnMut(*const u8, *const u8, usize) -> Result, { let mut src_chunk_iter = - MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes) + MemoryChunkIterator::new(memory_mapping, accounts, src_access, src_addr, n_bytes) .map_err(EbpfError::from)?; let mut dst_chunk_iter = - MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes) + MemoryChunkIterator::new(memory_mapping, accounts, dst_access, dst_addr, n_bytes) .map_err(EbpfError::from)?; let mut src_chunk = None; @@ -392,17 +412,21 @@ where struct MemoryChunkIterator<'a> { memory_mapping: &'a MemoryMapping<'a>, + accounts: &'a [SerializedAccountMetadata], access_type: AccessType, initial_vm_addr: u64, vm_addr_start: u64, // exclusive end index (start + len, so one past the last valid address) vm_addr_end: u64, len: u64, + account_index: usize, + is_account: Option, } impl<'a> MemoryChunkIterator<'a> { fn new( memory_mapping: &'a MemoryMapping, + accounts: &'a [SerializedAccountMetadata], access_type: AccessType, vm_addr: u64, len: u64, @@ -413,13 +437,17 @@ impl<'a> MemoryChunkIterator<'a> { len, "unknown", ))?; + Ok(MemoryChunkIterator { memory_mapping, + accounts, access_type, initial_vm_addr: vm_addr, len, vm_addr_start: vm_addr, vm_addr_end, + account_index: 0, + is_account: None, }) } @@ -460,6 +488,36 @@ impl<'a> Iterator for MemoryChunkIterator<'a> { } }; + let region_is_account; + + loop { + if let Some(account) = self.accounts.get(self.account_index) { + let account_addr = account.vm_data_addr; + let resize_addr = account_addr.saturating_add(account.original_data_len as u64); + + if resize_addr < region.vm_addr { + // region is after this account, move on next one + self.account_index = self.account_index.saturating_add(1); + } else { + region_is_account = + region.vm_addr == account_addr || region.vm_addr == resize_addr; + break; + } + } else { + // address is after all the accounts + region_is_account = false; + break; + } + } + + if let Some(is_account) = self.is_account { + if is_account != region_is_account { + return Some(Err(SyscallError::InvalidLength.into())); + } + } else { + self.is_account = Some(region_is_account); + } + let vm_addr = self.vm_addr_start; let chunk_len = if region.vm_addr_end <= self.vm_addr_end { @@ -536,7 +594,7 @@ mod tests { let memory_mapping = MemoryMapping::new(vec![], &config, &SBPFVersion::V2).unwrap(); let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, 0, 1).unwrap(); + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1).unwrap(); src_chunk_iter.next().unwrap().unwrap(); } @@ -550,7 +608,7 @@ mod tests { let memory_mapping = MemoryMapping::new(vec![], &config, &SBPFVersion::V2).unwrap(); let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, u64::MAX, 1).unwrap(); + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1).unwrap(); src_chunk_iter.next().unwrap().unwrap(); } @@ -569,9 +627,14 @@ mod tests { .unwrap(); // check oob at the lower bound on the first next() - let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 42) - .unwrap(); + let mut src_chunk_iter = MemoryChunkIterator::new( + &memory_mapping, + &[], + AccessType::Load, + MM_PROGRAM_START - 1, + 42, + ) + .unwrap(); assert_matches!( src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(), EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_PROGRAM_START - 1 @@ -580,7 +643,7 @@ mod tests { // check oob at the upper bound. Since the memory mapping isn't empty, // this always happens on the second next(). let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START, 43) + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, MM_PROGRAM_START, 43) .unwrap(); assert!(src_chunk_iter.next().unwrap().is_ok()); assert_matches!( @@ -590,7 +653,7 @@ mod tests { // check oob at the upper bound on the first next_back() let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START, 43) + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, MM_PROGRAM_START, 43) .unwrap() .rev(); assert_matches!( @@ -599,10 +662,15 @@ mod tests { ); // check oob at the upper bound on the 2nd next_back() - let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 43) - .unwrap() - .rev(); + let mut src_chunk_iter = MemoryChunkIterator::new( + &memory_mapping, + &[], + AccessType::Load, + MM_PROGRAM_START - 1, + 43, + ) + .unwrap() + .rev(); assert!(src_chunk_iter.next().unwrap().is_ok()); assert_matches!( src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(), @@ -625,15 +693,25 @@ mod tests { .unwrap(); // check lower bound - let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START - 1, 1) - .unwrap(); + let mut src_chunk_iter = MemoryChunkIterator::new( + &memory_mapping, + &[], + AccessType::Load, + MM_PROGRAM_START - 1, + 1, + ) + .unwrap(); assert!(src_chunk_iter.next().unwrap().is_err()); // check upper bound - let mut src_chunk_iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, MM_PROGRAM_START + 42, 1) - .unwrap(); + let mut src_chunk_iter = MemoryChunkIterator::new( + &memory_mapping, + &[], + AccessType::Load, + MM_PROGRAM_START + 42, + 1, + ) + .unwrap(); assert!(src_chunk_iter.next().unwrap().is_err()); for (vm_addr, len) in [ @@ -645,7 +723,7 @@ mod tests { ] { for rev in [true, false] { let iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, vm_addr, len) + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, vm_addr, len) .unwrap(); let res = if rev { to_chunk_vec(iter.rev()) @@ -690,7 +768,7 @@ mod tests { ] { for rev in [false, true] { let iter = - MemoryChunkIterator::new(&memory_mapping, AccessType::Load, vm_addr, len) + MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, vm_addr, len) .unwrap(); let res = if rev { expected.reverse(); @@ -730,6 +808,7 @@ mod tests { AccessType::Load, MM_PROGRAM_START + 8, 8, + &[], &memory_mapping, false, |_src, _dst, _len| Ok::<_, Error>(0), @@ -745,6 +824,7 @@ mod tests { AccessType::Load, MM_PROGRAM_START + 2, 3, + &[], &memory_mapping, false, |_src, _dst, _len| Ok::<_, Error>(0), @@ -772,7 +852,14 @@ mod tests { ) .unwrap(); - memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap(); + memmove_non_contiguous( + MM_PROGRAM_START, + MM_PROGRAM_START + 8, + 4, + &[], + &memory_mapping, + ) + .unwrap(); } #[test_case(&[], (0, 0, 0); "no regions")] @@ -819,6 +906,7 @@ mod tests { MM_PROGRAM_START + dst_offset as u64, MM_PROGRAM_START + src_offset as u64, len as u64, + &[], &memory_mapping, ) .unwrap(); @@ -850,7 +938,7 @@ mod tests { .unwrap(); assert_eq!( - memset_non_contiguous(MM_PROGRAM_START, 0x33, 9, &memory_mapping).unwrap(), + memset_non_contiguous(MM_PROGRAM_START, 0x33, 9, &[], &memory_mapping).unwrap(), 0 ); } @@ -878,7 +966,7 @@ mod tests { .unwrap(); assert_eq!( - memset_non_contiguous(MM_PROGRAM_START + 1, 0x55, 7, &memory_mapping).unwrap(), + memset_non_contiguous(MM_PROGRAM_START + 1, 0x55, 7, &[], &memory_mapping).unwrap(), 0 ); assert_eq!(&mem1, &[0x11]); @@ -909,8 +997,14 @@ mod tests { // non contiguous src assert_eq!( - memcmp_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 9, 9, &memory_mapping) - .unwrap(), + memcmp_non_contiguous( + MM_PROGRAM_START, + MM_PROGRAM_START + 9, + 9, + &[], + &memory_mapping + ) + .unwrap(), 0 ); @@ -920,6 +1014,7 @@ mod tests { MM_PROGRAM_START + 10, MM_PROGRAM_START + 1, 8, + &[], &memory_mapping ) .unwrap(), @@ -932,6 +1027,7 @@ mod tests { MM_PROGRAM_START + 1, MM_PROGRAM_START + 11, 5, + &[], &memory_mapping ) .unwrap(), diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index 5d32b4f3d7dca6..656ef7afd2596b 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -6744,6 +6744,13 @@ dependencies = [ "solana-program", ] +[[package]] +name = "solana-sbf-rust-account-mem" +version = "2.2.0" +dependencies = [ + "solana-program", +] + [[package]] name = "solana-sbf-rust-alloc" version = "2.2.0" diff --git a/programs/sbf/Cargo.toml b/programs/sbf/Cargo.toml index 7d48d1ab8efca2..4c30264b957af8 100644 --- a/programs/sbf/Cargo.toml +++ b/programs/sbf/Cargo.toml @@ -143,6 +143,7 @@ name = "bpf_loader" members = [ "rust/128bit", "rust/128bit_dep", + "rust/account_mem", "rust/alloc", "rust/alt_bn128", "rust/alt_bn128_compression", diff --git a/programs/sbf/rust/account_mem/Cargo.toml b/programs/sbf/rust/account_mem/Cargo.toml new file mode 100644 index 00000000000000..6a4f56be0a2669 --- /dev/null +++ b/programs/sbf/rust/account_mem/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "solana-sbf-rust-account-mem" +version = { workspace = true } +description = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[dependencies] +solana-program = { workspace = true } + +[lib] +crate-type = ["cdylib"] diff --git a/programs/sbf/rust/account_mem/src/lib.rs b/programs/sbf/rust/account_mem/src/lib.rs new file mode 100644 index 00000000000000..c873adec86373d --- /dev/null +++ b/programs/sbf/rust/account_mem/src/lib.rs @@ -0,0 +1,116 @@ +//! Test mem functions + +use solana_program::{ + account_info::AccountInfo, + entrypoint::ProgramResult, + program_error::ProgramError, + program_memory::{sol_memcmp, sol_memcpy, sol_memmove, sol_memset}, + pubkey::Pubkey, +}; + +solana_program::entrypoint_no_alloc!(process_instruction); +pub fn process_instruction( + _program_id: &Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> ProgramResult { + let mut buf = [0u8; 2048]; + + let account = accounts.last().ok_or(ProgramError::NotEnoughAccountKeys)?; + let data_len = account.try_borrow_data()?.len().wrapping_add(10240); + let data_ptr = account.try_borrow_mut_data()?.as_mut_ptr(); + // re-create slice with resize area + let data = unsafe { std::slice::from_raw_parts_mut(data_ptr, data_len) }; + + let mut too_early = |before: usize| -> &mut [u8] { + let data = data.as_mut_ptr().wrapping_sub(before); + + unsafe { std::slice::from_raw_parts_mut(data, data_len) } + }; + + match instruction_data[0] { + 0 => { + // memcmp overlaps end + sol_memcmp(&buf, &data[data_len.saturating_sub(8)..], 16); + } + 1 => { + // memcmp overlaps end + sol_memcmp(&data[data_len.saturating_sub(7)..], &buf, 15); + } + 2 => { + // memcmp overlaps begining + #[allow(clippy::manual_memcpy)] + for i in 0..500 { + buf[i] = too_early(8)[i]; + } + + sol_memcmp(too_early(8), &buf, 500); + } + 3 => { + // memcmp overlaps begining + #[allow(clippy::manual_memcpy)] + for i in 0..12 { + buf[i] = too_early(9)[i]; + } + + sol_memcmp(&buf, too_early(9), 12); + } + 4 => { + // memset overlaps end of account + sol_memset(&mut data[data_len.saturating_sub(2)..], 0, 3); + } + 5 => { + // memset overlaps begin of account area + sol_memset(too_early(2), 3, 3); + } + 6 => { + // memcpy src overlaps end of account + sol_memcpy(&mut buf, &data[data_len.saturating_sub(3)..], 10); + } + 7 => { + // memmov src overlaps end of account + unsafe { + sol_memmove( + buf.as_mut_ptr(), + data[data_len.saturating_sub(3)..].as_ptr(), + 10, + ) + }; + } + 8 => { + // memcpy src overlaps begin of account + sol_memcpy(&mut buf, too_early(3), 10); + } + 9 => { + // memmov src overlaps begin of account + unsafe { sol_memmove(buf.as_mut_ptr(), too_early(3).as_ptr(), 10) }; + } + + 10 => { + // memcpy dst overlaps end of account + sol_memcpy(&mut data[data_len.saturating_sub(3)..], &buf, 10); + } + 11 => { + // memmov dst overlaps end of account + unsafe { + sol_memmove( + data[data_len.saturating_sub(3)..].as_mut_ptr(), + buf.as_ptr(), + 10, + ) + }; + } + 12 => { + // memcpy dst overlaps begin of account + sol_memcpy(too_early(3), &buf, 10); + } + 13 => { + // memmov dst overlaps begin of account + unsafe { sol_memmove(too_early(3).as_mut_ptr(), buf.as_ptr(), 10) }; + } + + _ => {} + } + + Ok(()) +} diff --git a/programs/sbf/tests/programs.rs b/programs/sbf/tests/programs.rs index 55fb0ce9169b69..f4421d19716aef 100644 --- a/programs/sbf/tests/programs.rs +++ b/programs/sbf/tests/programs.rs @@ -17,8 +17,9 @@ use { solana_compute_budget::compute_budget::ComputeBudget, solana_feature_set::{self as feature_set, FeatureSet}, solana_ledger::token_balances::collect_token_balances, - solana_program_runtime::invoke_context::mock_process_instruction, - solana_rbpf::vm::ContextObject, + solana_program_runtime::{ + invoke_context::mock_process_instruction, solana_rbpf::vm::ContextObject, + }, solana_runtime::{ bank::{Bank, TransactionBalancesSet}, bank_client::BankClient, @@ -5440,3 +5441,65 @@ fn test_function_call_args() { assert_eq!(decoded.many_args_1, verify_many_args(&input_data)); assert_eq!(decoded.many_args_2, verify_many_args(&input_data)); } + +#[test] +#[cfg(feature = "sbf_rust")] +fn test_mem_syscalls_overlap_account_begin_or_end() { + solana_logger::setup(); + + for direct_mapping in [false, true] { + let GenesisConfigInfo { + genesis_config, + mint_keypair, + .. + } = create_genesis_config(100_123_456_789); + + let mut bank = Bank::new_for_tests(&genesis_config); + let mut feature_set = FeatureSet::all_enabled(); + if !direct_mapping { + feature_set.deactivate(&feature_set::bpf_account_data_direct_mapping::id()); + } + + let account_keypair = Keypair::new(); + + bank.feature_set = Arc::new(feature_set); + let (bank, bank_forks) = bank.wrap_with_bank_forks_for_tests(); + let mut bank_client = BankClient::new_shared(bank); + let authority_keypair = Keypair::new(); + + let (bank, program_id) = load_upgradeable_program_and_advance_slot( + &mut bank_client, + bank_forks.as_ref(), + &mint_keypair, + &authority_keypair, + "solana_sbf_rust_account_mem", + ); + + let mint_pubkey = mint_keypair.pubkey(); + let account_metas = vec![ + AccountMeta::new(mint_pubkey, true), + AccountMeta::new_readonly(program_id, false), + AccountMeta::new(account_keypair.pubkey(), false), + ]; + + let account = AccountSharedData::new(42, 1024, &program_id); + bank.store_account(&account_keypair.pubkey(), &account); + + for instr in 0..=13 { + println!("Testing direct_mapping:{direct_mapping} instruction:{instr}"); + let instruction = + Instruction::new_with_bytes(program_id, &[instr], account_metas.clone()); + + let message = Message::new(&[instruction], Some(&mint_pubkey)); + let tx = Transaction::new(&[&mint_keypair], message.clone(), bank.last_blockhash()); + let (result, _, logs) = process_transaction_and_record_inner(&bank, tx); + + if direct_mapping { + assert!(logs.last().unwrap().ends_with(" failed: InvalidLength")); + } else if result.is_err() { + // without direct mapping, we should never get the InvalidLength error + assert!(!logs.last().unwrap().ends_with(" failed: InvalidLength")); + } + } + } +} diff --git a/sdk/program-memory/src/lib.rs b/sdk/program-memory/src/lib.rs index 737ff1a0c72329..edbd546ecdd149 100644 --- a/sdk/program-memory/src/lib.rs +++ b/sdk/program-memory/src/lib.rs @@ -128,7 +128,7 @@ pub fn sol_memcpy(dst: &mut [u8], src: &[u8], n: usize) { /// /// [`ptr::copy`]: https://doc.rust-lang.org/std/ptr/fn.copy.html #[inline] -pub unsafe fn sol_memmove(dst: *mut u8, src: *mut u8, n: usize) { +pub unsafe fn sol_memmove(dst: *mut u8, src: *const u8, n: usize) { #[cfg(target_os = "solana")] syscalls::sol_memmove_(dst, src, n as u64);