Skip to content

Commit

Permalink
Safe zero copy implementation with no change to existing byte layout
Browse files Browse the repository at this point in the history
  • Loading branch information
jgur-psyops committed Aug 23, 2024
1 parent 24dfeb1 commit 6fbdd7d
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn lending_pool_handle_bankruptcy<'info>(
.lending_account
.balances
.iter_mut()
.find(|balance| balance.active && balance.bank_pk == bank_loader.key());
.find(|balance| balance.is_active() && balance.bank_pk == bank_loader.key());

check!(
lending_account_balance.is_some(),
Expand Down
43 changes: 27 additions & 16 deletions programs/marginfi/src/state/marginfi_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
};
use anchor_lang::prelude::*;
use anchor_spl::token_interface::Mint;
use bytemuck::{Pod, Zeroable};
use fixed::types::I80F48;
use std::{
cmp::{max, min},
Expand All @@ -25,7 +26,7 @@ use type_layout::TypeLayout;

assert_struct_size!(MarginfiAccount, 2304);
assert_struct_align!(MarginfiAccount, 8);
#[account(zero_copy(unsafe))]
#[account(zero_copy)]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
Expand All @@ -42,7 +43,8 @@ pub struct MarginfiAccount {
/// - DISABLED_FLAG = 1 << 0 = 1 - This flag indicates that the account is disabled,
/// and no further actions can be taken on it.
pub account_flags: u64, // 8
pub _padding: [u64; 63], // 504
pub _padding0: [u64; 32], // 504
pub _padding1: [u64; 31],
}

pub const DISABLED_FLAG: u64 = 1 << 0;
Expand All @@ -61,7 +63,7 @@ impl MarginfiAccount {
self.lending_account
.balances
.iter()
.filter(|b| b.active)
.filter(|b| b.is_active())
.count()
* 2 // TODO: Make account count oracle setup specific
}
Expand Down Expand Up @@ -169,7 +171,7 @@ impl<'info> BankAccountWithPriceFeed<'_, 'info> {
let active_balances = lending_account
.balances
.iter()
.filter(|balance| balance.active)
.filter(|balance| balance.is_active())
.collect::<Vec<_>>();

debug!("Expecting {} remaining accounts", active_balances.len() * 2);
Expand Down Expand Up @@ -709,20 +711,20 @@ const MAX_LENDING_ACCOUNT_BALANCES: usize = 16;

assert_struct_size!(LendingAccount, 1728);
assert_struct_align!(LendingAccount, 8);
#[zero_copy(unsafe)]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
derive(Debug, PartialEq, Eq, TypeLayout)
)]
#[derive(AnchorDeserialize, AnchorSerialize, Copy, Clone, Zeroable, Pod)]
pub struct LendingAccount {
pub balances: [Balance; MAX_LENDING_ACCOUNT_BALANCES], // 104 * 16 = 1664
pub _padding: [u64; 8], // 8 * 8 = 64
}

impl LendingAccount {
pub fn get_first_empty_balance(&self) -> Option<usize> {
self.balances.iter().position(|b| !b.active)
self.balances.iter().position(|b| !b.is_active())
}
}

Expand All @@ -731,24 +733,24 @@ impl LendingAccount {
pub fn get_balance(&self, bank_pk: &Pubkey) -> Option<&Balance> {
self.balances
.iter()
.find(|balance| balance.active && balance.bank_pk.eq(bank_pk))
.find(|balance| balance.is_active() && balance.bank_pk.eq(bank_pk))
}

pub fn get_active_balances_iter(&self) -> impl Iterator<Item = &Balance> {
self.balances.iter().filter(|b| b.active)
self.balances.iter().filter(|b| b.is_active())
}
}

