diff --git a/mixnet/config.py b/mixnet/config.py index 1b37fbde..18a05042 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -49,19 +49,11 @@ class MixMembership: nodes: List[NodeInfo] - def generate_route(self, num_hops: int, last_mix: NodeInfo) -> list[NodeInfo]: + def generate_route(self, length: int) -> list[NodeInfo]: """ Generate a mix route for a Sphinx packet. - The pre-selected mix_destination is used as a last mix node in the route, - so that associated packets can be merged together into a original message. """ - return [*(self.choose() for _ in range(num_hops - 1)), last_mix] - - def choose(self) -> NodeInfo: - """ - Choose a mix node as a mix destination that will reconstruct a message from Sphinx packets. - """ - return random.choice(self.nodes) + return [random.choice(self.nodes) for _ in range(length)] @dataclass diff --git a/mixnet/node.py b/mixnet/node.py index 38f2b80e..1868b7b9 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -4,7 +4,6 @@ from typing import TypeAlias from pysphinx.sphinx import ( - Payload, ProcessedFinalHopPacket, ProcessedForwardHopPacket, SphinxPacket, @@ -12,7 +11,7 @@ from mixnet.config import GlobalConfig, NodeConfig from mixnet.nomssip import Nomssip -from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder +from mixnet.sphinx import SphinxPacketBuilder BroadcastChannel: TypeAlias = asyncio.Queue[bytes] @@ -28,7 +27,6 @@ class Node: config: NodeConfig global_config: GlobalConfig nomssip: Nomssip - reconstructor: MessageReconstructor broadcast_channel: BroadcastChannel def __init__(self, config: NodeConfig, global_config: GlobalConfig): @@ -42,7 +40,6 @@ def __init__(self, config: NodeConfig, global_config: GlobalConfig): ), self.__process_msg, ) - self.reconstructor = MessageReconstructor() self.broadcast_channel = asyncio.Queue() @staticmethod @@ -50,10 +47,10 @@ def __calculate_message_size(global_config: GlobalConfig) -> int: """ Calculate the actual message size to be gossiped, which depends on the maximum length of mix path. """ - sample_packet, _ = PacketBuilder.build_real_packets( + sample_sphinx_packet, _ = SphinxPacketBuilder.build( bytes(1), global_config.membership, global_config.max_mix_path_length - )[0] - return len(sample_packet.bytes()) + ) + return len(sample_sphinx_packet.bytes()) async def __process_msg(self, msg: bytes) -> None: """ @@ -83,24 +80,11 @@ async def __process_sphinx_packet( case ProcessedForwardHopPacket(): return processed.next_packet case ProcessedFinalHopPacket(): - return await self.__process_sphinx_payload(processed.payload) + return processed.payload.recover_plain_playload() except ValueError: # Return nothing, if it cannot be unwrapped by the private key of this node. return None - async def __process_sphinx_payload(self, payload: Payload) -> bytes | None: - """ - Process the Sphinx payload if possible - """ - msg_with_flag = self.reconstructor.add( - Fragment.from_bytes(payload.recover_plain_playload()) - ) - if msg_with_flag is not None: - flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag) - if flag == MessageFlag.MESSAGE_FLAG_REAL: - return msg - return None - def connect(self, peer: Node): """ Establish a duplex connection with a peer node. @@ -117,9 +101,9 @@ async def send_message(self, msg: bytes): """ # Here, we handle the case in which a msg is split into multiple Sphinx packets. # But, in practice, we expect a message to be small enough to fit in a single Sphinx packet. - for packet, _ in PacketBuilder.build_real_packets( + sphinx_packet, _ = SphinxPacketBuilder.build( msg, self.global_config.membership, self.config.mix_path_length, - ): - await self.nomssip.gossip(packet.bytes()) + ) + await self.nomssip.gossip(sphinx_packet.bytes()) diff --git a/mixnet/packet.py b/mixnet/packet.py deleted file mode 100644 index 513b9f08..00000000 --- a/mixnet/packet.py +++ /dev/null @@ -1,206 +0,0 @@ -from __future__ import annotations - -import uuid -from dataclasses import dataclass -from enum import Enum -from itertools import batched -from typing import Dict, List, Self, Tuple, TypeAlias - -from pysphinx.payload import Payload -from pysphinx.sphinx import SphinxPacket - -from mixnet.config import MixMembership, NodeInfo - - -class MessageFlag(Enum): - MESSAGE_FLAG_REAL = b"\x00" - MESSAGE_FLAG_DROP_COVER = b"\x01" - - def bytes(self) -> bytes: - return bytes(self.value) - - -class PacketBuilder: - @staticmethod - def build_real_packets( - message: bytes, membership: MixMembership, path_len: int - ) -> List[Tuple[SphinxPacket, List[NodeInfo]]]: - return PacketBuilder.__build_packets( - MessageFlag.MESSAGE_FLAG_REAL, message, membership, path_len - ) - - @staticmethod - def build_drop_cover_packets( - message: bytes, membership: MixMembership, path_len: int - ) -> List[Tuple[SphinxPacket, List[NodeInfo]]]: - return PacketBuilder.__build_packets( - MessageFlag.MESSAGE_FLAG_DROP_COVER, message, membership, path_len - ) - - @staticmethod - def __build_packets( - flag: MessageFlag, message: bytes, membership: MixMembership, path_len: int - ) -> List[Tuple[SphinxPacket, List[NodeInfo]]]: - if path_len <= 0: - raise ValueError("path_len must be greater than 0") - - last_mix = membership.choose() - - msg_with_flag = flag.bytes() + message - # NOTE: We don't encrypt msg_with_flag for destination. - # If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag. - fragment_set = FragmentSet(msg_with_flag) - - out = [] - for fragment in fragment_set.fragments: - route = membership.generate_route(path_len, last_mix) - packet = SphinxPacket.build( - fragment.bytes(), - [mixnode.sphinx_node() for mixnode in route], - last_mix.sphinx_node(), - ) - out.append((packet, route)) - - return out - - @staticmethod - def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]: - """Remove a MessageFlag from data""" - if len(data) < 1: - raise ValueError("data is too short") - - return (MessageFlag(data[0:1]), data[1:]) - - -# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions. -# We will use UUID until figuring out why Nym uses i32. -FragmentSetId: TypeAlias = bytes # 128bit UUID v4 -FragmentId: TypeAlias = int # unsigned 8bit int in big endian - -FRAGMENT_SET_ID_LENGTH: int = 16 -FRAGMENT_ID_LENGTH: int = 1 - - -@dataclass -class FragmentHeader: - """ - Contain all information for reconstructing a message that was fragmented into the same FragmentSet. - """ - - set_id: FragmentSetId - total_fragments: FragmentId - fragment_id: FragmentId - - SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2 - - @staticmethod - def max_total_fragments() -> int: - return 256 # because total_fragment is u8 - - def bytes(self) -> bytes: - return ( - self.set_id - + self.total_fragments.to_bytes(1) - + self.fragment_id.to_bytes(1) - ) - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - if len(data) != cls.SIZE: - raise ValueError("Invalid data length", len(data)) - - return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18])) - - -@dataclass -class FragmentSet: - """ - Represent a set of Fragments that can be reconstructed to a single original message. - - Note that the maximum number of fragments in a FragmentSet is limited for now. - """ - - fragments: List[Fragment] - - MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments() - - def __init__(self, message: bytes): - """ - Build a FragmentSet by chunking a message into Fragments. - """ - chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE) - # For now, we don't support more than max_fragments() fragments. - # If needed, we can devise the FragmentSet chaining to support larger messages, like Nym. - if len(chunked_messages) > self.MAX_FRAGMENTS: - raise ValueError(f"Too long message: {len(chunked_messages)} chunks") - - set_id = uuid.uuid4().bytes - self.fragments = [ - Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk) - for i, chunk in enumerate(chunked_messages) - ] - - -@dataclass -class Fragment: - """Represent a piece of data that can be transformed to a single SphinxPacket""" - - header: FragmentHeader - body: bytes - - MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE - - def bytes(self) -> bytes: - return self.header.bytes() + self.body - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE]) - body = data[FragmentHeader.SIZE :] - return cls(header, body) - - -@dataclass -class MessageReconstructor: - fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor] - - def __init__(self): - self.fragmentSets = {} - - def add(self, fragment: Fragment) -> bytes | None: - if fragment.header.set_id not in self.fragmentSets: - self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor( - fragment.header.total_fragments - ) - - msg = self.fragmentSets[fragment.header.set_id].add(fragment) - if msg is not None: - del self.fragmentSets[fragment.header.set_id] - return msg - - -@dataclass -class FragmentSetReconstructor: - total_fragments: FragmentId - fragments: Dict[FragmentId, Fragment] - - def __init__(self, total_fragments: FragmentId): - self.total_fragments = total_fragments - self.fragments = {} - - def add(self, fragment: Fragment) -> bytes | None: - self.fragments[fragment.header.fragment_id] = fragment - if len(self.fragments) == self.total_fragments: - return self.build_message() - else: - return None - - def build_message(self) -> bytes: - message = b"" - for i in range(self.total_fragments): - message += self.fragments[FragmentId(i)].body - return message - - -def chunks(data: bytes, size: int) -> List[bytes]: - return list(map(bytes, batched(data, size))) diff --git a/mixnet/sphinx.py b/mixnet/sphinx.py new file mode 100644 index 00000000..c3a6bdf4 --- /dev/null +++ b/mixnet/sphinx.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import List, Tuple + +from pysphinx.payload import Payload +from pysphinx.sphinx import SphinxPacket + +from mixnet.config import MixMembership, NodeInfo + + +class SphinxPacketBuilder: + @staticmethod + def build( + message: bytes, membership: MixMembership, path_len: int + ) -> Tuple[SphinxPacket, List[NodeInfo]]: + if path_len <= 0: + raise ValueError("path_len must be greater than 0") + if len(message) > Payload.max_plain_payload_size(): + raise ValueError("message is too long") + + route = membership.generate_route(path_len) + # We don't need the destination (defined in the Loopix Sphinx spec) + # because the last mix will broadcast the fully unwrapped message. + # Later, we will optimize the Sphinx according to our requirements. + dummy_destination = route[-1] + + packet = SphinxPacket.build( + message, + route=[mixnode.sphinx_node() for mixnode in route], + destination=dummy_destination.sphinx_node(), + ) + return (packet, route) diff --git a/mixnet/test_packet.py b/mixnet/test_packet.py deleted file mode 100644 index 453e648e..00000000 --- a/mixnet/test_packet.py +++ /dev/null @@ -1,103 +0,0 @@ -from random import randint -from typing import List -from unittest import TestCase - -from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket, X25519PrivateKey - -from mixnet.config import NodeInfo -from mixnet.packet import ( - Fragment, - MessageFlag, - MessageReconstructor, - PacketBuilder, -) -from mixnet.test_utils import init_mixnet_config - - -class TestPacket(TestCase): - def test_real_packet(self): - global_config, _, key_map = init_mixnet_config(10) - msg = self.random_bytes(3500) - packets_and_routes = PacketBuilder.build_real_packets( - msg, global_config.membership, 3 - ) - self.assertEqual(4, len(packets_and_routes)) - - reconstructor = MessageReconstructor() - self.assertIsNone( - reconstructor.add( - self.process_packet( - packets_and_routes[1][0], packets_and_routes[1][1], key_map - ) - ), - ) - self.assertIsNone( - reconstructor.add( - self.process_packet( - packets_and_routes[3][0], packets_and_routes[3][1], key_map - ) - ), - ) - self.assertIsNone( - reconstructor.add( - self.process_packet( - packets_and_routes[2][0], packets_and_routes[2][1], key_map - ) - ), - ) - msg_with_flag = reconstructor.add( - self.process_packet( - packets_and_routes[0][0], packets_and_routes[0][1], key_map - ) - ) - assert msg_with_flag is not None - self.assertEqual( - PacketBuilder.parse_msg_and_flag(msg_with_flag), - (MessageFlag.MESSAGE_FLAG_REAL, msg), - ) - - def test_cover_packet(self): - global_config, _, key_map = init_mixnet_config(10) - msg = b"cover" - packets_and_routes = PacketBuilder.build_drop_cover_packets( - msg, global_config.membership, 3 - ) - self.assertEqual(1, len(packets_and_routes)) - - reconstructor = MessageReconstructor() - msg_with_flag = reconstructor.add( - self.process_packet( - packets_and_routes[0][0], packets_and_routes[0][1], key_map - ) - ) - assert msg_with_flag is not None - self.assertEqual( - PacketBuilder.parse_msg_and_flag(msg_with_flag), - (MessageFlag.MESSAGE_FLAG_DROP_COVER, msg), - ) - - @staticmethod - def process_packet( - packet: SphinxPacket, - route: List[NodeInfo], - key_map: dict[bytes, X25519PrivateKey], - ) -> Fragment: - processed = packet.process(key_map[route[0].public_key.public_bytes_raw()]) - if isinstance(processed, ProcessedFinalHopPacket): - return Fragment.from_bytes(processed.payload.recover_plain_playload()) - else: - processed = processed - for node in route[1:]: - p = processed.next_packet.process( - key_map[node.public_key.public_bytes_raw()] - ) - if isinstance(p, ProcessedFinalHopPacket): - return Fragment.from_bytes(p.payload.recover_plain_playload()) - else: - processed = p - assert False - - @staticmethod - def random_bytes(size: int) -> bytes: - assert size >= 0 - return bytes([randint(0, 255) for _ in range(size)]) diff --git a/mixnet/test_sphinx.py b/mixnet/test_sphinx.py new file mode 100644 index 00000000..13e15dea --- /dev/null +++ b/mixnet/test_sphinx.py @@ -0,0 +1,39 @@ +from random import randint +from typing import cast +from unittest import TestCase + +from pysphinx.sphinx import ( + ProcessedFinalHopPacket, + ProcessedForwardHopPacket, +) + +from mixnet.sphinx import SphinxPacketBuilder +from mixnet.test_utils import init_mixnet_config + + +class TestSphinxPacketBuilder(TestCase): + def test_builder(self): + global_config, _, key_map = init_mixnet_config(10) + msg = self.random_bytes(500) + packet, route = SphinxPacketBuilder.build(msg, global_config.membership, 3) + self.assertEqual(3, len(route)) + + processed = packet.process(key_map[route[0].public_key.public_bytes_raw()]) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[1].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[2].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedFinalHopPacket) + recovered = cast( + ProcessedFinalHopPacket, processed + ).payload.recover_plain_playload() + self.assertEqual(msg, recovered) + + @staticmethod + def random_bytes(size: int) -> bytes: + assert size >= 0 + return bytes([randint(0, 255) for _ in range(size)])