diff --git a/src/lib.cairo b/src/lib.cairo index 5c79614..e73c083 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -24,6 +24,9 @@ mod interfaces { } mod utils { pub(crate) mod exp2; + pub(crate) mod fp; + #[cfg(test)] + pub(crate) mod fp_test; } #[cfg(test)] diff --git a/src/staker.cairo b/src/staker.cairo index 5dfbd5d..d9739af 100644 --- a/src/staker.cairo +++ b/src/staker.cairo @@ -1,4 +1,6 @@ use starknet::{ContractAddress}; +use crate::utils::fp::{UFixedPoint}; + #[starknet::interface] pub trait IStaker { @@ -54,25 +56,31 @@ pub trait IStaker { ) -> u128; // Gets the cumulative staked amount * per second staked for the given timestamp and account. - fn get_staked_seconds_at(self: @TContractState, staker: ContractAddress, timestamp: u64) -> u128; + fn get_cumulative_seconds_per_total_staked_at(self: @TContractState, timestamp: u64) -> UFixedPoint; + // Gets the cumulative staked amount * per second staked for the given timestamp and account. + fn calculate_staked_seconds_for_amount_between(self: @TContractState, token_amount: u128, start_at: u64, end_at: u64) -> u128; } + #[starknet::contract] pub mod Staker { + use starknet::storage::StorageAsPath; +use super::super::utils::fp::UFixedPointTrait; use core::num::traits::zero::{Zero}; use governance::interfaces::erc20::{IERC20Dispatcher, IERC20DispatcherTrait}; use starknet::storage::{ - Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePathEntry, + Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePathEntry, StoragePath, StoragePointerReadAccess, StoragePointerWriteAccess, Vec, VecTrait, MutableVecTrait, }; + use crate::utils::fp::{UFixedPoint, UFixedPointZero}; use starknet::{ get_block_timestamp, get_caller_address, get_contract_address, - storage_access::{StorePacking}, + storage_access::{StorePacking}, ContractAddress, }; - use super::{ContractAddress, IStaker}; + use super::{IStaker}; #[derive(Copy, Drop, PartialEq, Debug)] @@ -106,7 +114,7 @@ pub mod Staker { struct StakingLogRecord { timestamp: u64, total_staked: u128, - cumulative_staked_seconds: u128 + cumulative_seconds_per_total_staked: UFixedPoint, } #[storage] @@ -117,7 +125,8 @@ pub mod Staker { amount_delegated: Map, delegated_cumulative_num_snapshots: Map, delegated_cumulative_snapshot: Map>, - staking_log: Map>, + + staking_log: Vec, } #[constructor] @@ -227,8 +236,8 @@ pub mod Staker { } } - fn log_change(ref self: ContractState, staker: ContractAddress, amount: u128, is_add: bool) { - let log = self.staking_log.entry(staker); + fn log_change(ref self: ContractState, amount: u128, is_add: bool) { + let log = self.staking_log.as_path(); if log.len() == 0 { // Add the first record. If withdrawal, then it's underflow. @@ -238,7 +247,7 @@ pub mod Staker { StakingLogRecord { timestamp: get_block_timestamp(), total_staked: amount, - cumulative_staked_seconds: 0, + cumulative_seconds_per_total_staked: 0_u64.into(), } ); @@ -260,7 +269,7 @@ pub mod Staker { // Might be zero let seconds_diff = (get_block_timestamp() - last_record.timestamp) / 1000; - let staked_seconds = last_record.total_staked * seconds_diff.into(); // staked seconds + let staked_seconds: UFixedPoint = seconds_diff.into() / last_record.total_staked.into(); // staked seconds let total_staked = if is_add { // overflow check @@ -277,16 +286,17 @@ pub mod Staker { StakingLogRecord { timestamp: get_block_timestamp(), total_staked: total_staked, - cumulative_staked_seconds: last_record.cumulative_staked_seconds + staked_seconds, + cumulative_seconds_per_total_staked: ( + last_record.cumulative_seconds_per_total_staked + staked_seconds + ), } ); } - fn find_in_change_log(self: @ContractState, staker: ContractAddress, timestamp: u64) -> Option { + fn find_in_change_log(self: @ContractState, timestamp: u64) -> Option { // Find first log record in an array whos timestamp is less or equal to timestamp. // Uses binary search. - - let log = self.staking_log.entry(staker); + let log = self.staking_log.as_path(); if log.len() == 0 { return Option::None; @@ -296,9 +306,9 @@ pub mod Staker { let mut right = log.len() - 1; // To avoid reading from the storage multiple times. - let mut result_ptr = Option::None; + let mut result_ptr: Option> = Option::None; - while left <= right { + while (left <= right) { let center = (right + left) / 2; let record = log.at(center); @@ -307,7 +317,7 @@ pub mod Staker { left = center + 1; } else { right = center - 1; - } + }; }; if let Option::Some(result) = result_ptr { @@ -359,11 +369,33 @@ pub mod Staker { .amount_delegated .write(delegate, self.insert_snapshot(delegate, get_block_timestamp()) + amount); - self.log_change(from, amount, true); + self.log_change(amount, true); self.emit(Staked { from, delegate, amount }); } + fn calculate_staked_seconds_for_amount_between( + self: @ContractState, + token_amount: u128, + start_at: u64, end_at: u64 + ) -> u128 { + let user_cumulative_at_start: UFixedPoint = if let Option::Some(log_record) = self.find_in_change_log(start_at) { + log_record.cumulative_seconds_per_total_staked + } else { + Zero::zero() + }; + + let user_cumulative_at_end: UFixedPoint = if let Option::Some(log_record) = self.find_in_change_log(start_at) { + log_record.cumulative_seconds_per_total_staked + } else { + Zero::zero() + }; + + let res = (user_cumulative_at_end - user_cumulative_at_start) * token_amount.into(); + + return res.get_integer(); + } + fn withdraw( ref self: ContractState, delegate: ContractAddress, recipient: ContractAddress, ) { @@ -389,7 +421,7 @@ pub mod Staker { .write(delegate, self.insert_snapshot(delegate, get_block_timestamp()) - amount); assert(self.token.read().transfer(recipient, amount.into()), 'TRANSFER_FAILED'); - self.log_change(from, amount, false); + self.log_change(amount, false); self.emit(Withdrawn { from, delegate, to: recipient, amount }); } @@ -445,15 +477,13 @@ pub mod Staker { self.get_average_delegated(delegate, now - period, now) } - fn get_staked_seconds_at( - self: @ContractState, staker: ContractAddress, timestamp: u64, - ) -> u128 { - if let Option::Some(log_record) = self.find_in_change_log(staker, timestamp) { + fn get_cumulative_seconds_per_total_staked_at(self: @ContractState, timestamp: u64) -> UFixedPoint { + if let Option::Some(log_record) = self.find_in_change_log(timestamp) { let seconds_diff = (timestamp - log_record.timestamp) / 1000; - let staked_seconds = log_record.total_staked * seconds_diff.into(); // staked seconds - return log_record.cumulative_staked_seconds + staked_seconds; + let staked_seconds: UFixedPoint = seconds_diff.into() / log_record.total_staked.into(); // staked seconds + return log_record.cumulative_seconds_per_total_staked + staked_seconds; } else { - return 0; + return 0_u64.into(); } } } diff --git a/src/utils/fp.cairo b/src/utils/fp.cairo new file mode 100644 index 0000000..da9518a --- /dev/null +++ b/src/utils/fp.cairo @@ -0,0 +1,186 @@ +use starknet::storage_access::{StorePacking}; +use core::num::traits::{WideMul, Zero}; +use core::integer::{u512, u512_safe_div_rem_by_u256}; + +// 128.128 +#[derive(Drop, Copy, PartialEq)] +pub struct UFixedPoint { + pub(crate) value: u512 +} + +pub impl UFixedPointStorePacking of StorePacking { + fn pack(value: UFixedPoint) -> u256 { + value.into() + } + + fn unpack(value: u256) -> UFixedPoint { + value.into() + } +} + +pub impl UFixedPointZero of Zero { + fn zero() -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: 0, + limb1: 0, + limb2: 0, + limb3: 0, + } + } + } + + fn is_zero(self: @UFixedPoint) -> bool { + self.value.limb0 == @0 && + self.value.limb1 == @0 && + self.value.limb2 == @0 && + self.value.limb3 == @0 + } + + fn is_non_zero(self: @UFixedPoint) -> bool { !self.is_zero() } +} + +impl UFixedPointSerde of core::serde::Serde { + fn serialize(self: @UFixedPoint, ref output: Array) { + let value: u256 = (*self).try_into().unwrap(); + Serde::serialize(@value, ref output) + } + + fn deserialize(ref serialized: Span) -> Option { + let value: u256 = Serde::deserialize(ref serialized)?; + Option::Some(value.into()) + } +} + +pub(crate) impl U64IntoUFixedPoint of Into { + fn into(self: u64) -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: 0, // fractional + limb1: self.into(), // integer + limb2: 0, + limb3: 0, + } + } + } +} + +pub(crate) impl U128IntoUFixedPoint of Into { + fn into(self: u128) -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: 0, // fractional + limb1: self.into(), // integer + limb2: 0, + limb3: 0, + } + } + } +} + +pub(crate) impl U256IntoUFixedPoint of Into { + fn into(self: u256) -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: self.low, // fractional + limb1: self.high, // integer + limb2: 0, + limb3: 0, + } + } + } +} + +#[generate_trait] +pub impl UFixedPointImpl of UFixedPointTrait { + fn get_integer(self: UFixedPoint) -> u128 { + self.value.limb1 + } + + fn get_fractional(self: UFixedPoint) -> u128 { + self.value.limb0 + } +} + +#[generate_trait] +impl UFixedPointShiftImpl of BitShiftImpl { + + fn bitshift_128_up(self: UFixedPoint) -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: 0, + limb1: self.value.limb0, + limb2: self.value.limb1, + limb3: self.value.limb2, + } + } + } + + fn bitshift_128_down(self: UFixedPoint) -> UFixedPoint { + UFixedPoint { + value: u512 { + limb0: self.value.limb1, + limb1: self.value.limb2, + limb2: self.value.limb3, + limb3: 0, + } + } + } +} + +pub(crate) impl FixedPointIntoU256 of Into { + fn into(self: UFixedPoint) -> u256 { self.value.try_into().unwrap() } +} + +pub impl UFpImplAdd of Add { + fn add(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint { + let sum: u256 = rhs.into() + lhs.into(); + UFixedPoint { + value: u512 { + limb0: sum.low, + limb1: sum.high, + limb2: 0, + limb3: 0 + } + } + } +} + +pub impl UFpImplSub of Sub { + fn sub(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint { + let sum: u256 = rhs.into() - lhs.into(); + UFixedPoint { + value: u512 { + limb0: sum.low, + limb1: sum.high, + limb2: 0, + limb3: 0 + } + } + } +} + +// 20100 +pub impl UFpImplMul of Mul { + fn mul(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint { + let left: u256 = lhs.into(); + let right: u256 = rhs.into(); + + let z = left.wide_mul(right); + + UFixedPoint { value: z }.bitshift_128_down() + } +} + +pub impl UFpImplDiv of Div { + fn div(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint { + let rhs: u256 = rhs.into(); + + let (result, _) = u512_safe_div_rem_by_u256( + lhs.bitshift_128_up().value, + rhs.try_into().unwrap(), + ); + + UFixedPoint { value: result } + } +} \ No newline at end of file diff --git a/src/utils/fp_test.cairo b/src/utils/fp_test.cairo new file mode 100644 index 0000000..3f296a3 --- /dev/null +++ b/src/utils/fp_test.cairo @@ -0,0 +1,45 @@ +use crate::utils::fp::{UFixedPoint}; + + +#[test] +fn test_add() { + let f1 : UFixedPoint = 0xFFFFFFFFFFFFFFFF_u64.into(); + let f2 : UFixedPoint = 1_u64.into(); + let res = f1 + f2; + let z: u256 = res.into(); + assert(z.low == 0, 'low 0'); + assert(z.high == 18446744073709551616, 'high 18446744073709551616'); +} + +#[test] +fn test_mul() { + let f1 : UFixedPoint = 7_u64.into(); + let f2 : UFixedPoint = 7_u64.into(); + let res = f1 * f2; + let z: u256 = res.into(); + assert(z.low == 0, 'low 0'); + assert(z.high == 49, 'high 49'); +} + +#[test] +fn test_div() { + let f1 : UFixedPoint = 7_u64.into(); + let f2 : UFixedPoint = 56_u64.into(); + let res: u256 = (f2 / f1).into(); + assert(res.high == 8, 'high 8'); + assert(res.low == 0, 'low 0'); +} + +#[test] +fn test_comlex() { + let f2 : UFixedPoint = 2_u64.into(); + let f05: UFixedPoint = 1_u64.into() / f2; + let uf05: u256 = f05.into(); + let f7 : UFixedPoint = 7_u64.into(); + let f175 : UFixedPoint = 17_u64.into() + f05; + let res: u256 = (f175 / f7).into(); + assert(res.high == 2, 'high 2'); + assert(res.low == uf05.low, 'low 0.5'); +} + +// TODO(baitcode): more tests needed \ No newline at end of file