From d44598201f5e89177b089758c00e7a94bc1c9ac2 Mon Sep 17 00:00:00 2001 From: Mateo Date: Thu, 12 Dec 2024 15:05:11 +0100 Subject: [PATCH] wip --- .../driver/src/domain/competition/auction.rs | 39 +++++-- .../src/domain/competition/bad_tokens.rs | 105 +++++++++++++----- crates/driver/src/domain/competition/mod.rs | 19 +++- .../src/domain/competition/order/mod.rs | 4 +- crates/driver/src/infra/api/mod.rs | 46 ++++++-- .../driver/src/infra/api/routes/solve/mod.rs | 2 +- crates/driver/src/infra/config/file/load.rs | 5 + crates/driver/src/infra/config/file/mod.rs | 60 +++++++++- crates/driver/src/infra/config/mod.rs | 3 +- crates/driver/src/infra/solver/mod.rs | 24 +++- crates/driver/src/run.rs | 1 + 11 files changed, 249 insertions(+), 59 deletions(-) diff --git a/crates/driver/src/domain/competition/auction.rs b/crates/driver/src/domain/competition/auction.rs index 9b83bf6e19..3c957aed46 100644 --- a/crates/driver/src/domain/competition/auction.rs +++ b/crates/driver/src/domain/competition/auction.rs @@ -1,5 +1,5 @@ use { - super::{bad_tokens, order, Order}, + super::{order, Order}, crate::{ domain::{ competition::{self, auction, sorting}, @@ -14,12 +14,11 @@ use { futures::future::{join_all, BoxFuture, FutureExt, Shared}, itertools::Itertools, model::{order::OrderKind, signature::Signature}, - shared::{ - bad_token::trace_call::TraceCallDetectorRaw, - signature_validator::{Contracts, SignatureValidating}, - }, + shared::signature_validator::{Contracts, SignatureValidating}, std::{ collections::{HashMap, HashSet}, + future::Future, + pin::Pin, sync::{Arc, Mutex}, }, thiserror::Error, @@ -76,6 +75,32 @@ impl Auction { }) } + /// Filter the orders according to the funcion `filter_fn` provided. + /// The function `filter_fn` must return an `Option`, with `None` + /// indicating that the order has to be filtered. + /// This is needed due to the lack of `filter()` async closure support. + pub async fn filter_orders(&mut self, filter_fn: F) + where + F: Fn( + competition::Order, + ) -> Pin> + Send>> + + Send, + { + let futures = self + .orders + .drain(..) + .map(|order| { + let filter_fn = &filter_fn; + async move { filter_fn(order).await } + }) + .collect::>(); + self.orders = futures::future::join_all(futures) + .await + .into_iter() + .flatten() + .collect(); + } + /// [`None`] if this auction applies to a quote. See /// [`crate::domain::quote`]. pub fn id(&self) -> Option { @@ -482,10 +507,6 @@ impl AuctionProcessor { Self(Arc::new(Mutex::new(Inner { auction: Id(0), fut: futures::future::pending().boxed().shared(), - bad_token_detector: TraceCallDetectorRaw::new( - eth.web3().clone(), - eth.contracts().settlement().address(), - ), eth, order_sorting_strategies, signature_validator, diff --git a/crates/driver/src/domain/competition/bad_tokens.rs b/crates/driver/src/domain/competition/bad_tokens.rs index 04608ee12c..34abea5d27 100644 --- a/crates/driver/src/domain/competition/bad_tokens.rs +++ b/crates/driver/src/domain/competition/bad_tokens.rs @@ -1,11 +1,11 @@ use { super::Order, crate::{ - domain::{self, eth}, - infra, + domain::{competition::Auction, eth}, + infra::{self, config::file::BadTokenDetectionCache}, }, - anyhow::Result, - dashmap::{DashMap, Entry, OccupiedEntry, VacantEntry}, + dashmap::{DashMap, Entry}, + futures::FutureExt, model::interaction::InteractionData, shared::bad_token::{trace_call::TraceCallDetectorRaw, TokenQuality}, std::{ @@ -42,7 +42,7 @@ pub struct Detector { hardcoded: HashMap, /// cache which is shared and updated by multiple bad token detection /// mechanisms - cache: Cache, + cache: Arc, simulation_detector: Option, metrics: Option, } @@ -65,22 +65,73 @@ impl Detector { self } - pub fn filter_unsupported_orders(&self, mut orders: Vec) -> Vec { + pub fn with_cache(mut self, cache: Arc) -> Self { + self.cache = cache; + self + } + + /// Filter all unsupported orders within an Auction + pub async fn filter_unsupported_orders_in_auction( + self: Arc, + mut auction: Auction, + ) -> Auction { let now = Instant::now(); - // group by sell tokens? - // future calling `determine_sell_token_quality()` for all of orders + let self_clone = self.clone(); + + auction + .filter_orders(move |order| { + { + let self_clone = self_clone.clone(); + async move { + // We first check the token quality: + // - If both tokens are supported, the order does is not filtered + // - If any of the order tokens is unsupported, the order is filtered + // - If the token quality cannot be determined: call + // `determine_sell_token_quality()` to execute the simulation + // All of these operations are done within the same `.map()` in order to + // avoid iterating twice over the orders vector + let tokens_quality = [order.sell.token, order.buy.token] + .iter() + .map(|token| self_clone.get_token_quality(*token, now)) + .collect::>(); + let both_tokens_supported = tokens_quality + .iter() + .all(|token_quality| *token_quality == Some(Quality::Supported)); + let any_token_unsupported = tokens_quality + .iter() + .any(|token_quality| *token_quality == Some(Quality::Unsupported)); + + // @TODO: remove the bad tokens from the tokens field? + + // If both tokens are supported, the order does is not filtered + if both_tokens_supported { + return Some(order); + } - orders.retain(|o| { - [o.sell.token, o.buy.token].iter().all(|token| { - self.get_token_quality(*token, now) - .is_none_or(|q| q == Quality::Supported) + // If any of the order tokens is unsupported, the order is filtered + if any_token_unsupported { + return None; + } + + // If the token quality cannot be determined: call + // `determine_sell_token_quality()` to execute the simulation + if self_clone.determine_sell_token_quality(&order, now).await + == Some(Quality::Supported) + { + return Some(order); + } + + None + } + } + .boxed() }) - }); + .await; self.cache.evict_outdated_entries(); - orders + auction } fn get_token_quality(&self, token: eth::TokenAddress, now: Instant) -> Option { @@ -99,12 +150,11 @@ impl Detector { None } - pub async fn determine_sell_token_quality( - &self, - detector: &TraceCallDetectorRaw, - order: &Order, - now: Instant, - ) -> Option { + async fn determine_sell_token_quality(&self, order: &Order, now: Instant) -> Option { + let Some(detector) = self.simulation_detector.as_ref() else { + return None; + }; + if let Some(quality) = self.cache.get_quality(order.sell.token, now) { return Some(quality); } @@ -122,7 +172,7 @@ impl Detector { match detector .test_transfer( - order.trader().0 .0, + eth::Address::from(order.trader()).0, token.0 .0, order.sell.amount.0, &pre_interactions, @@ -164,8 +214,6 @@ pub struct Cache { cache: DashMap, /// entries older than this get ignored and evicted max_age: Duration, - /// evicts entries when the cache grows beyond this size - max_size: usize, } struct CacheEntry { @@ -177,18 +225,17 @@ struct CacheEntry { impl Default for Cache { fn default() -> Self { - Self::new(Duration::from_secs(60 * 10), 1000) + Self::new(&BadTokenDetectionCache::default()) } } impl Cache { /// Creates a new instance which evicts cached values after a period of /// time. - pub fn new(max_age: Duration, max_size: usize) -> Self { + pub fn new(bad_token_detection_cache: &BadTokenDetectionCache) -> Self { Self { - max_age, - max_size, - cache: Default::default(), + max_age: bad_token_detection_cache.max_age, + cache: DashMap::with_capacity(bad_token_detection_cache.max_size), } } @@ -243,7 +290,7 @@ impl Cache { struct Metrics {} impl Metrics { - fn get_quality(&self, token: eth::TokenAddress) -> Option { + fn get_quality(&self, _token: eth::TokenAddress) -> Option { todo!() } } diff --git a/crates/driver/src/domain/competition/mod.rs b/crates/driver/src/domain/competition/mod.rs index ddebeff32b..8b30fcd0ba 100644 --- a/crates/driver/src/domain/competition/mod.rs +++ b/crates/driver/src/domain/competition/mod.rs @@ -22,7 +22,7 @@ use { std::{ cmp::Reverse, collections::{HashMap, HashSet, VecDeque}, - sync::Mutex, + sync::{Arc, Mutex}, }, tap::TapFallible, }; @@ -53,15 +53,22 @@ pub struct Competition { pub mempools: Mempools, /// Cached solutions with the most recent solutions at the front. pub settlements: Mutex>, - // TODO: single type should have the feature set to simulate - pub bad_tokens: bad_tokens::Detector, + pub bad_tokens: Option>, } impl Competition { /// Solve an auction as part of this competition. - pub async fn solve(&self, auction: &Auction) -> Result, Error> { - // 1. simulate sell tokens - // 2. filter bad tokens + pub async fn solve(&self, mut auction: Auction) -> Result, Error> { + // filter orders in auction which contain a bad tokens if the bad token + // detection is configured + if let Some(bad_tokens) = self.bad_tokens.as_ref() { + auction = bad_tokens + .clone() + .filter_unsupported_orders_in_auction(auction) + .await; + } + // Enforces Auction not to be consumed by making it as a shared reference + let auction = &auction; let liquidity = match self.solver.liquidity() { solver::Liquidity::Fetch => { diff --git a/crates/driver/src/domain/competition/order/mod.rs b/crates/driver/src/domain/competition/order/mod.rs index f833b48de4..3d2857be84 100644 --- a/crates/driver/src/domain/competition/order/mod.rs +++ b/crates/driver/src/domain/competition/order/mod.rs @@ -371,8 +371,8 @@ impl From for BuyTokenBalance { } /// The address which placed the order. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Into)] -pub struct Trader(pub eth::Address); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Into, From)] +pub struct Trader(eth::Address); /// A just-in-time order. JIT orders are added at solving time by the solver to /// generate a more optimal solution for the auction. Very similar to a regular diff --git a/crates/driver/src/infra/api/mod.rs b/crates/driver/src/infra/api/mod.rs index 37375236ed..11145c2f84 100644 --- a/crates/driver/src/infra/api/mod.rs +++ b/crates/driver/src/infra/api/mod.rs @@ -1,9 +1,17 @@ use { crate::{ - domain::{self, competition::bad_tokens, Mempools}, + domain::{ + self, + competition::{ + bad_tokens, + bad_tokens::{Cache, Quality}, + }, + eth, + Mempools, + }, infra::{ self, - config::file::OrderPriorityStrategy, + config::file::{BadTokenDetectionCache, OrderPriorityStrategy}, liquidity, solver::{Solver, Timeouts}, tokens, @@ -13,7 +21,7 @@ use { }, error::Error, futures::Future, - std::{net::SocketAddr, sync::Arc}, + std::{collections::HashMap, net::SocketAddr, sync::Arc}, tokio::sync::oneshot, }; @@ -32,6 +40,7 @@ pub struct Api { /// If this channel is specified, the bound address will be sent to it. This /// allows the driver to bind to 0.0.0.0:0 during testing. pub addr_sender: Option>, + pub bad_token_detection_cache: BadTokenDetectionCache, } impl Api { @@ -52,8 +61,9 @@ impl Api { let tokens = tokens::Fetcher::new(&self.eth); let pre_processor = domain::competition::AuctionProcessor::new(&self.eth, order_priority_strategies); - let trace_detector = bad_tokens::SimulationDetector::new(&self.eth); - let miep = bad_tokens::Detector::default().register_cache(trace_detector.cache().clone()); + + // TODO: create a struct wrapper to handle this under the hood + let trace_detector = Arc::new(Cache::new(&self.bad_token_detection_cache)); // Add the metrics and healthz endpoints. app = routes::metrics(app); @@ -72,8 +82,28 @@ impl Api { let router = routes::reveal(router); let router = routes::settle(router); - let miep = - bad_tokens::Detector::default().register_cache(trace_detector.cache().clone()); + let bad_tokens = solver.bad_token_detector().and_then(|bad_token_detector| { + // maybe make this as part of the bad token builder? + let config = bad_token_detector + .unsupported_tokens + .iter() + .map(|token| (eth::TokenAddress::from(*token), Quality::Unsupported)) + .chain( + bad_token_detector + .allowed_tokens + .iter() + .map(|token| (eth::TokenAddress::from(*token), Quality::Supported)), + ) + .collect::>(); + + Some(Arc::new( + // maybe do proper builder pattern here? + bad_tokens::Detector::default() + .with_simulation_detector(&self.eth.clone()) + .with_config(config) + .with_cache(trace_detector.clone()), + )) + }); let router = router.with_state(State(Arc::new(Inner { eth: self.eth.clone(), @@ -85,7 +115,7 @@ impl Api { simulator: self.simulator.clone(), mempools: self.mempools.clone(), settlements: Default::default(), - bad_tokens: miep, + bad_tokens, }, liquidity: self.liquidity.clone(), tokens: tokens.clone(), diff --git a/crates/driver/src/infra/api/routes/solve/mod.rs b/crates/driver/src/infra/api/routes/solve/mod.rs index eccafb8ead..072de4bceb 100644 --- a/crates/driver/src/infra/api/routes/solve/mod.rs +++ b/crates/driver/src/infra/api/routes/solve/mod.rs @@ -36,7 +36,7 @@ async fn route( .pre_processor() .prioritize(auction, &competition.solver.account().address()) .await; - let result = competition.solve(&auction).await; + let result = competition.solve(auction).await; observe::solved(state.solver().name(), &result); Ok(axum::Json(dto::Solved::new(result?, &competition.solver))) }; diff --git a/crates/driver/src/infra/config/file/load.rs b/crates/driver/src/infra/config/file/load.rs index b235106448..ed22caa278 100644 --- a/crates/driver/src/infra/config/file/load.rs +++ b/crates/driver/src/infra/config/file/load.rs @@ -94,6 +94,10 @@ pub async fn load(chain: chain::Id, path: &Path) -> infra::Config { solver_native_token: config.manage_native_token.to_domain(), quote_tx_origin: config.quote_tx_origin.map(eth::Address), response_size_limit_max_bytes: config.response_size_limit_max_bytes, + bad_token_detector: config + .bad_token_detector + .filter(|bad_token_detector| bad_token_detector.enabled) + .map(Into::into), } })) .await, @@ -340,5 +344,6 @@ pub async fn load(chain: chain::Id, path: &Path) -> infra::Config { gas_estimator: config.gas_estimator, order_priority_strategies: config.order_priority_strategies, archive_node_url: config.archive_node_url, + bad_token_detection_cache: config.bad_token_detection_cache, } } diff --git a/crates/driver/src/infra/config/file/mod.rs b/crates/driver/src/infra/config/file/mod.rs index b461883636..eac55ce173 100644 --- a/crates/driver/src/infra/config/file/mod.rs +++ b/crates/driver/src/infra/config/file/mod.rs @@ -5,7 +5,10 @@ use { serde::{Deserialize, Serialize}, serde_with::serde_as, solver::solver::Arn, - std::{collections::HashMap, time::Duration}, + std::{ + collections::{HashMap, HashSet}, + time::Duration, + }, }; mod load; @@ -65,6 +68,9 @@ struct Config { /// Archive node URL used to index CoW AMM archive_node_url: Option, + + /// Cache configuration for the bad tokend detection + bad_token_detection_cache: BadTokenDetectionCache, } #[serde_as] @@ -260,6 +266,10 @@ struct SolverConfig { /// Maximum HTTP response size the driver will accept in bytes. #[serde(default = "default_response_size_limit_max_bytes")] response_size_limit_max_bytes: usize, + + /// Bad token detector configuration + #[serde(default)] + bad_token_detector: Option, } #[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Serialize)] @@ -651,6 +661,54 @@ fn default_order_priority_strategies() -> Vec { ] } +/// Bad token detector configuration +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct BadTokenDetector { + /// Whether or not the bad token detector is enabled + #[serde(default = "bool::default")] + pub enabled: bool, + /// List of tokens which will be directly allowed, no detection will be run + /// on them + #[serde(default = "HashSet::new")] + pub allowed_tokens: HashSet, + /// List of tokens which will be directly unsupported + #[serde(default = "HashSet::new")] + pub unsupported_tokens_tokens: HashSet, +} + fn default_max_order_age() -> Option { Some(Duration::from_secs(300)) } + +/// Cache configuration for the bad token detection +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct BadTokenDetectionCache { + /// Entries older than `max_age` will get ignored and evicted + #[serde( + with = "humantime_serde", + default = "default_bad_token_detection_cache_max_age" + )] + pub max_age: Duration, + /// Maximum number of tokens the cache can have + #[serde(default = "default_bad_token_detection_cache_max_size")] + pub max_size: usize, +} + +impl Default for BadTokenDetectionCache { + fn default() -> Self { + Self { + max_age: default_bad_token_detection_cache_max_age(), + max_size: default_bad_token_detection_cache_max_size(), + } + } +} + +fn default_bad_token_detection_cache_max_age() -> Duration { + Duration::from_secs(600) +} + +fn default_bad_token_detection_cache_max_size() -> usize { + 1000 +} diff --git a/crates/driver/src/infra/config/mod.rs b/crates/driver/src/infra/config/mod.rs index 94ef8821a5..8fdbcfde95 100644 --- a/crates/driver/src/infra/config/mod.rs +++ b/crates/driver/src/infra/config/mod.rs @@ -3,7 +3,7 @@ use { domain::eth, infra::{ blockchain, - config::file::{GasEstimatorType, OrderPriorityStrategy}, + config::file::{BadTokenDetectionCache, GasEstimatorType, OrderPriorityStrategy}, liquidity, mempool, simulator, @@ -28,4 +28,5 @@ pub struct Config { pub contracts: blockchain::contracts::Addresses, pub order_priority_strategies: Vec, pub archive_node_url: Option, + pub bad_token_detection_cache: BadTokenDetectionCache, } diff --git a/crates/driver/src/infra/solver/mod.rs b/crates/driver/src/infra/solver/mod.rs index ba7332dd65..357ac82d68 100644 --- a/crates/driver/src/infra/solver/mod.rs +++ b/crates/driver/src/infra/solver/mod.rs @@ -12,7 +12,7 @@ use { }, infra::{ blockchain::Ethereum, - config::file::FeeHandler, + config::{self, file::FeeHandler}, persistence::{Persistence, S3}, }, util, @@ -21,7 +21,7 @@ use { derive_more::{From, Into}, num::BigRational, reqwest::header::HeaderName, - std::collections::HashMap, + std::collections::{HashMap, HashSet}, tap::TapFallible, thiserror::Error, tracing::Instrument, @@ -123,6 +123,7 @@ pub struct Config { /// Which `tx.origin` is required to make quote verification pass. pub quote_tx_origin: Option, pub response_size_limit_max_bytes: usize, + pub bad_token_detector: Option, } impl Solver { @@ -151,6 +152,10 @@ impl Solver { }) } + pub fn bad_token_detector(&self) -> Option<&BadTokenDetector> { + self.config.bad_token_detector.as_ref() + } + pub fn persistence(&self) -> Persistence { self.persistence.clone() } @@ -277,6 +282,21 @@ pub enum SolutionMerging { Forbidden, } +#[derive(Debug, Clone)] +pub struct BadTokenDetector { + pub allowed_tokens: HashSet, + pub unsupported_tokens: HashSet, +} + +impl From for BadTokenDetector { + fn from(value: config::file::BadTokenDetector) -> Self { + Self { + allowed_tokens: value.allowed_tokens, + unsupported_tokens: value.unsupported_tokens_tokens, + } + } +} + #[derive(Debug, Error)] pub enum Error { #[error("HTTP error: {0:?}")] diff --git a/crates/driver/src/run.rs b/crates/driver/src/run.rs index 68dd25b9d5..177cc9fe71 100644 --- a/crates/driver/src/run.rs +++ b/crates/driver/src/run.rs @@ -69,6 +69,7 @@ async fn run_with(args: cli::Args, addr_sender: Option