Skip to content

Commit

Permalink
Added UFixedPoint type for unsigned FixedPoint operations. Currently …
Browse files Browse the repository at this point in the history
…support Q128.128 FP.

Added cumulative seconds per total staked history value, and getter for that value
Added method that allows to calculate how much staked seconds would staked token amount generate
in between 2 timestamps: `calculate_staked_seconds_for_amount_between`
Added tiny amount of tests
  • Loading branch information
baitcode committed Dec 13, 2024
1 parent 5f32326 commit b5ac50b
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 26 deletions.
3 changes: 3 additions & 0 deletions src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mod interfaces {
}
mod utils {
pub(crate) mod exp2;
pub(crate) mod fp;
#[cfg(test)]
pub(crate) mod fp_test;
}

#[cfg(test)]
Expand Down
82 changes: 56 additions & 26 deletions src/staker.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use starknet::{ContractAddress};
use crate::utils::fp::{UFixedPoint};


#[starknet::interface]
pub trait IStaker<TContractState> {
Expand Down Expand Up @@ -54,25 +56,31 @@ pub trait IStaker<TContractState> {
) -> u128;

// Gets the cumulative staked amount * per second staked for the given timestamp and account.
fn get_staked_seconds_at(self: @TContractState, staker: ContractAddress, timestamp: u64) -> u128;
fn get_cumulative_seconds_per_total_staked_at(self: @TContractState, timestamp: u64) -> UFixedPoint;

// Gets the cumulative staked amount * per second staked for the given timestamp and account.
fn calculate_staked_seconds_for_amount_between(self: @TContractState, token_amount: u128, start_at: u64, end_at: u64) -> u128;
}


