Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solana-program: improve VoteState::deserialize_into() #2146

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 83 additions & 40 deletions sdk/program/src/vote/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use {
sysvar::clock::Clock,
vote::{authorized_voters::AuthorizedVoters, error::VoteError},
},
bincode::{serialize_into, serialized_size, ErrorKind},
bincode::{serialize_into, ErrorKind},
serde_derive::{Deserialize, Serialize},
std::{collections::VecDeque, fmt::Debug, io::Cursor},
};
Expand Down Expand Up @@ -323,6 +323,7 @@ const MAX_ITEMS: usize = 32;

#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct CircBuf<I> {
buf: [I; MAX_ITEMS],
/// next pointer
Expand Down Expand Up @@ -368,23 +369,6 @@ impl<I> CircBuf<I> {
}
}

#[cfg(test)]
impl<'a, I: Default + Copy> Arbitrary<'a> for CircBuf<I>
where
I: Arbitrary<'a>,
{
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
let mut circbuf = Self::default();

let len = u.arbitrary_len::<I>()?;
for _ in 0..len {
circbuf.append(I::arbitrary(u)?);
}

Ok(circbuf)
}
}

#[cfg_attr(
feature = "frozen-abi",
frozen_abi(digest = "EeenjJaSrm9hRM39gK6raRNtzG61hnk7GciUCJJRDUSQ"),
Expand Down Expand Up @@ -475,8 +459,11 @@ impl VoteState {
3762 // see test_vote_state_size_of.
}

