Skip to content

Commit

Permalink
Merge pull request #44 from cowprotocol/refactoring
Browse files Browse the repository at this point in the history
Refactoring
  • Loading branch information
harisang authored Aug 28, 2024
2 parents bc574c3 + 28d58cd commit 4dec721
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 54 deletions.
18 changes: 16 additions & 2 deletions src/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@


def main() -> None:
# valid chain names: mainnet, xdai, arbitrum-one
# valid chain names: mainnet, xdai, arbitrum_one
chain_name = os.getenv("CHAIN_NAME")
if chain_name is None:
logger.error("CHAIN_NAME environment variable is not set.")
return

process_imbalances = True
process_fees = True
process_prices = True

web3, db_engine = initialize_connections()
blockchain = BlockchainData(web3)
db = Database(db_engine, chain_name)
processor = TransactionProcessor(blockchain, db, chain_name)

if chain_name == "arbitrum_one":
process_imbalances = False
process_prices = False

if chain_name == "xdai":
process_prices = False

processor = TransactionProcessor(
blockchain, db, chain_name, process_imbalances, process_fees, process_prices
)

start_block = processor.get_start_block()
processor.process(start_block)
Expand Down
1 change: 1 addition & 0 deletions src/sql/delete_entries_max_block.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ AND block_number >= :block_number;
DELETE FROM fees
WHERE chain_name = :chain_name
AND block_number >= :block_number;

COMMIT;
227 changes: 175 additions & 52 deletions src/transaction_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@
class TransactionProcessor:
"""Class processes transactions for the slippage project."""

def __init__(self, blockchain_data: BlockchainData, db: Database, chain_name: str):
def __init__(
self,
blockchain_data: BlockchainData,
db: Database,
chain_name: str,
process_imbalances: bool,
process_fees: bool,
process_prices: bool,
):
self.blockchain_data = blockchain_data
self.db = db
self.chain_name = chain_name
self.process_imbalances = process_imbalances
self.process_fees = process_fees
self.process_prices = process_prices

self.imbalances = RawTokenImbalances(self.blockchain_data.web3, self.chain_name)
self.price_providers = PriceFeed()
self.log_message: list[str] = []