#[starknet::contract]
pub mod Staker {
use starknet::storage::StorageAsPath;
use super::super::utils::fp::UFixedPointTrait;
use core::num::traits::zero::{Zero};
use governance::interfaces::erc20::{IERC20Dispatcher, IERC20DispatcherTrait};
use starknet::storage::{
Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePathEntry,
Map, StorageMapReadAccess, StorageMapWriteAccess, StoragePathEntry, StoragePath,
StoragePointerReadAccess, StoragePointerWriteAccess,
Vec, VecTrait, MutableVecTrait,
};
use crate::utils::fp::{UFixedPoint, UFixedPointZero};

use starknet::{
get_block_timestamp, get_caller_address, get_contract_address,
storage_access::{StorePacking},
storage_access::{StorePacking}, ContractAddress,
};
use super::{ContractAddress, IStaker};
use super::{IStaker};


#[derive(Copy, Drop, PartialEq, Debug)]
Expand Down Expand Up @@ -106,7 +114,7 @@ pub mod Staker {
struct StakingLogRecord {
timestamp: u64,
total_staked: u128,
cumulative_staked_seconds: u128
cumulative_seconds_per_total_staked: UFixedPoint,
}

#[storage]
Expand All @@ -117,7 +125,8 @@ pub mod Staker {
amount_delegated: Map<ContractAddress, u128>,
delegated_cumulative_num_snapshots: Map<ContractAddress, u64>,
delegated_cumulative_snapshot: Map<ContractAddress, Map<u64, DelegatedSnapshot>>,
staking_log: Map<ContractAddress, Vec<StakingLogRecord>>,

staking_log: Vec<StakingLogRecord>,
}

#[constructor]
Expand Down Expand Up @@ -227,8 +236,8 @@ pub mod Staker {
}
}

fn log_change(ref self: ContractState, staker: ContractAddress, amount: u128, is_add: bool) {
let log = self.staking_log.entry(staker);
fn log_change(ref self: ContractState, amount: u128, is_add: bool) {
let log = self.staking_log.as_path();

if log.len() == 0 {
// Add the first record. If withdrawal, then it's underflow.
Expand All @@ -238,7 +247,7 @@ pub mod Staker {
StakingLogRecord {
timestamp: get_block_timestamp(),
total_staked: amount,
cumulative_staked_seconds: 0,
cumulative_seconds_per_total_staked: 0_u64.into(),
}
);

Expand All @@ -260,7 +269,7 @@ pub mod Staker {
// Might be zero
let seconds_diff = (get_block_timestamp() - last_record.timestamp) / 1000;

let staked_seconds = last_record.total_staked * seconds_diff.into(); // staked seconds
let staked_seconds: UFixedPoint = seconds_diff.into() / last_record.total_staked.into(); // staked seconds

let total_staked = if is_add {
// overflow check
Expand All @@ -277,16 +286,17 @@ pub mod Staker {
StakingLogRecord {
timestamp: get_block_timestamp(),
total_staked: total_staked,
cumulative_staked_seconds: last_record.cumulative_staked_seconds + staked_seconds,
cumulative_seconds_per_total_staked: (
last_record.cumulative_seconds_per_total_staked + staked_seconds
),
}
);
}

fn find_in_change_log(self: @ContractState, staker: ContractAddress, timestamp: u64) -> Option<StakingLogRecord> {
fn find_in_change_log(self: @ContractState, timestamp: u64) -> Option<StakingLogRecord> {
// Find first log record in an array whos timestamp is less or equal to timestamp.
// Uses binary search.

let log = self.staking_log.entry(staker);
let log = self.staking_log.as_path();

if log.len() == 0 {
return Option::None;
Expand All @@ -296,9 +306,9 @@ pub mod Staker {
let mut right = log.len() - 1;

// To avoid reading from the storage multiple times.
let mut result_ptr = Option::None;
let mut result_ptr: Option<StoragePath<StakingLogRecord>> = Option::None;

while left <= right {
while (left <= right) {
let center = (right + left) / 2;
let record = log.at(center);

Expand All @@ -307,7 +317,7 @@ pub mod Staker {
left = center + 1;
} else {
right = center - 1;
}
};
};

if let Option::Some(result) = result_ptr {
Expand Down Expand Up @@ -359,11 +369,33 @@ pub mod Staker {
.amount_delegated
.write(delegate, self.insert_snapshot(delegate, get_block_timestamp()) + amount);

self.log_change(from, amount, true);
self.log_change(amount, true);

self.emit(Staked { from, delegate, amount });
}

fn calculate_staked_seconds_for_amount_between(
self: @ContractState,
token_amount: u128,
start_at: u64, end_at: u64
) -> u128 {
let user_cumulative_at_start: UFixedPoint = if let Option::Some(log_record) = self.find_in_change_log(start_at) {
log_record.cumulative_seconds_per_total_staked
} else {
Zero::zero()
};

let user_cumulative_at_end: UFixedPoint = if let Option::Some(log_record) = self.find_in_change_log(start_at) {
log_record.cumulative_seconds_per_total_staked
} else {
Zero::zero()
};

let res = (user_cumulative_at_end - user_cumulative_at_start) * token_amount.into();

return res.get_integer();
}

fn withdraw(
ref self: ContractState, delegate: ContractAddress, recipient: ContractAddress,
) {
Expand All @@ -389,7 +421,7 @@ pub mod Staker {
.write(delegate, self.insert_snapshot(delegate, get_block_timestamp()) - amount);
assert(self.token.read().transfer(recipient, amount.into()), 'TRANSFER_FAILED');

self.log_change(from, amount, false);
self.log_change(amount, false);

self.emit(Withdrawn { from, delegate, to: recipient, amount });
}
Expand Down Expand Up @@ -445,15 +477,13 @@ pub mod Staker {
self.get_average_delegated(delegate, now - period, now)
}

fn get_staked_seconds_at(
self: @ContractState, staker: ContractAddress, timestamp: u64,
) -> u128 {
if let Option::Some(log_record) = self.find_in_change_log(staker, timestamp) {
fn get_cumulative_seconds_per_total_staked_at(self: @ContractState, timestamp: u64) -> UFixedPoint {
if let Option::Some(log_record) = self.find_in_change_log(timestamp) {
let seconds_diff = (timestamp - log_record.timestamp) / 1000;
let staked_seconds = log_record.total_staked * seconds_diff.into(); // staked seconds
return log_record.cumulative_staked_seconds + staked_seconds;
let staked_seconds: UFixedPoint = seconds_diff.into() / log_record.total_staked.into(); // staked seconds
return log_record.cumulative_seconds_per_total_staked + staked_seconds;
} else {
return 0;
return 0_u64.into();
}
}
}
Expand Down
186 changes: 186 additions & 0 deletions src/utils/fp.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
use starknet::storage_access::{StorePacking};
use core::num::traits::{WideMul, Zero};
use core::integer::{u512, u512_safe_div_rem_by_u256};

// 128.128
#[derive(Drop, Copy, PartialEq)]
pub struct UFixedPoint {
pub(crate) value: u512
}

pub impl UFixedPointStorePacking of StorePacking<UFixedPoint, u256> {
fn pack(value: UFixedPoint) -> u256 {
value.into()
}

fn unpack(value: u256) -> UFixedPoint {
value.into()
}
}

pub impl UFixedPointZero of Zero<UFixedPoint> {
fn zero() -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: 0,
limb1: 0,
limb2: 0,
limb3: 0,
}
}
}

fn is_zero(self: @UFixedPoint) -> bool {
self.value.limb0 == @0 &&
self.value.limb1 == @0 &&
self.value.limb2 == @0 &&
self.value.limb3 == @0
}

fn is_non_zero(self: @UFixedPoint) -> bool { !self.is_zero() }
}

impl UFixedPointSerde of core::serde::Serde<UFixedPoint> {
fn serialize(self: @UFixedPoint, ref output: Array<felt252>) {
let value: u256 = (*self).try_into().unwrap();
Serde::serialize(@value, ref output)
}

fn deserialize(ref serialized: Span<felt252>) -> Option<UFixedPoint> {
let value: u256 = Serde::deserialize(ref serialized)?;
Option::Some(value.into())
}
}

pub(crate) impl U64IntoUFixedPoint of Into<u64, UFixedPoint> {
fn into(self: u64) -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: 0, // fractional
limb1: self.into(), // integer
limb2: 0,
limb3: 0,
}
}
}
}

pub(crate) impl U128IntoUFixedPoint of Into<u128, UFixedPoint> {
fn into(self: u128) -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: 0, // fractional
limb1: self.into(), // integer
limb2: 0,
limb3: 0,
}
}
}
}

pub(crate) impl U256IntoUFixedPoint of Into<u256, UFixedPoint> {
fn into(self: u256) -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: self.low, // fractional
limb1: self.high, // integer
limb2: 0,
limb3: 0,
}
}
}
}

