Skip to content

Commit

Permalink
Merge pull request flashbots#211 from flashbots/faster-writes
Browse files Browse the repository at this point in the history
Use COPY to speed up database writes for blocks and traces
  • Loading branch information
lukevs authored Jan 4, 2022
2 parents 873296a + 060cd74 commit 8d7ffa6
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 37 deletions.
18 changes: 7 additions & 11 deletions mev_inspect/crud/blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import List

from mev_inspect.db import write_as_csv
from mev_inspect.schemas.blocks import Block


Expand Down Expand Up @@ -28,16 +29,11 @@ def write_blocks(
db_session,
blocks: List[Block],
) -> None:
block_params = [
{
"block_number": block.block_number,
"block_timestamp": datetime.fromtimestamp(block.block_timestamp),
}
items_generator = (
(
block.block_number,
datetime.fromtimestamp(block.block_timestamp),
)
for block in blocks
]

db_session.execute(
"INSERT INTO blocks (block_number, block_timestamp) VALUES (:block_number, :block_timestamp)",
params=block_params,
)
db_session.commit()
write_as_csv(db_session, "blocks", items_generator)
57 changes: 32 additions & 25 deletions mev_inspect/crud/traces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from datetime import datetime, timezone
from typing import List

from mev_inspect.db import to_postgres_list, write_as_csv
from mev_inspect.models.traces import ClassifiedTraceModel
from mev_inspect.schemas.traces import ClassifiedTrace

Expand All @@ -26,30 +28,35 @@ def write_classified_traces(
db_session,
classified_traces: List[ClassifiedTrace],
) -> None:
models = []
for trace in classified_traces:
inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],)
models.append(
ClassifiedTraceModel(
transaction_hash=trace.transaction_hash,
transaction_position=trace.transaction_position,
block_number=trace.block_number,
classification=trace.classification.value,
trace_type=trace.type.value,
trace_address=trace.trace_address,
protocol=str(trace.protocol),
abi_name=trace.abi_name,
function_name=trace.function_name,
function_signature=trace.function_signature,
inputs=inputs_json,
from_address=trace.from_address,
to_address=trace.to_address,
gas=trace.gas,
value=trace.value,
gas_used=trace.gas_used,
error=trace.error,
)
classified_at = datetime.now(timezone.utc)
items = (
(
classified_at,
trace.transaction_hash,
trace.block_number,
trace.classification.value,
trace.type.value,
str(trace.protocol),
trace.abi_name,
trace.function_name,
trace.function_signature,
_inputs_as_json(trace),
trace.from_address,
trace.to_address,
trace.gas,
trace.value,
trace.gas_used,
trace.error,
to_postgres_list(trace.trace_address),
trace.transaction_position,
)
for trace in classified_traces
)

db_session.bulk_save_objects(models)
db_session.commit()
write_as_csv(db_session, "classified_traces", items)


def _inputs_as_json(trace) -> str:
inputs = json.dumps(json.loads(trace.json(include={"inputs"}))["inputs"])
inputs_with_array = f"[{inputs}]"
return inputs_with_array
30 changes: 29 additions & 1 deletion mev_inspect/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import Optional
from typing import Any, Iterable, List, Optional

from sqlalchemy import create_engine, orm
from sqlalchemy.orm import sessionmaker

from mev_inspect.string_io import StringIteratorIO


def get_trace_database_uri() -> Optional[str]:
username = os.getenv("TRACE_DB_USER")
Expand Down Expand Up @@ -63,3 +65,29 @@ def get_trace_session() -> Optional[orm.Session]:
return Session()

return None


def write_as_csv(
db_session,
table_name: str,
items: Iterable[Iterable[Any]],
) -> None:
csv_iterator = StringIteratorIO(
("|".join(map(_clean_csv_value, item)) + "\n" for item in items)
)

with db_session.connection().connection.cursor() as cursor:
cursor.copy_from(csv_iterator, table_name, sep="|")


def _clean_csv_value(value: Optional[Any]) -> str:
if value is None:
return r"\N"
return str(value).replace("\n", "\\n")


def to_postgres_list(values: List[Any]) -> str:
if len(values) == 0:
return "{}"

return "{" + ",".join(map(str, values)) + "}"
40 changes: 40 additions & 0 deletions mev_inspect/string_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""This is taken from https://hakibenita.com/fast-load-data-python-postgresql"""

import io
from typing import Iterator, Optional


class StringIteratorIO(io.TextIOBase):
def __init__(self, iter: Iterator[str]):
self._iter = iter
self._buff = ""

def readable(self) -> bool:
return True

def _read1(self, n: Optional[int] = None) -> str:
while not self._buff:
try:
self._buff = next(self._iter)
except StopIteration:
break
ret = self._buff[:n]
self._buff = self._buff[len(ret) :]
return ret

def read(self, n: Optional[int] = None) -> str:
line = []
if n is None or n < 0:
while True:
m = self._read1()
if not m:
break
line.append(m)
else:
while n > 0:
m = self._read1(n)
if not m:
break
n -= len(m)
line.append(m)
return "".join(line)

0 comments on commit 8d7ffa6

Please sign in to comment.