From bf3fed23de22a8012c0d77e24736f55cf33f36f0 Mon Sep 17 00:00:00 2001 From: A_A <21040751+Otto-AA@users.noreply.github.com> Date: Wed, 3 Jul 2024 13:17:54 +0200 Subject: [PATCH] feat: filter candidates --- .../snapshots/snap_test_snapshots.py | 20 +++- tests/integration/test_snapshots.py | 5 +- tod_attack_miner/cli.py | 7 +- tod_attack_miner/db/db.py | 37 +++++--- tod_attack_miner/db/filters.py | 91 +++++++++++++++++++ tod_attack_miner/fetcher/fetcher.py | 3 +- tod_attack_miner/miner/miner.py | 36 ++++++-- 7 files changed, 169 insertions(+), 30 deletions(-) diff --git a/tests/integration/snapshots/snap_test_snapshots.py b/tests/integration/snapshots/snap_test_snapshots.py index 384c36b..f3ebec6 100644 --- a/tests/integration/snapshots/snap_test_snapshots.py +++ b/tests/integration/snapshots/snap_test_snapshots.py @@ -7,7 +7,7 @@ snapshots = Snapshot() -snapshots["test_tod_attack_miner_e2e attacks"] = 3113 +snapshots["test_tod_attack_miner_e2e num_candidates"] = 694 snapshots["test_tod_attack_miner_e2e stats"] = { "accesses": {"balance": 4663, "code": 2248, "nonce": 4332, "storage": 8239}, @@ -23,9 +23,21 @@ ("0x9430801ebaf509ad49202aabc5f5bc6fd8a3daf8", 38), ("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48", 31), ], - "addresses_est_total": [(233,)], - "candidates": [(2193,)], - "candidates_original": [(264234,)], + "addresses_est_total": 233, + "candidates": 694, + "candidates_filters": { + "candidates": { + "before_filters": 2193, + "final": 694, + "original_without_same_value": 264234, + }, + "filtered": { + "indirect_dependencies_quick": 766, + "indirect_dependencies_recursive": 425, + "recipient_eth_transfer": 245, + "same_sender": 63, + }, + }, "conflicts": {"balance": 1526, "nonce": 879, "storage": 708}, "state_diffs": {"balance": 2577, "code": 3, "nonce": 880, "storage": 2594}, } diff --git a/tests/integration/test_snapshots.py b/tests/integration/test_snapshots.py index 2b9328f..6eefc32 100644 --- a/tests/integration/test_snapshots.py +++ b/tests/integration/test_snapshots.py @@ -22,9 +22,10 @@ def test_tod_attack_miner_e2e(postgresql: Connection, snapshot: PyTestSnapshotTe miner.fetch(block_range.start, block_range.end) miner.find_conflicts() + miner.filter_candidates() + candidates = miner.get_candidates() stats = miner.get_stats() - attacks = miner.get_conflicts(block_range.start, block_range.end) + snapshot.assert_match(len(candidates), "num_candidates") snapshot.assert_match(stats, "stats") - snapshot.assert_match(len(attacks), "attacks") diff --git a/tod_attack_miner/cli.py b/tod_attack_miner/cli.py index 8120e76..def2cfc 100644 --- a/tod_attack_miner/cli.py +++ b/tod_attack_miner/cli.py @@ -36,7 +36,10 @@ def main(): ) as conn: miner = Miner(RPC(args.archive_node_provider), DB(conn)) - if not args.stats_only: + if args.stats_only: + print(json.dumps(miner.get_stats())) + else: miner.fetch(int(args.from_block), int(args.to_block)) miner.find_conflicts() - print(json.dumps(miner.get_stats())) + miner.filter_candidates() + print(f"Found {miner.count_candidates()} candidates") diff --git a/tod_attack_miner/db/db.py b/tod_attack_miner/db/db.py index aa304f0..94b8ac3 100644 --- a/tod_attack_miner/db/db.py +++ b/tod_attack_miner/db/db.py @@ -6,7 +6,7 @@ from tod_attack_miner.rpc.types import BlockWithTransactions, TxPrestate, TxStateDiff _TABLES = { - "transactions": "(hash TEXT PRIMARY KEY, block_number INTEGER, tx_index INTEGER)", + "transactions": "(hash TEXT PRIMARY KEY, block_number INTEGER, tx_index INTEGER, sender TEXT)", "accesses": "(block_number INTEGER, tx_index INTEGER, tx_hash TEXT, type TEXT, key TEXT, value TEXT)", "state_diffs": "(block_number INTEGER, tx_index INTEGER, tx_hash TEXT, type TEXT, key TEXT, pre_value TEXT, post_value TEXT)", "collisions": "(tx_write_hash TEXT, tx_access_hash TEXT, type TEXT, key TEXT, block_dist INTEGER)", @@ -15,7 +15,9 @@ # TODO: check if indexes are worth it _INDEXES = { "accesses_type_key": "accesses(type, key, value)", + "accesses_tx_hash": "accesses(tx_hash)", "state_diffs_type_key": "state_diffs(type, key, post_value)", + "transactions_hash_sender": "transactions(hash, sender)", } ACCESS_TYPE = ( @@ -49,11 +51,6 @@ def _setup_tables(self): def insert_prestate(self, block_number: int, tx_index: int, prestate: TxPrestate): with self._con.cursor() as cursor: - cursor.execute( - psycopg.sql.SQL("INSERT INTO transactions VALUES (%s, %s, %s)"), - (prestate["txHash"], block_number, tx_index), - ) - accesses: list[tuple[int, int, str, ACCESS_TYPE, str, str]] = [] for addr, state in prestate["result"].items(): if (balance := state.get("balance")) is not None: @@ -231,14 +228,24 @@ def count_candidates_original(self): GROUP BY accesses.tx_hash, state_diffs.tx_hash ) x """ - return cursor.execute(sql).fetchall() + return cursor.execute(sql).fetchall()[0][0] def count_candidates(self): with self._con.cursor() as cursor: - return cursor.execute("SELECT COUNT(*) FROM candidates").fetchall() + return cursor.execute("SELECT COUNT(*) FROM candidates").fetchall()[0][0] def insert_block(self, block: BlockWithTransactions): - pass + tx_values = [ + (tx["hash"], block["number"], tx["transactionIndex"], tx["from"]) + for tx in block["transactions"] + ] + + with self._con.cursor() as cursor: + cursor.executemany( + psycopg.sql.SQL("INSERT INTO transactions VALUES (%s, %s, %s, %s)"), + tx_values, + ) + self._con.commit() def _get_collisions(self): cursor = self._con.cursor() @@ -260,11 +267,11 @@ def _get_collisions(self): ] return result_dicts - def get_collisions(self) -> Sequence[tuple[str, str, tuple[str, str, int]]]: - return [ - (c["tx_write"], c["tx_access"], (c["type"], c["key"], c["block_dist"])) - for c in self._get_collisions() - ] + def get_candidates(self) -> Sequence[tuple[str, str]]: + with self._con.cursor() as cursor: + return cursor.execute( + "SELECT tx_write_hash, tx_access_hash FROM candidates" + ).fetchall() def get_accesses_stats(self): return dict( @@ -304,7 +311,7 @@ def get_unique_addresses_total(self): SELECT COUNT(DISTINCT SUBSTR(key, 1, 42)) FROM collisions """ - return cursor.execute(sql).fetchall() + return cursor.execute(sql).fetchall()[0][0] def hash_code(code: str) -> str: diff --git a/tod_attack_miner/db/filters.py b/tod_attack_miner/db/filters.py index e69de29..3205093 100644 --- a/tod_attack_miner/db/filters.py +++ b/tod_attack_miner/db/filters.py @@ -0,0 +1,91 @@ +import psycopg +import psycopg.sql +from tod_attack_miner.db.db import DB + + +def filter_same_sender(db: DB): + sql = """ +DELETE FROM candidates +USING transactions a, transactions b +WHERE a.sender = b.sender + AND candidates.tx_access_hash = a.hash + AND candidates.tx_write_hash = b.hash +""" + with db._con.cursor() as cursor: + cursor.execute(sql) + deleted = cursor.rowcount + db._con.commit() + return deleted + + +def filter_second_tx_ether_transfer(db: DB): + sql = """ +DELETE FROM candidates +WHERE NOT EXISTS ( + SELECT 1 + FROM accesses + WHERE candidates.tx_access_hash = accesses.tx_hash + AND accesses.type = 'code' +) +""" + with db._con.cursor() as cursor: + cursor.execute(sql) + deleted = cursor.rowcount + db._con.commit() + return deleted + + +def filter_indirect_dependencies_quick(db: DB): + with db._con.cursor() as cursor: + sql = """ +DELETE FROM candidates d +USING candidates a, candidates b +/* (A, X) (X, B). candidates d is (A, B). candidates a is (A, X). candidates b is (X, B) */ +/* A = A AND B = B */ +WHERE d.tx_write_hash = a.tx_write_hash AND d.tx_access_hash = b.tx_access_hash +/* B != X AND A != X */ + AND d.tx_access_hash != a.tx_access_hash AND d.tx_write_hash != b.tx_write_hash +/* X = X */ + AND a.tx_access_hash = b.tx_write_hash +""" + cursor.execute(sql) + deleted = cursor.rowcount + db._con.commit() + return deleted + + +def filter_indirect_dependencies_recursive(db: DB): + sql = psycopg.sql.SQL(""" +WITH RECURSIVE depends_on(tx_a, tx_b, min_block_number, min_tx_index, tx_x) AS ( + SELECT tx_write_hash, tx_access_hash, block_number, tx_index, tx_access_hash + FROM candidates + INNER JOIN transactions ON hash = tx_write_hash + UNION + SELECT tx_a, tx_b, min_block_number, min_tx_index, tx_write_hash + FROM depends_on, candidates + INNER JOIN transactions ON hash = tx_write_hash + WHERE depends_on.tx_x = tx_access_hash + AND (block_number > min_block_number + OR block_number = min_block_number AND tx_index > min_tx_index) +), +indirect_dependencies_candidates AS ( + SELECT tx_a, tx_b + FROM depends_on + INNER JOIN candidates ON tx_access_hash = tx_x + WHERE tx_a = tx_write_hash + AND tx_b != tx_x + + LIMIT {} +) +DELETE FROM candidates +USING indirect_dependencies_candidates +WHERE tx_write_hash = tx_a AND tx_access_hash = tx_b""").format(10000) + finished = False + deleted = 0 + while not finished: + with db._con.cursor() as cursor: + cursor.execute(sql) + finished = cursor.rowcount == 0 + deleted += cursor.rowcount + db._con.commit() + return deleted diff --git a/tod_attack_miner/fetcher/fetcher.py b/tod_attack_miner/fetcher/fetcher.py index e6cf194..4726f30 100644 --- a/tod_attack_miner/fetcher/fetcher.py +++ b/tod_attack_miner/fetcher/fetcher.py @@ -27,7 +27,8 @@ def __post_init__(self): def fetch_block_range(rpc: RPC, db: DB, block_range: BlockRange): for block_number in ( bar := tqdm( - range(block_range.start, block_range.end + 1), desc="Fetch prestate traces" + range(block_range.start, block_range.end + 1), + desc="Fetch traces and block metadata", ) ): bar.set_postfix_str(f"block {block_number}") diff --git a/tod_attack_miner/miner/miner.py b/tod_attack_miner/miner/miner.py index d6c3ac4..1b57347 100644 --- a/tod_attack_miner/miner/miner.py +++ b/tod_attack_miner/miner/miner.py @@ -1,5 +1,11 @@ from typing import Sequence from tod_attack_miner.db.db import DB +from tod_attack_miner.db.filters import ( + filter_indirect_dependencies_recursive, + filter_indirect_dependencies_quick, + filter_same_sender, + filter_second_tx_ether_transfer, +) from tod_attack_miner.fetcher.fetcher import BlockRange, fetch_block_range from tod_attack_miner.rpc.rpc import RPC @@ -8,6 +14,7 @@ class Miner: def __init__(self, rpc: RPC, db: DB) -> None: self.rpc = rpc self.db = db + self._filter_stats = {"candidates": {}, "filtered": {}} def fetch(self, start: int, end: int) -> None: fetch_block_range(self.rpc, self.db, BlockRange(start, end)) @@ -16,18 +23,35 @@ def find_conflicts(self) -> None: self.db.insert_conflicts() self.db.insert_candidates() - def get_conflicts( - self, start: int, end: int - ) -> Sequence[tuple[str, str, tuple[str, str, int]]]: - # TODO: only get attacks in the specified range - return self.db.get_collisions() + def filter_candidates(self) -> None: + self._filter_stats["candidates"]["before_filters"] = self.db.count_candidates() + self._filter_stats["filtered"]["indirect_dependencies_quick"] = ( + filter_indirect_dependencies_quick(self.db) + ) + self._filter_stats["filtered"]["indirect_dependencies_recursive"] = ( + filter_indirect_dependencies_recursive(self.db) + ) + self._filter_stats["filtered"]["same_sender"] = filter_same_sender(self.db) + self._filter_stats["filtered"]["recipient_eth_transfer"] = ( + filter_second_tx_ether_transfer(self.db) + ) + self._filter_stats["candidates"]["final"] = self.db.count_candidates() + + def count_candidates(self) -> int: + return self.db.count_candidates()[0][0] + + def get_candidates(self) -> Sequence[tuple[str, str]]: + return self.db.get_candidates() def get_stats(self): + self._filter_stats["candidates"]["original_without_same_value"] = ( + self.db.count_candidates_original() + ) return { "accesses": self.db.get_accesses_stats(), "state_diffs": self.db.get_state_diffs_stats(), "conflicts": self.db.get_conflicts_stats(), - "candidates_original": self.db.count_candidates_original(), + "candidates_filters": self._filter_stats, "candidates": self.db.count_candidates(), "addresses_est": self.db.get_unique_addresses_stats(), "addresses_est_total": self.db.get_unique_addresses_total(),