diff --git a/mixnet/README.md b/mixnet/README.md deleted file mode 100644 index 0cc18d93..00000000 --- a/mixnet/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Mixnet Specification - -This is the executable specification of Mixnet, which can be used as a networking layer of the Nomos network. - -![](structure.png) - -## Public Components - -- [`mixnet.py`](mixnet.py): A public interface of the Mixnet layer, which can be used by upper layers -- [`robustness.py`](robustness.py): A public interface of the Robustness layer, which can be on top of the Mixnet layer and used by upper layers - -## Private Components - -There are two primary components in the Mixnet layer. - -- [`client.py`](client.py): A mix client interface, which splits a message into Sphinx packets, sends packets to mix nodes, and receives messages via gossip. Also, this emits cover packets periodically. -- [`node.py`](node.py): A mix node interface, which receives Sphinx packets from other mix nodes, processes packets, and forwards packets to other mix nodes. This works only when selected by the topology construction. - -Each component receives a new topology from the Robustness layer. - -There is no interaction between mix client and mix node components. diff --git a/mixnet/bls.py b/mixnet/bls.py deleted file mode 100644 index a4278b08..00000000 --- a/mixnet/bls.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypeAlias - -import blspy - -from mixnet.utils import random_bytes - -BlsPrivateKey: TypeAlias = blspy.PrivateKey -BlsPublicKey: TypeAlias = blspy.G1Element - - -def generate_bls() -> BlsPrivateKey: - seed = random_bytes(32) - return blspy.BasicSchemeMPL.key_gen(seed) diff --git a/mixnet/client.py b/mixnet/client.py deleted file mode 100644 index 8d21cf3c..00000000 --- a/mixnet/client.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import asyncio -from contextlib import suppress -from typing import Self - -from mixnet.config import MixClientConfig, MixnetTopology -from mixnet.node import PacketQueue -from mixnet.packet import PacketBuilder -from mixnet.poisson import poisson_interval_sec - - -class MixClient: - config: MixClientConfig - real_packet_queue: PacketQueue - outbound_socket: PacketQueue - task: asyncio.Task # A reference just to prevent task from being garbage collected - - @classmethod - async def new( - cls, - config: MixClientConfig, - ) -> Self: - self = cls() - self.config = config - self.real_packet_queue = asyncio.Queue() - self.outbound_socket = asyncio.Queue() - self.task = asyncio.create_task(self.__run()) - return self - - def set_topology(self, topology: MixnetTopology) -> None: - """ - Replace the old topology with the new topology received - - In real implementations, this method may be integrated in a long-running task. - Here in the spec, this method has been simplified as a setter, assuming the single-thread test environment. - """ - self.config.topology = topology - - # Only for testing - def get_topology(self) -> MixnetTopology: - return self.config.topology - - async def send_message(self, msg: bytes) -> None: - packets_and_routes = PacketBuilder.build_real_packets(msg, self.config.topology) - for packet, route in packets_and_routes: - await self.real_packet_queue.put((route[0].addr, packet)) - - def subscribe_messages(self) -> "asyncio.Queue[bytes]": - """ - Subscribe messages, which went through mix nodes and were broadcasted via gossip - """ - return asyncio.Queue() - - async def __run(self): - """ - Emit packets at the Poisson emission_rate_per_min. - - If a real packet is scheduled to be sent, this thread sends the real packet to the mixnet, - and schedules redundant real packets to be emitted in the next turns. - - If no real packet is not scheduled, this thread emits a cover packet according to the emission_rate_per_min. - """ - - redundant_real_packet_queue: PacketQueue = asyncio.Queue() - - emission_notifier_queue = asyncio.Queue() - _ = asyncio.create_task( - self.__emission_notifier( - self.config.emission_rate_per_min, emission_notifier_queue - ) - ) - - while True: - # Wait until the next emission time - _ = await emission_notifier_queue.get() - try: - await self.__emit(self.config.redundancy, redundant_real_packet_queue) - finally: - # Python convention: indicate that the previously enqueued task has been processed - emission_notifier_queue.task_done() - - async def __emit( - self, - redundancy: int, # b in the spec - redundant_real_packet_queue: PacketQueue, - ): - if not redundant_real_packet_queue.empty(): - addr, packet = redundant_real_packet_queue.get_nowait() - await self.outbound_socket.put((addr, packet)) - return - - if not self.real_packet_queue.empty(): - addr, packet = self.real_packet_queue.get_nowait() - # Schedule redundant real packets - for _ in range(redundancy - 1): - redundant_real_packet_queue.put_nowait((addr, packet)) - await self.outbound_socket.put((addr, packet)) - - packets_and_routes = PacketBuilder.build_drop_cover_packets( - b"drop cover", self.config.topology - ) - # We have a for loop here, but we expect that the total num of packets is 1 - # because the dummy message is short. - for packet, route in packets_and_routes: - await self.outbound_socket.put((route[0].addr, packet)) - - async def __emission_notifier( - self, emission_rate_per_min: int, queue: asyncio.Queue - ): - while True: - await asyncio.sleep(poisson_interval_sec(emission_rate_per_min)) - queue.put_nowait(None) - - async def cancel(self) -> None: - self.task.cancel() - with suppress(asyncio.CancelledError): - await self.task diff --git a/mixnet/config.py b/mixnet/config.py index 5b1a2d10..749038a9 100644 --- a/mixnet/config.py +++ b/mixnet/config.py @@ -2,110 +2,69 @@ import random from dataclasses import dataclass -from typing import List, TypeAlias +from typing import List from cryptography.hazmat.primitives.asymmetric.x25519 import ( X25519PrivateKey, X25519PublicKey, ) -from pysphinx.node import Node - -from mixnet.bls import BlsPrivateKey, BlsPublicKey -from mixnet.fisheryates import FisherYates +from pysphinx.sphinx import Node as SphinxNode @dataclass -class MixnetConfig: - topology_config: MixnetTopologyConfig - mixclient_config: MixClientConfig - mixnode_config: MixNodeConfig - +class GlobalConfig: + """ + Global parameters used across all nodes in the network + """ -@dataclass -class MixnetTopologyConfig: - mixnode_candidates: List[MixNodeInfo] - size: MixnetTopologySize - entropy: bytes + membership: MixMembership + transmission_rate_per_sec: int # Global Transmission Rate + # TODO: use these two to make the size of Sphinx packet constant + max_message_size: int + max_mix_path_length: int @dataclass -class MixClientConfig: - emission_rate_per_min: int # Poisson rate parameter: lambda - redundancy: int - topology: MixnetTopology +class NodeConfig: + """ + Node-specific parameters + """ + + private_key: X25519PrivateKey + mix_path_length: int + nomssip: NomssipConfig @dataclass -class MixNodeConfig: - encryption_private_key: X25519PrivateKey - delay_rate_per_min: int # Poisson rate parameter: mu +class NomssipConfig: + # The target number of peers a node should maintain in its p2p network + peering_degree: int @dataclass -class MixnetTopology: - # In production, this can be a 1-D array, which is accessible by indexes. - # Here, we use a 2-D array for readability. - layers: List[List[MixNodeInfo]] - - def __init__( - self, - config: MixnetTopologyConfig, - ) -> None: - """ - Build a new topology deterministically using an entropy and a given set of candidates. - """ - shuffled = FisherYates.shuffle(config.mixnode_candidates, config.entropy) - sampled = shuffled[: config.size.num_total_mixnodes()] +class MixMembership: + """ + A list of public information of nodes in the network. + We assume that this list is known to all nodes in the network. + """ - layers = [] - for layer_id in range(config.size.num_layers): - start = layer_id * config.size.num_mixnodes_per_layer - layer = sampled[start : start + config.size.num_mixnodes_per_layer] - layers.append(layer) - self.layers = layers - - def generate_route(self, mix_destination: MixNodeInfo) -> list[MixNodeInfo]: - """ - Generate a mix route for a Sphinx packet. - The pre-selected mix_destination is used as a last mix node in the route, - so that associated packets can be merged together into a original message. - """ - route = [random.choice(layer) for layer in self.layers[:-1]] - route.append(mix_destination) - return route + nodes: List[NodeInfo] - def choose_mix_destination(self) -> MixNodeInfo: + def generate_route(self, length: int) -> list[NodeInfo]: """ - Choose a mix node from the last mix layer as a mix destination - that will reconstruct a message from Sphinx packets. + Choose `length` nodes with replacement as a mix route. """ - return random.choice(self.layers[-1]) - - -@dataclass -class MixnetTopologySize: - num_layers: int - num_mixnodes_per_layer: int - - def num_total_mixnodes(self) -> int: - return self.num_layers * self.num_mixnodes_per_layer - - -# 32-byte that represents an IP address and a port of a mix node. -NodeAddress: TypeAlias = bytes + return random.choices(self.nodes, k=length) @dataclass -class MixNodeInfo: - identity_private_key: BlsPrivateKey - encryption_private_key: X25519PrivateKey - addr: NodeAddress - - def identity_public_key(self) -> BlsPublicKey: - return self.identity_private_key.get_g1() +class NodeInfo: + """ + Public information of a node to be shared to all nodes in the network + """ - def encryption_public_key(self) -> X25519PublicKey: - return self.encryption_private_key.public_key() + public_key: X25519PublicKey - def sphinx_node(self) -> Node: - return Node(self.encryption_private_key, self.addr) + def sphinx_node(self) -> SphinxNode: + dummy_node_addr = bytes(32) + return SphinxNode(self.public_key, dummy_node_addr) diff --git a/mixnet/connection.py b/mixnet/connection.py new file mode 100644 index 00000000..5fc319d8 --- /dev/null +++ b/mixnet/connection.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import asyncio + +NetworkPacketQueue = asyncio.Queue[bytes] +SimplexConnection = NetworkPacketQueue + + +class DuplexConnection: + """ + A duplex connection in which data can be transmitted and received simultaneously in both directions. + This is to mimic duplex communication in a real network (such as TCP or QUIC). + """ + + def __init__(self, inbound: SimplexConnection, outbound: MixSimplexConnection): + self.inbound = inbound + self.outbound = outbound + + async def recv(self) -> bytes: + return await self.inbound.get() + + async def send(self, packet: bytes): + await self.outbound.send(packet) + + +class MixSimplexConnection: + """ + Wraps a SimplexConnection to add a transmission rate and noise to the connection. + """ + + def __init__( + self, conn: SimplexConnection, transmission_rate_per_sec: int, noise_msg: bytes + ): + self.queue = asyncio.Queue() + self.conn = conn + self.transmission_rate_per_sec = transmission_rate_per_sec + self.noise_msg = noise_msg + self.task = asyncio.create_task(self.__run()) + + async def __run(self): + while True: + await asyncio.sleep(1 / self.transmission_rate_per_sec) + # TODO: temporal mixing + if self.queue.empty(): + # To guarantee GTR, send noise if there is no message to send + msg = self.noise_msg + else: + msg = self.queue.get_nowait() + await self.conn.put(msg) + + async def send(self, msg: bytes): + await self.queue.put(msg) diff --git a/mixnet/fisheryates.py b/mixnet/fisheryates.py deleted file mode 100644 index 70c92ffc..00000000 --- a/mixnet/fisheryates.py +++ /dev/null @@ -1,21 +0,0 @@ -import random -from typing import List - - -class FisherYates: - @staticmethod - def shuffle(elements: List, entropy: bytes) -> List: - """ - Fisher-Yates shuffling algorithm. - In Python, random.shuffle implements the Fisher-Yates shuffling. - https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle - https://softwareengineering.stackexchange.com/a/215780 - :param elements: elements to be shuffled - :param entropy: a seed for deterministic sampling - """ - out = elements.copy() - random.seed(a=entropy, version=2) - random.shuffle(out) - # reset seed - random.seed() - return out diff --git a/mixnet/mixnet.py b/mixnet/mixnet.py deleted file mode 100644 index dedc9efd..00000000 --- a/mixnet/mixnet.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import asyncio -from contextlib import suppress -from typing import Self, TypeAlias - -from mixnet.client import MixClient -from mixnet.config import MixnetConfig, MixnetTopology, MixnetTopologyConfig -from mixnet.node import MixNode - -EntropyQueue: TypeAlias = "asyncio.Queue[bytes]" - - -class Mixnet: - topology_config: MixnetTopologyConfig - - mixclient: MixClient - mixnode: MixNode - entropy_queue: EntropyQueue - task: asyncio.Task # A reference just to prevent task from being garbage collected - - @classmethod - async def new( - cls, - config: MixnetConfig, - entropy_queue: EntropyQueue, - ) -> Self: - self = cls() - self.topology_config = config.topology_config - self.mixclient = await MixClient.new(config.mixclient_config) - self.mixnode = await MixNode.new(config.mixnode_config) - self.entropy_queue = entropy_queue - self.task = asyncio.create_task(self.__consume_entropy()) - return self - - async def publish_message(self, msg: bytes) -> None: - await self.mixclient.send_message(msg) - - def subscribe_messages(self) -> "asyncio.Queue[bytes]": - return self.mixclient.subscribe_messages() - - async def __consume_entropy( - self, - ) -> None: - while True: - entropy = await self.entropy_queue.get() - self.topology_config.entropy = entropy - - topology = MixnetTopology(self.topology_config) - self.mixclient.set_topology(topology) - - async def cancel(self) -> None: - self.task.cancel() - with suppress(asyncio.CancelledError): - await self.task - - await self.mixclient.cancel() - await self.mixnode.cancel() - - # Only for testing - def get_topology(self) -> MixnetTopology: - return self.mixclient.get_topology() diff --git a/mixnet/node.py b/mixnet/node.py index 8ab4b77c..06d36cc7 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,107 +1,106 @@ from __future__ import annotations import asyncio -from contextlib import suppress -from typing import Self, Tuple, TypeAlias +from typing import TypeAlias -from cryptography.hazmat.primitives.asymmetric.x25519 import ( - X25519PrivateKey, -) from pysphinx.sphinx import ( - Payload, ProcessedFinalHopPacket, ProcessedForwardHopPacket, SphinxPacket, - UnknownHeaderTypeError, ) -from mixnet.config import MixNodeConfig, NodeAddress -from mixnet.poisson import poisson_interval_sec +from mixnet.config import GlobalConfig, NodeConfig +from mixnet.nomssip import Nomssip +from mixnet.sphinx import SphinxPacketBuilder -PacketQueue: TypeAlias = "asyncio.Queue[Tuple[NodeAddress, SphinxPacket]]" -PacketPayloadQueue: TypeAlias = ( - "asyncio.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" -) +BroadcastChannel: TypeAlias = asyncio.Queue[bytes] -class MixNode: +class Node: """ - A class handling incoming packets with delays - - This class is defined separated with the MixNode class, - in order to define the MixNode as a simple dataclass for clarity. + This represents any node in the network, which: + - generates/gossips mix messages (Sphinx packets) + - performs cryptographic mix (unwrapping Sphinx packets) + - generates noise """ - config: MixNodeConfig - inbound_socket: PacketQueue - outbound_socket: PacketPayloadQueue - task: asyncio.Task # A reference just to prevent task from being garbage collected - - @classmethod - async def new( - cls, - config: MixNodeConfig, - ) -> Self: - self = cls() + def __init__(self, config: NodeConfig, global_config: GlobalConfig): self.config = config - self.inbound_socket = asyncio.Queue() - self.outbound_socket = asyncio.Queue() - self.task = asyncio.create_task(self.__run()) - return self - - async def __run(self): + self.global_config = global_config + self.nomssip = Nomssip( + Nomssip.Config( + global_config.transmission_rate_per_sec, + config.nomssip.peering_degree, + self.__calculate_message_size(global_config), + ), + self.__process_msg, + ) + self.broadcast_channel = asyncio.Queue() + + @staticmethod + def __calculate_message_size(global_config: GlobalConfig) -> int: """ - Read SphinxPackets from inbound socket and spawn a thread for each packet to process it. - - This thread approximates a M/M/inf queue. + Calculate the actual message size to be gossiped, which depends on the maximum length of mix path. """ - - # A set just for gathering a reference of tasks to prevent them from being garbage collected. - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - self.tasks = set() - - while True: - _, packet = await self.inbound_socket.get() - task = asyncio.create_task( - self.__process_packet( - packet, - self.config.encryption_private_key, - self.config.delay_rate_per_min, - ) - ) - self.tasks.add(task) - # To discard the task from the set automatically when it is done. - task.add_done_callback(self.tasks.discard) - - async def __process_packet( - self, - packet: SphinxPacket, - encryption_private_key: X25519PrivateKey, - delay_rate_per_min: int, # Poisson rate parameter: mu - ): + sample_sphinx_packet, _ = SphinxPacketBuilder.build( + bytes(global_config.max_message_size), + global_config, + global_config.max_mix_path_length, + ) + return len(sample_sphinx_packet.bytes()) + + async def __process_msg(self, msg: bytes) -> None: """ - Process a single packet with a delay that follows exponential distribution, - and forward it to the next mix node or the mix destination - - This thread is a single server (worker) in a M/M/inf queue that MixNodeRunner approximates. + A handler to process messages received via gossip channel """ - delay_sec = poisson_interval_sec(delay_rate_per_min) - await asyncio.sleep(delay_sec) - - processed = packet.process(encryption_private_key) - match processed: - case ProcessedForwardHopPacket(): - await self.outbound_socket.put( - (processed.next_node_address, processed.next_packet) - ) - case ProcessedFinalHopPacket(): - await self.outbound_socket.put( - (processed.destination_node_address, processed.payload) - ) - case _: - raise UnknownHeaderTypeError + sphinx_packet = SphinxPacket.from_bytes(msg) + result = await self.__process_sphinx_packet(sphinx_packet) + match result: + case SphinxPacket(): + # Gossip the next Sphinx packet + await self.nomssip.gossip(result.bytes()) + case bytes(): + # Broadcast the message fully recovered from Sphinx packets + await self.broadcast_channel.put(result) + case None: + return + + async def __process_sphinx_packet( + self, packet: SphinxPacket + ) -> SphinxPacket | bytes | None: + """ + Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible + """ + try: + processed = packet.process(self.config.private_key) + match processed: + case ProcessedForwardHopPacket(): + return processed.next_packet + case ProcessedFinalHopPacket(): + return processed.payload.recover_plain_playload() + except ValueError: + # Return nothing, if it cannot be unwrapped by the private key of this node. + return None + + def connect(self, peer: Node): + """ + Establish a duplex connection with a peer node. + """ + inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() + # Register a duplex connection for its own use + self.nomssip.add_conn(inbound_conn, outbound_conn) + # Register a duplex connection for the peer + peer.nomssip.add_conn(outbound_conn, inbound_conn) - async def cancel(self) -> None: - self.task.cancel() - with suppress(asyncio.CancelledError): - await self.task + async def send_message(self, msg: bytes): + """ + Build a Sphinx packet and gossip it to all connected peers. + """ + # Here, we handle the case in which a msg is split into multiple Sphinx packets. + # But, in practice, we expect a message to be small enough to fit in a single Sphinx packet. + sphinx_packet, _ = SphinxPacketBuilder.build( + msg, + self.global_config, + self.config.mix_path_length, + ) + await self.nomssip.gossip(sphinx_packet.bytes()) diff --git a/mixnet/nomssip.py b/mixnet/nomssip.py new file mode 100644 index 00000000..f0fc4e63 --- /dev/null +++ b/mixnet/nomssip.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import asyncio +import hashlib +from dataclasses import dataclass +from enum import Enum +from typing import Awaitable, Callable, Self + +from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection + + +class Nomssip: + """ + A NomMix gossip channel that broadcasts messages to all connected peers. + Peers are connected via DuplexConnection. + """ + + @dataclass + class Config: + transmission_rate_per_sec: int + peering_degree: int + msg_size: int + + def __init__( + self, + config: Config, + handler: Callable[[bytes], Awaitable[None]], + ): + self.config = config + self.conns: list[DuplexConnection] = [] + # A handler to process inbound messages. + self.handler = handler + self.packet_cache: set[bytes] = set() + # A set just for gathering a reference of tasks to prevent them from being garbage collected. + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self.tasks: set[asyncio.Task] = set() + + def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): + if len(self.conns) >= self.config.peering_degree: + # For simplicity of the spec, reject the connection if the peering degree is reached. + raise ValueError("The peering degree is reached.") + + noise_packet = FlaggedPacket( + FlaggedPacket.Flag.NOISE, bytes(self.config.msg_size) + ).bytes() + conn = DuplexConnection( + inbound, + MixSimplexConnection( + outbound, + self.config.transmission_rate_per_sec, + noise_packet, + ), + ) + + self.conns.append(conn) + task = asyncio.create_task(self.__process_inbound_conn(conn)) + self.tasks.add(task) + # To discard the task from the set automatically when it is done. + task.add_done_callback(self.tasks.discard) + + async def __process_inbound_conn(self, conn: DuplexConnection): + while True: + packet = await conn.recv() + if self.__check_update_cache(packet): + continue + + packet = FlaggedPacket.from_bytes(packet) + match packet.flag: + case FlaggedPacket.Flag.NOISE: + # Drop noise packet + continue + case FlaggedPacket.Flag.REAL: + await self.__gossip_flagged_packet(packet) + await self.handler(packet.message) + + async def gossip(self, msg: bytes): + """ + Gossip a message to all connected peers with prepending a message flag + """ + # The message size must be fixed. + assert len(msg) == self.config.msg_size + + packet = FlaggedPacket(FlaggedPacket.Flag.REAL, msg) + await self.__gossip_flagged_packet(packet) + + async def __gossip_flagged_packet(self, packet: FlaggedPacket): + """ + An internal method to send a flagged packet to all connected peers + """ + for conn in self.conns: + await conn.send(packet.bytes()) + + def __check_update_cache(self, packet: bytes) -> bool: + """ + Add a message to the cache, and return True if the message was already in the cache. + """ + hash = hashlib.sha256(packet).digest() + if hash in self.packet_cache: + return True + self.packet_cache.add(hash) + return False + + +class FlaggedPacket: + class Flag(Enum): + REAL = b"\x00" + NOISE = b"\x01" + + def __init__(self, flag: Flag, message: bytes): + self.flag = flag + self.message = message + + def bytes(self) -> bytes: + return self.flag.value + self.message + + @classmethod + def from_bytes(cls, packet: bytes) -> Self: + """ + Parse a flagged packet from bytes + """ + if len(packet) < 1: + raise ValueError("Invalid message format") + return cls(cls.Flag(packet[:1]), packet[1:]) diff --git a/mixnet/packet.py b/mixnet/packet.py deleted file mode 100644 index 71fe6878..00000000 --- a/mixnet/packet.py +++ /dev/null @@ -1,203 +0,0 @@ -from __future__ import annotations - -import uuid -from dataclasses import dataclass -from enum import Enum -from itertools import batched -from typing import Dict, List, Self, Tuple, TypeAlias - -from pysphinx.payload import Payload -from pysphinx.sphinx import SphinxPacket - -from mixnet.config import MixnetTopology, MixNodeInfo - - -class MessageFlag(Enum): - MESSAGE_FLAG_REAL = b"\x00" - MESSAGE_FLAG_DROP_COVER = b"\x01" - - def bytes(self) -> bytes: - return bytes(self.value) - - -class PacketBuilder: - @staticmethod - def build_real_packets( - message: bytes, topology: MixnetTopology - ) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]: - return PacketBuilder.__build_packets( - MessageFlag.MESSAGE_FLAG_REAL, message, topology - ) - - @staticmethod - def build_drop_cover_packets( - message: bytes, topology: MixnetTopology - ) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]: - return PacketBuilder.__build_packets( - MessageFlag.MESSAGE_FLAG_DROP_COVER, message, topology - ) - - @staticmethod - def __build_packets( - flag: MessageFlag, message: bytes, topology: MixnetTopology - ) -> List[Tuple[SphinxPacket, List[MixNodeInfo]]]: - destination = topology.choose_mix_destination() - - msg_with_flag = flag.bytes() + message - # NOTE: We don't encrypt msg_with_flag for destination. - # If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag. - fragment_set = FragmentSet(msg_with_flag) - - out = [] - for fragment in fragment_set.fragments: - route = topology.generate_route(destination) - packet = SphinxPacket.build( - fragment.bytes(), - [mixnode.sphinx_node() for mixnode in route], - destination.sphinx_node(), - ) - out.append((packet, route)) - - return out - - @staticmethod - def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]: - """Remove a MessageFlag from data""" - if len(data) < 1: - raise ValueError("data is too short") - - return (MessageFlag(data[0:1]), data[1:]) - - -# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions. -# We will use UUID until figuring out why Nym uses i32. -FragmentSetId: TypeAlias = bytes # 128bit UUID v4 -FragmentId: TypeAlias = int # unsigned 8bit int in big endian - -FRAGMENT_SET_ID_LENGTH: int = 16 -FRAGMENT_ID_LENGTH: int = 1 - - -@dataclass -class FragmentHeader: - """ - Contain all information for reconstructing a message that was fragmented into the same FragmentSet. - """ - - set_id: FragmentSetId - total_fragments: FragmentId - fragment_id: FragmentId - - SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2 - - @staticmethod - def max_total_fragments() -> int: - return 256 # because total_fragment is u8 - - def bytes(self) -> bytes: - return ( - self.set_id - + self.total_fragments.to_bytes(1) - + self.fragment_id.to_bytes(1) - ) - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - if len(data) != cls.SIZE: - raise ValueError("Invalid data length", len(data)) - - return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18])) - - -@dataclass -class FragmentSet: - """ - Represent a set of Fragments that can be reconstructed to a single original message. - - Note that the maximum number of fragments in a FragmentSet is limited for now. - """ - - fragments: List[Fragment] - - MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments() - - def __init__(self, message: bytes): - """ - Build a FragmentSet by chunking a message into Fragments. - """ - chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE) - # For now, we don't support more than max_fragments() fragments. - # If needed, we can devise the FragmentSet chaining to support larger messages, like Nym. - if len(chunked_messages) > self.MAX_FRAGMENTS: - raise ValueError(f"Too long message: {len(chunked_messages)} chunks") - - set_id = uuid.uuid4().bytes - self.fragments = [ - Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk) - for i, chunk in enumerate(chunked_messages) - ] - - -@dataclass -class Fragment: - """Represent a piece of data that can be transformed to a single SphinxPacket""" - - header: FragmentHeader - body: bytes - - MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE - - def bytes(self) -> bytes: - return self.header.bytes() + self.body - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE]) - body = data[FragmentHeader.SIZE :] - return cls(header, body) - - -@dataclass -class MessageReconstructor: - fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor] - - def __init__(self): - self.fragmentSets = {} - - def add(self, fragment: Fragment) -> bytes | None: - if fragment.header.set_id not in self.fragmentSets: - self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor( - fragment.header.total_fragments - ) - - msg = self.fragmentSets[fragment.header.set_id].add(fragment) - if msg is not None: - del self.fragmentSets[fragment.header.set_id] - return msg - - -@dataclass -class FragmentSetReconstructor: - total_fragments: FragmentId - fragments: Dict[FragmentId, Fragment] - - def __init__(self, total_fragments: FragmentId): - self.total_fragments = total_fragments - self.fragments = {} - - def add(self, fragment: Fragment) -> bytes | None: - self.fragments[fragment.header.fragment_id] = fragment - if len(self.fragments) == self.total_fragments: - return self.build_message() - else: - return None - - def build_message(self) -> bytes: - message = b"" - for i in range(self.total_fragments): - message += self.fragments[FragmentId(i)].body - return message - - -def chunks(data: bytes, size: int) -> List[bytes]: - return list(map(bytes, batched(data, size))) diff --git a/mixnet/poisson.py b/mixnet/poisson.py deleted file mode 100644 index 86a4b66a..00000000 --- a/mixnet/poisson.py +++ /dev/null @@ -1,13 +0,0 @@ -import numpy - - -def poisson_interval_sec(rate_per_min: int) -> float: - # If events occur in a Poisson distribution with rate_per_min, - # the interval between events follows the exponential distribution - # with the rate_per_min (i.e. with the scale 1/rate_per_min). - interval_min = numpy.random.exponential(scale=1 / rate_per_min, size=1)[0] - return interval_min * 60 - - -def poisson_mean_interval_sec(rate_per_min: int) -> float: - return 1 / rate_per_min * 60 diff --git a/mixnet/sphinx.py b/mixnet/sphinx.py new file mode 100644 index 00000000..a3bab4a3 --- /dev/null +++ b/mixnet/sphinx.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import List, Tuple + +from pysphinx.sphinx import SphinxPacket + +from mixnet.config import GlobalConfig, NodeInfo + + +class SphinxPacketBuilder: + @staticmethod + def build( + message: bytes, global_config: GlobalConfig, path_len: int + ) -> Tuple[SphinxPacket, List[NodeInfo]]: + if path_len <= 0: + raise ValueError("path_len must be greater than 0") + if len(message) > global_config.max_message_size: + raise ValueError("message is too long") + + route = global_config.membership.generate_route(path_len) + # We don't need the destination (defined in the Loopix Sphinx spec) + # because the last mix will broadcast the fully unwrapped message. + # Later, we will optimize the Sphinx according to our requirements. + dummy_destination = route[-1] + + packet = SphinxPacket.build( + message, + route=[mixnode.sphinx_node() for mixnode in route], + destination=dummy_destination.sphinx_node(), + ) + return (packet, route) diff --git a/mixnet/structure.png b/mixnet/structure.png deleted file mode 100644 index 994313c5..00000000 Binary files a/mixnet/structure.png and /dev/null differ diff --git a/mixnet/test_client.py b/mixnet/test_client.py deleted file mode 100644 index 43274662..00000000 --- a/mixnet/test_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from datetime import datetime -from unittest import IsolatedAsyncioTestCase - -import numpy - -from mixnet.client import MixClient -from mixnet.poisson import poisson_mean_interval_sec -from mixnet.test_utils import ( - init_mixnet_config, - with_test_timeout, -) -from mixnet.utils import random_bytes - - -class TestMixClient(IsolatedAsyncioTestCase): - @with_test_timeout(100) - async def test_mixclient(self): - config = init_mixnet_config().mixclient_config - config.emission_rate_per_min = 30 - config.redundancy = 3 - - mixclient = await MixClient.new(config) - try: - # Send a 3500-byte msg, expecting that it is split into at least two packets - await mixclient.send_message(random_bytes(3500)) - - # Calculate intervals between packet emissions from the mix client - intervals = [] - ts = datetime.now() - for _ in range(30): - _ = await mixclient.outbound_socket.get() - now = datetime.now() - intervals.append((now - ts).total_seconds()) - ts = now - - # Check if packets were emitted at the Poisson emission_rate - # If emissions follow the Poisson distribution with a rate `lambda`, - # a mean interval between emissions must be `1/lambda`. - self.assertAlmostEqual( - float(numpy.mean(intervals)), - poisson_mean_interval_sec(config.emission_rate_per_min), - delta=1.0, - ) - finally: - await mixclient.cancel() diff --git a/mixnet/test_fisheryates.py b/mixnet/test_fisheryates.py deleted file mode 100644 index a32554c6..00000000 --- a/mixnet/test_fisheryates.py +++ /dev/null @@ -1,21 +0,0 @@ -from unittest import TestCase - -from mixnet.fisheryates import FisherYates - - -class TestFisherYates(TestCase): - def test_shuffle(self): - entropy = b"hello" - elems = [1, 2, 3, 4, 5] - - shuffled1 = FisherYates.shuffle(elems, entropy) - self.assertEqual(sorted(elems), sorted(shuffled1)) - - # shuffle again with the same entropy - shuffled2 = FisherYates.shuffle(elems, entropy) - self.assertEqual(shuffled1, shuffled2) - - # shuffle with a different entropy - shuffled3 = FisherYates.shuffle(elems, b"world") - self.assertNotEqual(shuffled1, shuffled3) - self.assertEqual(sorted(elems), sorted(shuffled3)) diff --git a/mixnet/test_mixnet.py b/mixnet/test_mixnet.py deleted file mode 100644 index 9a0e4cb1..00000000 --- a/mixnet/test_mixnet.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio -from unittest import IsolatedAsyncioTestCase - -from mixnet.mixnet import Mixnet -from mixnet.test_utils import init_mixnet_config - - -class TestMixnet(IsolatedAsyncioTestCase): - async def test_topology_from_robustness(self): - config = init_mixnet_config() - entropy_queue = asyncio.Queue() - - mixnet = await Mixnet.new(config, entropy_queue) - try: - old_topology = config.mixclient_config.topology - await entropy_queue.put(b"new entropy") - await asyncio.sleep(1) - self.assertNotEqual(old_topology, mixnet.get_topology()) - finally: - await mixnet.cancel() diff --git a/mixnet/test_node.py b/mixnet/test_node.py index 26ab0c7b..f4ba644a 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -1,117 +1,38 @@ import asyncio -from datetime import datetime from unittest import IsolatedAsyncioTestCase -import numpy -from pysphinx.sphinx import SphinxPacket - -from mixnet.node import MixNode, NodeAddress, PacketQueue -from mixnet.packet import PacketBuilder -from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec +from mixnet.node import Node from mixnet.test_utils import ( init_mixnet_config, - with_test_timeout, ) -class TestMixNodeRunner(IsolatedAsyncioTestCase): - @with_test_timeout(180) - async def test_mixnode_emission_rate(self): - """ - Test if MixNodeRunner works as a M/M/inf queue. - - If inputs are arrived at Poisson rate `lambda`, - and if processing is delayed according to an exponential distribution with a rate `mu`, - the rate of outputs should be `lambda`. - """ - config = init_mixnet_config() - config.mixclient_config.emission_rate_per_min = 120 # lambda (= 2msg/sec) - config.mixnode_config.delay_rate_per_min = 30 # mu (= 2s delay on average) - - packet, route = PacketBuilder.build_real_packets( - b"msg", config.mixclient_config.topology - )[0] - - # Start only the first mix node for testing - config.mixnode_config.encryption_private_key = route[0].encryption_private_key - mixnode = await MixNode.new(config.mixnode_config) - try: - # Send packets to the first mix node in a Poisson distribution - packet_count = 100 - # This queue is just for counting how many packets have been sent so far. - sent_packet_queue: PacketQueue = asyncio.Queue() - sender_task = asyncio.create_task( - self.send_packets( - mixnode.inbound_socket, - packet, - route[0].addr, - packet_count, - config.mixclient_config.emission_rate_per_min, - sent_packet_queue, - ) - ) +class TestNode(IsolatedAsyncioTestCase): + async def test_node(self): + global_config, node_configs, _ = init_mixnet_config(10) + nodes = [Node(node_config, global_config) for node_config in node_configs] + for i, node in enumerate(nodes): try: - # Calculate intervals between outputs and gather num_jobs in the first mix node. - intervals = [] - num_jobs = [] - ts = datetime.now() - for _ in range(packet_count): - _ = await mixnode.outbound_socket.get() - now = datetime.now() - intervals.append((now - ts).total_seconds()) - - # Calculate the current # of jobs staying in the mix node - num_packets_emitted_from_mixnode = len(intervals) - num_packets_sent_to_mixnode = sent_packet_queue.qsize() - num_jobs.append( - num_packets_sent_to_mixnode - num_packets_emitted_from_mixnode - ) - - ts = now - - # Remove the first interval that would be much larger than other intervals, - # because of the delay in mix node. - intervals = intervals[1:] - num_jobs = num_jobs[1:] - - # Check if the emission rate of the first mix node is the same as - # the emission rate of the message sender, but with a delay. - # If outputs follow the Poisson distribution with a rate `lambda`, - # a mean interval between outputs must be `1/lambda`. - self.assertAlmostEqual( - float(numpy.mean(intervals)), - poisson_mean_interval_sec( - config.mixclient_config.emission_rate_per_min - ), - delta=1.0, - ) - # If runner is a M/M/inf queue, - # a mean number of jobs being processed/scheduled in the runner must be `lambda/mu`. - self.assertAlmostEqual( - float(numpy.mean(num_jobs)), - round( - config.mixclient_config.emission_rate_per_min - / config.mixnode_config.delay_rate_per_min - ), - delta=1.5, - ) - finally: - await sender_task - finally: - await mixnode.cancel() - - @staticmethod - async def send_packets( - inbound_socket: PacketQueue, - packet: SphinxPacket, - node_addr: NodeAddress, - cnt: int, - rate_per_min: int, - # For testing purpose, to inform the caller how many packets have been sent to the inbound_socket - sent_packet_queue: PacketQueue, - ): - for _ in range(cnt): - # Since the task is not heavy, just sleep for seconds instead of using emission_notifier - await asyncio.sleep(poisson_interval_sec(rate_per_min)) - await inbound_socket.put((node_addr, packet)) - await sent_packet_queue.put((node_addr, packet)) + node.connect(nodes[(i + 1) % len(nodes)]) + except ValueError as e: + print(e) + + await nodes[0].send_message(b"block selection") + + timeout = 15 + for _ in range(timeout): + broadcasted_msgs = [] + for node in nodes: + if not node.broadcast_channel.empty(): + broadcasted_msgs.append(node.broadcast_channel.get_nowait()) + + if len(broadcasted_msgs) == 0: + await asyncio.sleep(1) + else: + # We expect only one node to broadcast the message. + assert len(broadcasted_msgs) == 1 + self.assertEqual(b"block selection", broadcasted_msgs[0]) + return + self.fail("timeout") + + # TODO: check noise diff --git a/mixnet/test_packet.py b/mixnet/test_packet.py deleted file mode 100644 index d1d517a6..00000000 --- a/mixnet/test_packet.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import List -from unittest import TestCase - -from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket - -from mixnet.config import MixNodeInfo -from mixnet.packet import ( - Fragment, - MessageFlag, - MessageReconstructor, - PacketBuilder, -) -from mixnet.test_utils import init_mixnet_config -from mixnet.utils import random_bytes - - -class TestPacket(TestCase): - def test_real_packet(self): - topology = init_mixnet_config().mixclient_config.topology - msg = random_bytes(3500) - packets_and_routes = PacketBuilder.build_real_packets(msg, topology) - self.assertEqual(4, len(packets_and_routes)) - - reconstructor = MessageReconstructor() - self.assertIsNone( - reconstructor.add( - self.process_packet(packets_and_routes[1][0], packets_and_routes[1][1]) - ), - ) - self.assertIsNone( - reconstructor.add( - self.process_packet(packets_and_routes[3][0], packets_and_routes[3][1]) - ), - ) - self.assertIsNone( - reconstructor.add( - self.process_packet(packets_and_routes[2][0], packets_and_routes[2][1]) - ), - ) - msg_with_flag = reconstructor.add( - self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1]) - ) - assert msg_with_flag is not None - self.assertEqual( - PacketBuilder.parse_msg_and_flag(msg_with_flag), - (MessageFlag.MESSAGE_FLAG_REAL, msg), - ) - - def test_cover_packet(self): - topology = init_mixnet_config().mixclient_config.topology - msg = b"cover" - packets_and_routes = PacketBuilder.build_drop_cover_packets(msg, topology) - self.assertEqual(1, len(packets_and_routes)) - - reconstructor = MessageReconstructor() - msg_with_flag = reconstructor.add( - self.process_packet(packets_and_routes[0][0], packets_and_routes[0][1]) - ) - assert msg_with_flag is not None - self.assertEqual( - PacketBuilder.parse_msg_and_flag(msg_with_flag), - (MessageFlag.MESSAGE_FLAG_DROP_COVER, msg), - ) - - @staticmethod - def process_packet(packet: SphinxPacket, route: List[MixNodeInfo]) -> Fragment: - processed = packet.process(route[0].encryption_private_key) - if isinstance(processed, ProcessedFinalHopPacket): - return Fragment.from_bytes(processed.payload.recover_plain_playload()) - else: - processed = processed - for node in route[1:]: - p = processed.next_packet.process(node.encryption_private_key) - if isinstance(p, ProcessedFinalHopPacket): - return Fragment.from_bytes(p.payload.recover_plain_playload()) - else: - processed = p - assert False diff --git a/mixnet/test_sphinx.py b/mixnet/test_sphinx.py new file mode 100644 index 00000000..3be3d229 --- /dev/null +++ b/mixnet/test_sphinx.py @@ -0,0 +1,39 @@ +from random import randint +from typing import cast +from unittest import TestCase + +from pysphinx.sphinx import ( + ProcessedFinalHopPacket, + ProcessedForwardHopPacket, +) + +from mixnet.sphinx import SphinxPacketBuilder +from mixnet.test_utils import init_mixnet_config + + +class TestSphinxPacketBuilder(TestCase): + def test_builder(self): + global_config, _, key_map = init_mixnet_config(10) + msg = self.random_bytes(500) + packet, route = SphinxPacketBuilder.build(msg, global_config, 3) + self.assertEqual(3, len(route)) + + processed = packet.process(key_map[route[0].public_key.public_bytes_raw()]) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[1].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedForwardHopPacket) + processed = cast(ProcessedForwardHopPacket, processed).next_packet.process( + key_map[route[2].public_key.public_bytes_raw()] + ) + self.assertIsInstance(processed, ProcessedFinalHopPacket) + recovered = cast( + ProcessedFinalHopPacket, processed + ).payload.recover_plain_playload() + self.assertEqual(msg, recovered) + + @staticmethod + def random_bytes(size: int) -> bytes: + assert size >= 0 + return bytes([randint(0, 255) for _ in range(size)]) diff --git a/mixnet/test_utils.py b/mixnet/test_utils.py index e3ba2607..b552f54f 100644 --- a/mixnet/test_utils.py +++ b/mixnet/test_utils.py @@ -1,46 +1,36 @@ -import asyncio - from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey -from mixnet.bls import generate_bls from mixnet.config import ( - MixClientConfig, - MixNodeConfig, - MixnetConfig, - MixNodeInfo, - MixnetTopology, - MixnetTopologyConfig, - MixnetTopologySize, + GlobalConfig, + MixMembership, + NodeConfig, + NodeInfo, + NomssipConfig, ) -from mixnet.utils import random_bytes - - -def with_test_timeout(t): - def wrapper(coroutine): - async def run(*args, **kwargs): - async with asyncio.timeout(t): - return await coroutine(*args, **kwargs) - return run - return wrapper - - -def init_mixnet_config() -> MixnetConfig: - topology_config = MixnetTopologyConfig( - [ - MixNodeInfo( - generate_bls(), - X25519PrivateKey.generate(), - random_bytes(32), - ) - for _ in range(12) - ], - MixnetTopologySize(3, 3), - b"entropy", - ) - mixclient_config = MixClientConfig(30, 3, MixnetTopology(topology_config)) - mixnode_config = MixNodeConfig( - topology_config.mixnode_candidates[0].encryption_private_key, 30 +def init_mixnet_config( + num_nodes: int, +) -> tuple[GlobalConfig, list[NodeConfig], dict[bytes, X25519PrivateKey]]: + max_mix_path_length = 3 + gossip_config = NomssipConfig(peering_degree=6) + node_configs = [ + NodeConfig(X25519PrivateKey.generate(), max_mix_path_length, gossip_config) + for _ in range(num_nodes) + ] + global_config = GlobalConfig( + MixMembership( + [ + NodeInfo(node_config.private_key.public_key()) + for node_config in node_configs + ] + ), + transmission_rate_per_sec=3, + max_message_size=512, + max_mix_path_length=max_mix_path_length, ) - return MixnetConfig(topology_config, mixclient_config, mixnode_config) + key_map = { + node_config.private_key.public_key().public_bytes_raw(): node_config.private_key + for node_config in node_configs + } + return (global_config, node_configs, key_map) diff --git a/mixnet/utils.py b/mixnet/utils.py deleted file mode 100644 index 6b45176c..00000000 --- a/mixnet/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -from random import randint - - -def random_bytes(size: int) -> bytes: - assert size >= 0 - return bytes([randint(0, 255) for _ in range(size)]) diff --git a/requirements.txt b/requirements.txt index 38ae3bd6..84adb0e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cffi==1.16.0 cryptography==41.0.7 numpy==1.26.3 pycparser==2.21 -pysphinx==0.0.1 +pysphinx==0.0.3 scipy==1.11.4 black==23.12.1 sympy==1.12 @@ -12,4 +12,4 @@ toml==0.10.2 # used for noir portalocker==2.8.2 # portable file locking keum==0.2.0 # for CL's use of more obscure curves poseidon-hash==0.1.4 # used as the algebraic hash in CL -hypothesis==6.103.0 +hypothesis==6.103.0 \ No newline at end of file