#[generate_trait]
pub impl UFixedPointImpl of UFixedPointTrait {
fn get_integer(self: UFixedPoint) -> u128 {
self.value.limb1
}

fn get_fractional(self: UFixedPoint) -> u128 {
self.value.limb0
}
}

#[generate_trait]
impl UFixedPointShiftImpl of BitShiftImpl {

fn bitshift_128_up(self: UFixedPoint) -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: 0,
limb1: self.value.limb0,
limb2: self.value.limb1,
limb3: self.value.limb2,
}
}
}

fn bitshift_128_down(self: UFixedPoint) -> UFixedPoint {
UFixedPoint {
value: u512 {
limb0: self.value.limb1,
limb1: self.value.limb2,
limb2: self.value.limb3,
limb3: 0,
}
}
}
}

pub(crate) impl FixedPointIntoU256 of Into<UFixedPoint, u256> {
fn into(self: UFixedPoint) -> u256 { self.value.try_into().unwrap() }
}

pub impl UFpImplAdd of Add<UFixedPoint> {
fn add(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint {
let sum: u256 = rhs.into() + lhs.into();
UFixedPoint {
value: u512 {
limb0: sum.low,
limb1: sum.high,
limb2: 0,
limb3: 0
}
}
}
}

pub impl UFpImplSub of Sub<UFixedPoint> {
fn sub(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint {
let sum: u256 = rhs.into() - lhs.into();
UFixedPoint {
value: u512 {
limb0: sum.low,
limb1: sum.high,
limb2: 0,
limb3: 0
}
}
}
}

// 20100
pub impl UFpImplMul of Mul<UFixedPoint> {
fn mul(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint {
let left: u256 = lhs.into();
let right: u256 = rhs.into();

let z = left.wide_mul(right);

UFixedPoint { value: z }.bitshift_128_down()
}
}

pub impl UFpImplDiv of Div<UFixedPoint> {
fn div(lhs: UFixedPoint, rhs: UFixedPoint) -> UFixedPoint {
let rhs: u256 = rhs.into();

let (result, _) = u512_safe_div_rem_by_u256(
lhs.bitshift_128_up().value,
rhs.try_into().unwrap(),
);

UFixedPoint { value: result }
}
}
Loading

0 comments on commit b5ac50b

Please sign in to comment.