assert_struct_size!(Balance, 104);
assert_struct_align!(Balance, 8);
#[zero_copy(unsafe)]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
derive(Debug, PartialEq, Eq, TypeLayout)
)]
#[derive(AnchorDeserialize, AnchorSerialize, Copy, Clone, Zeroable, Pod)]
pub struct Balance {
pub active: bool,
pub active: u8,
pub bank_pk: Pubkey,
pub _pad0: [u8; 7],
pub asset_shares: WrappedI80F48,
Expand All @@ -759,6 +761,14 @@ pub struct Balance {
}

impl Balance {
pub fn is_active(&self) -> bool {
self.active != 0
}

pub fn set_active(&mut self, value: bool) {
self.active = value as u8;
}

/// Check whether a balance is empty while accounting for any rounding errors
/// that might have occured during depositing/withdrawing.
#[inline]
Expand Down Expand Up @@ -820,7 +830,7 @@ impl Balance {

pub fn empty_deactivated() -> Self {
Balance {
active: false,
active: 0,
bank_pk: Pubkey::default(),
_pad0: [0; 7],
asset_shares: WrappedI80F48::from(I80F48::ZERO),
Expand All @@ -847,7 +857,7 @@ impl<'a> BankAccountWrapper<'a> {
let balance = lending_account
.balances
.iter_mut()
.find(|balance| balance.active && balance.bank_pk.eq(bank_pk))
.find(|balance| balance.is_active() && balance.bank_pk.eq(bank_pk))
.ok_or_else(|| error!(MarginfiError::BankAccountNotFound))?;

Ok(Self { balance, bank })
Expand All @@ -863,7 +873,7 @@ impl<'a> BankAccountWrapper<'a> {
let balance_index = lending_account
.balances
.iter()
.position(|balance| balance.active && balance.bank_pk.eq(bank_pk));
.position(|balance| balance.is_active() && balance.bank_pk.eq(bank_pk));

match balance_index {
Some(balance_index) => {
Expand All @@ -880,7 +890,7 @@ impl<'a> BankAccountWrapper<'a> {
.ok_or_else(|| error!(MarginfiError::LendingAccountBalanceSlotsFull))?;

lending_account.balances[empty_index] = Balance {
active: true,
active: 1,
bank_pk: *bank_pk,
_pad0: [0; 7],
asset_shares: I80F48::ZERO.into(),
Expand Down Expand Up @@ -1410,7 +1420,7 @@ mod test {
authority: authority.into(),
lending_account: LendingAccount {
balances: [Balance {
active: true,
active: 1,
bank_pk: bank_pk.into(),
_pad0: [0; 7],
asset_shares: WrappedI80F48::default(),
Expand All @@ -1422,7 +1432,8 @@ mod test {
_padding: [0; 8],
},
account_flags: TRANSFER_AUTHORITY_ALLOWED_FLAG,
_padding: [0; 63],
_padding0: [0; 32],
_padding1: [0; 31],
};

assert!(acc.get_flag(TRANSFER_AUTHORITY_ALLOWED_FLAG));
Expand Down
34 changes: 20 additions & 14 deletions programs/marginfi/src/state/marginfi_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::{
use anchor_lang::prelude::borsh;
use anchor_lang::prelude::*;
use anchor_spl::token_interface::*;
use bytemuck::{Pod, Zeroable};
use fixed::types::I80F48;
use pyth_sdk_solana::{state::SolanaPriceAccount, PriceFeed};
use pyth_solana_receiver_sdk::price_update::FeedId;
Expand Down Expand Up @@ -129,13 +130,12 @@ impl From<InterestRateConfig> for InterestRateConfigCompact {
}
}

#[zero_copy]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
derive(PartialEq, Eq, TypeLayout)
)]
#[derive(Default, Debug)]
#[derive(Default, Debug, Copy, Clone, AnchorSerialize, AnchorDeserialize, Zeroable, Pod)]
pub struct InterestRateConfig {
// Curve Params
pub optimal_utilization_rate: WrappedI80F48,
Expand Down Expand Up @@ -282,7 +282,7 @@ pub struct InterestRateConfigOpt {

assert_struct_size!(Bank, 1856);
assert_struct_align!(Bank, 8);
#[account(zero_copy(unsafe))]
#[account(zero_copy)]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
Expand Down Expand Up @@ -921,12 +921,14 @@ fn calc_interest_payment_for_period(apr: I80F48, time_delta: u64, value: I80F48)

#[repr(u8)]
#[cfg_attr(any(feature = "test", feature = "client"), derive(PartialEq, Eq))]
#[derive(Copy, Clone, Debug, AnchorSerialize, AnchorDeserialize)]
#[derive(Debug, Clone, Copy, AnchorDeserialize, AnchorSerialize)]
pub enum BankOperationalState {
Paused,
Operational,
ReduceOnly,
Paused = 0,
Operational = 1,
ReduceOnly = 2,
}
unsafe impl Zeroable for BankOperationalState {}
unsafe impl Pod for BankOperationalState {}

#[cfg(feature = "client")]
impl Display for BankOperationalState {
Expand All @@ -942,15 +944,17 @@ impl Display for BankOperationalState {
#[repr(u8)]
#[derive(Copy, Clone, Debug, AnchorSerialize, AnchorDeserialize, PartialEq, Eq)]
pub enum RiskTier {
Collateral,
Collateral = 0,
/// ## Isolated Risk
/// Assets in this trance can be borrowed only in isolation.
/// They can't be borrowed together with other assets.
///
/// For example, if users has USDC, and wants to borrow XYZ which is isolated,
/// they can't borrow XYZ together with SOL, only XYZ alone.
Isolated,
Isolated = 1,
}
unsafe impl Zeroable for RiskTier {}
unsafe impl Pod for RiskTier {}

#[repr(C)]
#[cfg_attr(
Expand Down Expand Up @@ -1019,7 +1023,8 @@ impl From<BankConfigCompact> for BankConfig {
_pad1: [0; 7],
total_asset_value_init_limit: config.total_asset_value_init_limit,
oracle_max_age: config.oracle_max_age,
_padding: [0; 38],
_padding0: [0; 6],
_padding1: [0; 32],
}
}
}
Expand Down Expand Up @@ -1047,13 +1052,12 @@ impl From<BankConfig> for BankConfigCompact {

assert_struct_size!(BankConfig, 544);
assert_struct_align!(BankConfig, 8);
#[zero_copy(unsafe)]
#[repr(C)]
#[cfg_attr(
any(feature = "test", feature = "client"),
derive(PartialEq, Eq, TypeLayout)
)]
#[derive(Debug)]
#[derive(Debug, Clone, Copy, AnchorDeserialize, AnchorSerialize, Zeroable, Pod)]
/// TODO: Convert weights to (u64, u64) to avoid precision loss (maybe?)
pub struct BankConfig {
pub asset_weight_init: WrappedI80F48,
Expand Down Expand Up @@ -1092,7 +1096,8 @@ pub struct BankConfig {
/// Time window in seconds for the oracle price feed to be considered live.
pub oracle_max_age: u16,

pub _padding: [u8; 38],
pub _padding0: [u8; 6],
pub _padding1: [u8; 32],
}

impl Default for BankConfig {
Expand All @@ -1113,7 +1118,8 @@ impl Default for BankConfig {
_pad1: [0; 7],
total_asset_value_init_limit: TOTAL_ASSET_VALUE_INIT_LIMIT_INACTIVE,
oracle_max_age: 0,
_padding: [0; 38],
_padding0: [0; 6],
_padding1: [0; 32],
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions programs/marginfi/src/state/price.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{cell::Ref, cmp::min};

use anchor_lang::prelude::*;
use bytemuck::{Pod, Zeroable};
use enum_dispatch::enum_dispatch;
use fixed::types::I80F48;
use pyth_sdk_solana::{state::SolanaPriceAccount, Price, PriceFeed};
Expand Down Expand Up @@ -30,12 +31,14 @@ use pyth_solana_receiver_sdk::PYTH_PUSH_ORACLE_ID;
#[cfg_attr(any(feature = "test", feature = "client"), derive(PartialEq, Eq))]
#[derive(Copy, Clone, Debug, AnchorSerialize, AnchorDeserialize)]
pub enum OracleSetup {
None,
PythLegacy,
SwitchboardV2,
PythPushOracle,
SwitchboardPull,
None = 0,
PythLegacy = 1,
SwitchboardV2 = 2,
PythPushOracle = 3,
SwitchboardPull = 4,
}
unsafe impl Zeroable for OracleSetup {}
unsafe impl Pod for OracleSetup {}

#[derive(Copy, Clone, Debug)]
pub enum PriceBias {
Expand Down
22 changes: 13 additions & 9 deletions programs/marginfi/tests/misc/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ async fn account_field_values_reg() -> anyhow::Result<()> {
pubkey!("Dq7wypbedtaqQK9QqEFvfrxc4ppfRGXCeTVd7ee7n2jw")
);
assert_eq!(account.account_flags, 0);
assert_eq!(account._padding, [0; 63]);
assert_eq!(account._padding0, [0; 32]);
assert_eq!(account._padding1, [0; 31]);

let balance_1 = account.lending_account.balances[0];
assert!(balance_1.active);
assert!(balance_1.is_active());
assert_eq!(
balance_1.bank_pk,
pubkey!("2s37akK2eyBbp8DZgCm7RtsaEz8eJP3Nxd4urLHQv7yB")
Expand All @@ -70,7 +71,7 @@ async fn account_field_values_reg() -> anyhow::Result<()> {
assert_eq!(balance_1._padding, [0; 1]);

let balance_2 = account.lending_account.balances[1];
assert!(balance_2.active);
assert!(balance_2.is_active());
assert_eq!(
balance_2.bank_pk,
pubkey!("CCKtUs6Cgwo4aaQUmBPmyoApH2gUDErxNZCAntD6LYGh")
Expand Down Expand Up @@ -117,10 +118,11 @@ async fn account_field_values_reg() -> anyhow::Result<()> {
pubkey!("3T1kGHp7CrdeW9Qj1t8NMc2Ks233RyvzVhoaUPWoBEFK")
);
assert_eq!(account.account_flags, 0);
assert_eq!(account._padding, [0; 63]);
assert_eq!(account._padding0, [0; 32]);
assert_eq!(account._padding1, [0; 31]);

let balance_1 = account.lending_account.balances[0];
assert!(balance_1.active);
assert!(balance_1.is_active());
assert_eq!(
balance_1.bank_pk,
pubkey!("6hS9i46WyTq1KXcoa2Chas2Txh9TJAVr6n1t3tnrE23K")
Expand All @@ -145,7 +147,7 @@ async fn account_field_values_reg() -> anyhow::Result<()> {
assert_eq!(balance_1._padding, [0; 1]);

let balance_2 = account.lending_account.balances[1];
assert!(!balance_2.active);
assert!(!balance_2.is_active());
assert_eq!(
balance_2.bank_pk,
pubkey!("11111111111111111111111111111111")
Expand Down Expand Up @@ -192,10 +194,11 @@ async fn account_field_values_reg() -> anyhow::Result<()> {
pubkey!("7hmfVTuXc7HeX3YQjpiCXGVQuTeXonzjp795jorZukVR")
);
assert_eq!(account.account_flags, 0);
assert_eq!(account._padding, [0; 63]);
assert_eq!(account._padding0, [0; 32]);
assert_eq!(account._padding1, [0; 31]);

let balance_1 = account.lending_account.balances[0];
assert!(!balance_1.active);
assert!(!balance_1.is_active());
assert_eq!(
balance_1.bank_pk,
pubkey!("11111111111111111111111111111111")
Expand Down Expand Up @@ -638,7 +641,8 @@ async fn bank_field_values_reg() -> anyhow::Result<()> {
assert_eq!(bank.config._pad1, [0; 7]);
assert_eq!(bank.config.total_asset_value_init_limit, 0);
assert_eq!(bank.config.oracle_max_age, 300);
assert_eq!(bank.config._padding, [0; 38]);
assert_eq!(bank.config._padding0, [0; 6]);
assert_eq!(bank.config._padding1, [0; 32]);

assert_eq!(bank.flags, 2);

Expand Down
2 changes: 1 addition & 1 deletion programs/marginfi/tests/user_actions/create_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn marginfi_account_create_success() -> anyhow::Result<()> {
.lending_account
.balances
.iter()
.all(|bank| !bank.active));
.all(|bank| !bank.is_active()));

Ok(())
}
Loading

0 comments on commit 6fbdd7d

Please sign in to comment.