Skip to content

Commit

Permalink
feat: accept arbitrary filters as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Otto-AA committed Jul 23, 2024
1 parent 0ba48dd commit 20f26c6
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 109 deletions.
8 changes: 7 additions & 1 deletion tests/integration/test_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import pytest
from tod_attack_miner.db.db import DB
from tod_attack_miner.db.filters import (
get_filters_duplicate_limits,
get_filters_except_duplicate_limits,
)
from tod_attack_miner.fetcher.fetcher import BlockRange
from tod_attack_miner.miner.miner import Miner

Expand All @@ -23,7 +27,9 @@ def test_tod_attack_miner_e2e(postgresql: Connection, snapshot: PyTestSnapshotTe
miner.fetch(block_range.start, block_range.end)
miner.find_collisions()
miner.compute_skelcodes()
miner.filter_candidates(window_size=3)
miner.filter_candidates(
get_filters_except_duplicate_limits(3) + get_filters_duplicate_limits(10)
)

candidates = miner.get_candidates()
stats = miner.get_stats()
Expand Down
174 changes: 108 additions & 66 deletions tod_attack_miner/db/filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import psycopg
import psycopg.sql
from tod_attack_miner.db.db import DB
Expand All @@ -14,15 +15,18 @@ def filter_block_producers(db: DB):
return db.remove_candidates_without_collision()


def filter_block_window(db: DB, window_size: int | None) -> int:
if window_size is None:
return 0
sql = psycopg.sql.SQL("""
DELETE FROM collisions c
WHERE block_dist >= {}""").format(window_size)
with db._con.cursor() as cursor:
cursor.execute(sql)
return db.remove_candidates_without_collision()
def create_block_window_filter(window_size: int | None):
def block_window_filter(db: DB):
if window_size is None:
return 0
sql = psycopg.sql.SQL("""
DELETE FROM collisions c
WHERE block_dist >= {}""").format(window_size)
with db._con.cursor() as cursor:
cursor.execute(sql)
return db.remove_candidates_without_collision()

return block_window_filter


def filter_nonces(db: DB):
Expand Down Expand Up @@ -122,60 +126,98 @@ def filter_indirect_dependencies_recursive(db: DB):
return deleted


def limit_collisions_per_address(db: DB, limit=10):
sql = f"""
DELETE FROM collisions c
USING (
SELECT *, ROW_NUMBER() OVER (PARTITION BY SUBSTR(key, 1, 42) ORDER BY RANDOM()) AS n
FROM collisions
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()


def limit_collisions_per_code_hash(db: DB, limit=10):
sql = f"""
DELETE FROM collisions c
USING (
SELECT tx_write_hash, tx_access_hash, type, key, ROW_NUMBER() OVER (PARTITION BY hash ORDER BY RANDOM()) AS n
FROM collisions
INNER JOIN skeletons ON SUBSTR(collisions.key, 1, 42) = skeletons.addr
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()


def limit_collisions_per_code_family(db: DB, limit=10):
sql = f"""
DELETE FROM collisions c
USING (
SELECT tx_write_hash, tx_access_hash, type, key, ROW_NUMBER() OVER (PARTITION BY family ORDER BY RANDOM()) AS n
FROM collisions
INNER JOIN skeletons ON SUBSTR(collisions.key, 1, 42) = skeletons.addr
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()
def create_limit_collisions_per_address(limit: int):
def limit_collisions_per_address(db: DB):
sql = f"""
DELETE FROM collisions c
USING (
SELECT *, ROW_NUMBER() OVER (PARTITION BY SUBSTR(key, 1, 42) ORDER BY RANDOM()) AS n
FROM collisions
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()

return limit_collisions_per_address


def create_limit_collisions_per_code_hash(limit: int):
def limit_collisions_per_code_hash(db: DB):
sql = f"""
DELETE FROM collisions c
USING (
SELECT tx_write_hash, tx_access_hash, type, key, ROW_NUMBER() OVER (PARTITION BY hash ORDER BY RANDOM()) AS n
FROM collisions
INNER JOIN skeletons ON SUBSTR(collisions.key, 1, 42) = skeletons.addr
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()

return limit_collisions_per_code_hash


def create_limit_collisions_per_code_family(limit=10):
def limit_collisions_per_code_family(db: DB):
sql = f"""
DELETE FROM collisions c
USING (
SELECT tx_write_hash, tx_access_hash, type, key, ROW_NUMBER() OVER (PARTITION BY family ORDER BY RANDOM()) AS n
FROM collisions
INNER JOIN skeletons ON SUBSTR(collisions.key, 1, 42) = skeletons.addr
) grouped
WHERE c.tx_write_hash = grouped.tx_write_hash
AND c.tx_access_hash = grouped.tx_access_hash
AND c.type = grouped.type
AND c.key = grouped.key
AND n > {limit}
"""
with db._con.cursor() as cursor:
cursor.execute("SELECT setseed(0)")
cursor.execute(sql) # type: ignore
return db.remove_candidates_without_collision()

return limit_collisions_per_code_family


def get_filters_except_duplicate_limits(
window_size: int,
) -> list[tuple[str, Callable[[DB], int]]]:
return [
("block_window", create_block_window_filter(window_size)),
("block_producers", filter_block_producers),
("nonces", filter_nonces),
("codes", filter_codes),
("indirect_dependencies_quick", filter_indirect_dependencies_quick),
("indirect_dependencies_recursive", filter_indirect_dependencies_recursive),
("same_sender", filter_same_sender),
("recipient_eth_transfer", filter_second_tx_ether_transfer),
]


def get_filters_duplicate_limits(limit: int) -> list[tuple[str, Callable[[DB], int]]]:
return [
("limited_collisions_per_address", create_limit_collisions_per_address(limit)),
(
"limited_collisions_per_code_hash",
create_limit_collisions_per_code_hash(limit),
),
(
"limited_collisions_per_code_family",
create_limit_collisions_per_code_family(limit),
),
]
48 changes: 6 additions & 42 deletions tod_attack_miner/miner/miner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
from typing import Sequence
from typing import Callable, Sequence
from tod_attack_miner.db.db import DB, Candidate
from tod_attack_miner.db.filters import (
filter_codes,
filter_nonces,
filter_block_window,
filter_indirect_dependencies_recursive,
filter_indirect_dependencies_quick,
filter_same_sender,
filter_second_tx_ether_transfer,
filter_block_producers,
limit_collisions_per_address,
limit_collisions_per_code_family,
limit_collisions_per_code_hash,
)
from tod_attack_miner.fetcher.fetcher import BlockRange, fetch_block_range
from tod_attack_miner.rpc.rpc import RPC

Expand All @@ -35,35 +22,12 @@ def find_collisions(self) -> None:
self.db.insert_candidates()
self._original_collisions = self.db.get_collisions_stats()

def filter_candidates(self, window_size: int | None) -> None:
def filter_candidates(
self, filters: Sequence[tuple[str, Callable[[DB], int]]]
) -> None:
self._filter_stats["candidates"]["before_filters"] = self.db.count_candidates()
self._filter_stats["filtered"]["block_window"] = filter_block_window(
self.db, window_size
)
self._filter_stats["filtered"]["block_producers"] = filter_block_producers(
self.db
)
self._filter_stats["filtered"]["nonces"] = filter_nonces(self.db)
self._filter_stats["filtered"]["codes"] = filter_codes(self.db)
self._filter_stats["filtered"]["indirect_dependencies_quick"] = (
filter_indirect_dependencies_quick(self.db)
)
self._filter_stats["filtered"]["indirect_dependencies_recursive"] = (
filter_indirect_dependencies_recursive(self.db)
)
self._filter_stats["filtered"]["same_sender"] = filter_same_sender(self.db)
self._filter_stats["filtered"]["recipient_eth_transfer"] = (
filter_second_tx_ether_transfer(self.db)
)
self._filter_stats["filtered"]["limited_collisions_per_address"] = (
limit_collisions_per_address(self.db)
)
self._filter_stats["filtered"]["limited_collisions_per_code_hash"] = (
limit_collisions_per_code_hash(self.db)
)
self._filter_stats["filtered"]["limited_collisions_per_code_family"] = (
limit_collisions_per_code_family(self.db)
)
for name, filter in filters:
self._filter_stats["filtered"][name] = filter(self.db)
self._filter_stats["candidates"]["final"] = self.db.count_candidates()

def count_candidates(self) -> int:
Expand Down

0 comments on commit 20f26c6

Please sign in to comment.