def get_start_block(self) -> int:
"""
Expand Down Expand Up @@ -84,71 +97,181 @@ def process(self, start_block: int) -> None:

def process_single_transaction(
self, tx_hash: str, auction_id: int, block_number: int
):
"""Function processes a single tx to find imbalances, prices."""
) -> None:
"""Function processes a single tx to find imbalances, fees, prices including writing to database."""
self.log_message = []
try:
token_imbalances = self.imbalances.compute_imbalances(tx_hash)
except Exception as e:
logger.error(f"Failed to compute imbalances for transaction {tx_hash}: {e}")
return
# Compute Raw Token Imbalances
if self.process_imbalances:
token_imbalances = self.process_token_imbalances(
tx_hash, auction_id, block_number
)

# Compute Fees
if self.process_fees:
protocol_fees, network_fees = self.process_fees_for_transaction(
tx_hash, auction_id, block_number
)

# Compute Prices
if self.process_prices:
prices = self.process_prices_for_tokens(
token_imbalances, protocol_fees, network_fees, block_number, tx_hash
)

# Write to database iff no errors in either computations
if (
(not self.process_imbalances)
and (not self.process_fees)
and (not self.process_prices)
):
return

if self.process_imbalances and token_imbalances:
self.handle_imbalances(
token_imbalances, tx_hash, auction_id, block_number
)

log_message: list[str] = []
log_message.append(f"Token Imbalances on {self.chain_name} for tx {tx_hash}:")
for token_address, imbalance in token_imbalances.items():
# write imbalance to table if it's non-zero
if imbalance != 0:
self.db.write_token_imbalances(
tx_hash, auction_id, block_number, token_address, imbalance
if self.process_fees:
self.handle_fees(
protocol_fees, network_fees, auction_id, block_number, tx_hash
)
log_message.append(f"Token: {token_address}, Imbalance: {imbalance}")

protocol_fees, network_fees = batch_fee_imbalances(HexBytes(tx_hash))
self.handle_fees(protocol_fees, network_fees, auction_id, block_number, tx_hash)
slippage = calculate_slippage(token_imbalances, protocol_fees, network_fees)
if self.process_prices and prices:
self.handle_prices(prices, tx_hash, block_number)

logger.info("\n".join(self.log_message))

for token_address in slippage.keys():
# fetch price for tokens with non-zero imbalance and write to table
if slippage[token_address] != 0:
price_data = self.price_providers.get_price(
set_params(token_address, block_number, tx_hash)
except Exception as err:
logger.error(f"An Error occurred: {err}")
return

def process_token_imbalances(
self, tx_hash: str, auction_id: int, block_number: int
) -> dict[str, int]:
"""Process token imbalances for a given transaction and return imbalances."""
try:
token_imbalances = self.imbalances.compute_imbalances(tx_hash)
if token_imbalances:
self.log_message.append(
f"Token Imbalances on {self.chain_name} for tx {tx_hash}:"
)
if price_data:
price, source = price_data
self.db.write_prices(
source, block_number, tx_hash, token_address, price
return token_imbalances
except Exception as e:
logger.error(f"Failed to compute imbalances for transaction {tx_hash}: {e}")
return {}

def process_fees_for_transaction(
self, tx_hash: str, auction_id: int, block_number: int
) -> tuple[dict[str, int], dict[str, int]]:
"""Process and return protocol and network fees for a given transaction."""
try:
protocol_fees, network_fees = batch_fee_imbalances(HexBytes(tx_hash))
return protocol_fees, network_fees
except Exception as e:
logger.error(f"Failed to process fees for transaction {tx_hash}: {e}")
return {}, {}

def process_prices_for_tokens(
self,
token_imbalances: dict[str, int],
protocol_fees: dict[str, int],
network_fees: dict[str, int],
block_number: int,
tx_hash: str,
) -> dict[str, tuple[float, str]]:
"""Compute prices for tokens with non-null imbalances."""
prices = {}
try:
slippage = calculate_slippage(token_imbalances, protocol_fees, network_fees)
for token_address in slippage.keys():
if slippage[token_address] != 0:
price_data = self.price_providers.get_price(
set_params(token_address, block_number, tx_hash)
)
log_message.append(f"Token: {token_address}, Price: {price} ETH")
if price_data:
price, source = price_data
prices[token_address] = (price, source)
except Exception as e:
logger.error(f"Failed to process prices for transaction {tx_hash}: {e}")

logger.info("\n".join(log_message))
return prices

def handle_imbalances(
self,
token_imbalances: dict[str, int],
tx_hash: str,
auction_id: int,
block_number: int,
) -> None:
"""Function loops over non-null raw imbalances and writes them to the database."""
try:
for token_address, imbalance in token_imbalances.items():
if imbalance != 0:
self.db.write_token_imbalances(
tx_hash,
auction_id,
block_number,
token_address,
imbalance,
)
self.log_message.append(
f"Token: {token_address}, Imbalance: {imbalance}"
)
except Exception as err:
logger.error(f"Error: {err}")

def handle_fees(
self, protocol_fees, network_fees, auction_id, block_number, tx_hash
):
self,
protocol_fees: dict[str, int],
network_fees: dict[str, int],
auction_id: int,
block_number: int,
tx_hash: str,
) -> None:
"""This function loops over (token, fee) and calls write_fees to write to table."""
# Write protocol fees
for token_address, fee_amount in protocol_fees.items():
self.db.write_fees(
chain_name=self.chain_name,
auction_id=auction_id,
block_number=block_number,
tx_hash=tx_hash,
token_address=token_address,
fee_amount=float(fee_amount),
fee_type="protocol",
)
try:
# Write protocol fees
for token_address, fee_amount in protocol_fees.items():
self.db.write_fees(
chain_name=self.chain_name,
auction_id=auction_id,
block_number=block_number,
tx_hash=tx_hash,
token_address=token_address,
fee_amount=float(fee_amount),
fee_type="protocol",
)

# Write network fees
for token_address, fee_amount in network_fees.items():
self.db.write_fees(
chain_name=self.chain_name,
auction_id=auction_id,
block_number=block_number,
tx_hash=tx_hash,
token_address=token_address,
fee_amount=float(fee_amount),
fee_type="network",
# Write network fees
for token_address, fee_amount in network_fees.items():
self.db.write_fees(
chain_name=self.chain_name,
auction_id=auction_id,
block_number=block_number,
tx_hash=tx_hash,
token_address=token_address,
fee_amount=float(fee_amount),
fee_type="network",
)
except Exception as e:
logger.error(
f"Failed to write fees to database for transaction {tx_hash}: {e}"
)

def handle_prices(
self, prices: dict[str, tuple[float, str]], tx_hash: str, block_number: int
) -> None:
"""Function writes prices to table per token."""
try:
for token_address, (price, source) in prices.items():
self.db.write_prices(
source, block_number, tx_hash, token_address, price
)
self.log_message.append(f"Token: {token_address}, Price: {price} ETH")
except Exception as err:
logger.error(f"Error: {err}")


def calculate_slippage(
token_imbalances: dict[str, int],
Expand Down

0 comments on commit 4dec721

Please sign in to comment.