// we retain bincode deserialize for not(target_os = "solana")
// because the hand-written parser does not support V0_23_5
// NOTE we retain `bincode::deserialize` for `not(target_os = "solana")` pending testing on mainnet-beta
// once that testing is done, `VoteState::deserialize_into` may be used for all targets
// conversion of V0_23_5 to current must be handled specially, however
// because it inserts a null voter into `authorized_voters`
// which `VoteStateVersions::is_uninitialized` erroneously reports as initialized
pub fn deserialize(input: &[u8]) -> Result<Self, InstructionError> {
#[cfg(not(target_os = "solana"))]
{
Expand All @@ -492,37 +479,39 @@ impl VoteState {
}
}

/// Deserializes the input buffer into the provided `VoteState`
/// Deserializes the input `VoteStateVersions` buffer directly into a provided `VoteState` struct
///
/// This function exists to deserialize `VoteState` in a BPF context without going above
/// the compute limit, and must be kept up to date with `bincode::deserialize`.
/// In a BPF context, V0_23_5 is not supported, but in non-BPF, all versions are supported for
/// compatibility with `bincode::deserialize`
pub fn deserialize_into(
input: &[u8],
vote_state: &mut VoteState,
) -> Result<(), InstructionError> {
let minimum_size =
serialized_size(vote_state).map_err(|_| InstructionError::InvalidAccountData)?;
if (input.len() as u64) < minimum_size {
return Err(InstructionError::InvalidAccountData);
}

let mut cursor = Cursor::new(input);

let variant = read_u32(&mut cursor)?;
match variant {
// V0_23_5. not supported; these should not exist on mainnet
0 => Err(InstructionError::InvalidAccountData),
// V0_23_5. not supported for bpf targets; these should not exist on mainnet
// supported for non-bpf targets for backwards compatibility
0 => {
#[cfg(not(target_os = "solana"))]
{
*vote_state = bincode::deserialize::<VoteStateVersions>(input)
.map(|versioned| versioned.convert_to_current())
.map_err(|_| InstructionError::InvalidAccountData)?;

Ok(())
}
#[cfg(target_os = "solana")]
Err(InstructionError::InvalidAccountData)
}
// V1_14_11. substantially different layout and data from V0_23_5
1 => deserialize_vote_state_into(&mut cursor, vote_state, false),
// Current. the only difference from V1_14_11 is the addition of a slot-latency to each vote
2 => deserialize_vote_state_into(&mut cursor, vote_state, true),
_ => Err(InstructionError::InvalidAccountData),
}?;

if cursor.position() > input.len() as u64 {
return Err(InstructionError::InvalidAccountData);
}

Ok(())
}

Expand Down Expand Up @@ -1089,7 +1078,7 @@ pub mod serde_tower_sync {

#[cfg(test)]
mod tests {
use {super::*, itertools::Itertools, rand::Rng};
use {super::*, bincode::serialized_size, itertools::Itertools, rand::Rng};

#[test]
fn test_vote_serialize() {
Expand Down Expand Up @@ -1147,16 +1136,70 @@ mod tests {
assert_eq!(e, InstructionError::InvalidAccountData);

// variant
let serialized_len_x4 = serialized_size(&test_vote_state).unwrap() * 4;
let serialized_len_x4 = serialized_size(&VoteState::default()).unwrap() * 4;
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let raw_data_length = rng.gen_range(1..serialized_len_x4);
let raw_data: Vec<u8> = (0..raw_data_length).map(|_| rng.gen::<u8>()).collect();
let mut raw_data: Vec<u8> = (0..raw_data_length).map(|_| rng.gen::<u8>()).collect();

// pure random data will ~never have a valid enum tag, so lets help it out
if raw_data_length >= 4 && rng.gen::<bool>() {
let tag = rng.gen::<u8>() % 3;
raw_data[0] = tag;
raw_data[1] = 0;
raw_data[2] = 0;
raw_data[3] = 0;
}

// it is extremely improbable, though theoretically possible, for random bytes to be syntactically valid
// so we only check that the deserialize function does not panic
// so we only check that the parser does not panic and that it succeeds or fails exactly in line with bincode
let mut test_vote_state = VoteState::default();
let test_res = VoteState::deserialize_into(&raw_data, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&raw_data)
.map(|versioned| versioned.convert_to_current());

if test_res.is_err() {
assert!(bincode_res.is_err());
} else {
assert_eq!(test_vote_state, bincode_res.unwrap());
}
}
}

#[test]
fn test_vote_deserialize_into_ill_sized() {
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1000 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let original_vote_state_versions =
VoteStateVersions::arbitrary(&mut unstructured).unwrap();
let original_buf = bincode::serialize(&original_vote_state_versions).unwrap();

let mut truncated_buf = original_buf.clone();
let mut expanded_buf = original_buf.clone();

truncated_buf.resize(original_buf.len() - 8, 0);
expanded_buf.resize(original_buf.len() + 8, 0);

// truncated fails
let mut test_vote_state = VoteState::default();
let _ = VoteState::deserialize_into(&raw_data, &mut test_vote_state);
let test_res = VoteState::deserialize_into(&truncated_buf, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&truncated_buf)
.map(|versioned| versioned.convert_to_current());

assert!(test_res.is_err());
assert!(bincode_res.is_err());

// expanded succeeds
let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&expanded_buf, &mut test_vote_state).unwrap();
let bincode_res = bincode::deserialize::<VoteStateVersions>(&expanded_buf)
.map(|versioned| versioned.convert_to_current());

assert_eq!(test_vote_state, bincode_res.unwrap());
}
}

Expand Down
45 changes: 45 additions & 0 deletions sdk/program/src/vote/state/vote_state_0_23_5.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#![allow(clippy::arithmetic_side_effects)]
use super::*;
#[cfg(test)]
use arbitrary::{Arbitrary, Unstructured};

const MAX_ITEMS: usize = 32;

#[derive(Debug, Default, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct VoteState0_23_5 {
/// the node that votes in this account
pub node_pubkey: Pubkey,
Expand Down Expand Up @@ -35,6 +38,7 @@ pub struct VoteState0_23_5 {
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct CircBuf<I> {
pub buf: [I; MAX_ITEMS],
/// next pointer
Expand All @@ -59,3 +63,44 @@ impl<I> CircBuf<I> {
self.buf[self.idx] = item;
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_vote_deserialize_0_23_5() {
// base case
let target_vote_state = VoteState0_23_5::default();
let target_vote_state_versions = VoteStateVersions::V0_23_5(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(
target_vote_state_versions.convert_to_current(),
test_vote_state
);

// variant
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState0_23_5>() * 4;
for _ in 0..100 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let arbitrary_vote_state = VoteState0_23_5::arbitrary(&mut unstructured).unwrap();
let target_vote_state_versions =
VoteStateVersions::V0_23_5(Box::new(arbitrary_vote_state));

let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(target_vote_state, test_vote_state);
}
}
}
41 changes: 41 additions & 0 deletions sdk/program/src/vote/state/vote_state_1_14_11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,44 @@ impl From<VoteState> for VoteState1_14_11 {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_vote_deserialize_1_14_11() {
// base case
let target_vote_state = VoteState1_14_11::default();
let target_vote_state_versions = VoteStateVersions::V1_14_11(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(
target_vote_state_versions.convert_to_current(),
test_vote_state
);

// variant
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState1_14_11>() * 4;
for _ in 0..1000 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let arbitrary_vote_state = VoteState1_14_11::arbitrary(&mut unstructured).unwrap();
let target_vote_state_versions =
VoteStateVersions::V1_14_11(Box::new(arbitrary_vote_state));

let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(target_vote_state, test_vote_state);
}
}
}
49 changes: 9 additions & 40 deletions sdk/program/src/vote/state/vote_state_deserialize.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use {
crate::{
instruction::InstructionError,
pubkey::Pubkey,
serialize_utils::cursor::*,
vote::state::{BlockTimestamp, LandedVote, Lockout, VoteState, MAX_ITEMS},
},
bincode::serialized_size,
std::io::Cursor,
};

Expand Down Expand Up @@ -67,46 +65,17 @@ fn read_prior_voters_into<T: AsRef<[u8]>>(
cursor: &mut Cursor<T>,
vote_state: &mut VoteState,
) -> Result<(), InstructionError> {
// record our position at the start of the struct
let prior_voters_position = cursor.position();

// `serialized_size()` must be used over `mem::size_of()` because of alignment
let is_empty_position = serialized_size(&vote_state.prior_voters)
.ok()
.and_then(|v| v.checked_add(prior_voters_position))
.and_then(|v| v.checked_sub(1))
.ok_or(InstructionError::InvalidAccountData)?;

// move to the end, to check if we need to parse the data
cursor.set_position(is_empty_position);

// if empty, we already read past the end of this struct and need to do no further work
// otherwise we go back to the start and proceed to decode the data
let is_empty = read_bool(cursor)?;
if !is_empty {
cursor.set_position(prior_voters_position);

let mut encountered_null_voter = false;
for i in 0..MAX_ITEMS {
let prior_voter = read_pubkey(cursor)?;
let from_epoch = read_u64(cursor)?;
let until_epoch = read_u64(cursor)?;
let item = (prior_voter, from_epoch, until_epoch);

if item == (Pubkey::default(), 0, 0) {
encountered_null_voter = true;
} else if encountered_null_voter {
// `prior_voters` should never be sparse
return Err(InstructionError::InvalidAccountData);
} else {
vote_state.prior_voters.buf[i] = item;
}
}

vote_state.prior_voters.idx = read_u64(cursor)? as usize;
vote_state.prior_voters.is_empty = read_bool(cursor)?;
for i in 0..MAX_ITEMS {
let prior_voter = read_pubkey(cursor)?;
let from_epoch = read_u64(cursor)?;
let until_epoch = read_u64(cursor)?;

vote_state.prior_voters.buf[i] = (prior_voter, from_epoch, until_epoch);
}

vote_state.prior_voters.idx = read_u64(cursor)? as usize;
vote_state.prior_voters.is_empty = read_bool(cursor)?;

Ok(())
}

Expand Down