From ecb6f5dd686b64a878c87dc318306a4f6160e000 Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:23:42 +0900 Subject: [PATCH] refactor nomssip vs node encapsulation --- mixnet/node.py | 114 +++++++++++++++------------------------------- mixnet/nomssip.py | 103 +++++++++++++++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 96 deletions(-) diff --git a/mixnet/node.py b/mixnet/node.py index f9a4849e..37385d44 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from enum import Enum from typing import TypeAlias from pysphinx.sphinx import ( @@ -12,7 +11,6 @@ ) from mixnet.config import GlobalConfig, NodeConfig -from mixnet.connection import DuplexConnection, MixSimplexConnection from mixnet.nomssip import Nomssip from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder @@ -32,44 +30,46 @@ class Node: nomssip: Nomssip reconstructor: MessageReconstructor broadcast_channel: BroadcastChannel - # The actual packet size is calculated based on the max length of mix path by Sphinx encoding - # when the node is initialized, so that it can be used to generate noise packets. - packet_size: int def __init__(self, config: NodeConfig, global_config: GlobalConfig): + sample_packet, _ = PacketBuilder.build_real_packets( + bytes(1), global_config.membership, global_config.max_mix_path_length + )[0] + self.config = config self.global_config = global_config - self.nomssip = Nomssip(config.nomssip, self.__process_msg) + self.nomssip = Nomssip( + Nomssip.Config( + global_config.transmission_rate_per_sec, + config.nomssip.peering_degree, + len(sample_packet.bytes()), + ), + self.__process_msg, + ) self.reconstructor = MessageReconstructor() self.broadcast_channel = asyncio.Queue() - sample_packet, _ = PacketBuilder.build_real_packets( - bytes(1), global_config.membership, self.global_config.max_mix_path_length - )[0] - self.packet_size = len(sample_packet.bytes()) - - async def __process_msg(self, msg: bytes) -> bytes | None: + async def __process_msg(self, msg: bytes) -> None: """ A handler to process messages received via gossip channel """ - flag, msg = Node.__parse_msg(msg) - match flag: - case MsgType.NOISE: - # Drop noise packet - return None - case MsgType.REAL: - # Handle the packet and gossip the result if needed. - sphinx_packet = SphinxPacket.from_bytes(msg) - new_sphinx_packet = await self.__process_sphinx_packet(sphinx_packet) - if new_sphinx_packet is None: - return None - return Node.__build_msg(MsgType.REAL, new_sphinx_packet.bytes()) + sphinx_packet = SphinxPacket.from_bytes(msg) + result = await self.__process_sphinx_packet(sphinx_packet) + match result: + case SphinxPacket(): + # Gossip the next Sphinx packet + await self.nomssip.gossip(result.bytes()) + case bytes(): + # Broadcast the message fully recovered from Sphinx packets + await self.broadcast_channel.put(result) + case None: + return async def __process_sphinx_packet( self, packet: SphinxPacket - ) -> SphinxPacket | None: + ) -> SphinxPacket | bytes | None: """ - Unwrap the Sphinx packet and process the next Sphinx packet or the payload. + Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible """ try: processed = packet.process(self.config.private_key) @@ -77,14 +77,14 @@ async def __process_sphinx_packet( case ProcessedForwardHopPacket(): return processed.next_packet case ProcessedFinalHopPacket(): - await self.__process_sphinx_payload(processed.payload) + return await self.__process_sphinx_payload(processed.payload) except ValueError: - # Return SphinxPacket as it is, if it cannot be unwrapped by the private key of this node. - return packet + # Return nothing, if it cannot be unwrapped by the private key of this node. + return None - async def __process_sphinx_payload(self, payload: Payload): + async def __process_sphinx_payload(self, payload: Payload) -> bytes | None: """ - Process the Sphinx payload and broadcast it if it is a real message. + Process the Sphinx payload if possible """ msg_with_flag = self.reconstructor.add( Fragment.from_bytes(payload.recover_plain_playload()) @@ -92,37 +92,18 @@ async def __process_sphinx_payload(self, payload: Payload): if msg_with_flag is not None: flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag) if flag == MessageFlag.MESSAGE_FLAG_REAL: - await self.broadcast_channel.put(msg) + return msg + return None def connect(self, peer: Node): """ Establish a duplex connection with a peer node. """ - noise_msg = Node.__build_msg(MsgType.NOISE, bytes(self.packet_size)) inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue() - # Register a duplex connection for its own use - self.nomssip.add_conn( - DuplexConnection( - inbound_conn, - MixSimplexConnection( - outbound_conn, - self.global_config.transmission_rate_per_sec, - noise_msg, - ), - ) - ) - # Register the same duplex connection for the peer - peer.nomssip.add_conn( - DuplexConnection( - outbound_conn, - MixSimplexConnection( - inbound_conn, - self.global_config.transmission_rate_per_sec, - noise_msg, - ), - ) - ) + self.nomssip.add_conn(inbound_conn, outbound_conn) + # Register a duplex connection for the peer + peer.nomssip.add_conn(outbound_conn, inbound_conn) async def send_message(self, msg: bytes): """ @@ -135,25 +116,4 @@ async def send_message(self, msg: bytes): self.global_config.membership, self.config.mix_path_length, ): - await self.nomssip.gossip(Node.__build_msg(MsgType.REAL, packet.bytes())) - - @staticmethod - def __build_msg(flag: MsgType, data: bytes) -> bytes: - """ - Prepend a flag to the message, right before sending it via network channel. - """ - return flag.value + data - - @staticmethod - def __parse_msg(data: bytes) -> tuple[MsgType, bytes]: - """ - Parse the message and extract the flag. - """ - if len(data) < 1: - raise ValueError("Invalid message format") - return (MsgType(data[:1]), data[1:]) - - -class MsgType(Enum): - REAL = b"\x00" - NOISE = b"\x01" + await self.nomssip.gossip(packet.bytes()) diff --git a/mixnet/nomssip.py b/mixnet/nomssip.py index a3244598..475fd9f5 100644 --- a/mixnet/nomssip.py +++ b/mixnet/nomssip.py @@ -1,9 +1,10 @@ import asyncio import hashlib +from dataclasses import dataclass +from enum import Enum from typing import Awaitable, Callable -from mixnet.config import NomssipConfig -from mixnet.connection import DuplexConnection +from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection class Nomssip: @@ -12,31 +13,49 @@ class Nomssip: Peers are connected via DuplexConnection. """ - config: NomssipConfig + @dataclass + class Config: + transmission_rate_per_sec: int + peering_degree: int + msg_size: int + + config: Config conns: list[DuplexConnection] # A handler to process inbound messages. - handler: Callable[[bytes], Awaitable[bytes | None]] - # A set of message hashes to prevent processing the same message twice. - msg_cache: set[bytes] + handler: Callable[[bytes], Awaitable[None]] + # A set of packet hashes to prevent gossiping/processing the same packet twice. + packet_cache: set[bytes] def __init__( self, - config: NomssipConfig, - handler: Callable[[bytes], Awaitable[bytes | None]], + config: Config, + handler: Callable[[bytes], Awaitable[None]], ): self.config = config self.conns = [] self.handler = handler - self.msg_cache = set() + self.packet_cache = set() # A set just for gathering a reference of tasks to prevent them from being garbage collected. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task self.tasks = set() - def add_conn(self, conn: DuplexConnection): + def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection): if len(self.conns) >= self.config.peering_degree: # For simplicity of the spec, reject the connection if the peering degree is reached. raise ValueError("The peering degree is reached.") + noise_packet = self.__build_packet( + self.PacketType.NOISE, bytes(self.config.msg_size) + ) + conn = DuplexConnection( + inbound, + MixSimplexConnection( + outbound, + self.config.transmission_rate_per_sec, + noise_packet, + ), + ) + self.conns.append(conn) task = asyncio.create_task(self.__process_inbound_conn(conn)) self.tasks.add(task) @@ -45,17 +64,63 @@ def add_conn(self, conn: DuplexConnection): async def __process_inbound_conn(self, conn: DuplexConnection): while True: - msg = await conn.recv() - # Don't process the same message twice. - msg_hash = hashlib.sha256(msg).digest() - if msg_hash in self.msg_cache: + packet = await conn.recv() + if self.__check_update_cache(packet): continue - self.msg_cache.add(msg_hash) - new_msg = await self.handler(msg) - if new_msg is not None: - await self.gossip(new_msg) + flag, msg = self.__parse_packet(packet) + match flag: + case self.PacketType.NOISE: + # Drop noise packet + continue + case self.PacketType.REAL: + await self.__gossip(packet) + await self.handler(msg) + + async def gossip(self, msg: bytes): + """ + Gossip a message to all connected peers if the message has not been gossiped yet. + """ + # The message size must be fixed. + assert len(msg) == self.config.msg_size + + packet = self.__build_packet(self.PacketType.REAL, msg) + if not self.__check_update_cache(packet): + await self.__gossip(packet) - async def gossip(self, packet: bytes): + async def __gossip(self, packet: bytes): + """ + An internal method to send a packet to all connected peers without checking the cache + """ for conn in self.conns: await conn.send(packet) + + def __check_update_cache(self, packet: bytes) -> bool: + """ + Add a message to the cache, and return True if the message was already in the cache. + """ + hash = hashlib.sha256(packet).digest() + if hash in self.packet_cache: + return True + self.packet_cache.add(hash) + return False + + class PacketType(Enum): + REAL = b"\x00" + NOISE = b"\x01" + + @staticmethod + def __build_packet(flag: PacketType, data: bytes) -> bytes: + """ + Prepend a flag to the message, right before sending it via network channel. + """ + return flag.value + data + + @staticmethod + def __parse_packet(data: bytes) -> tuple[PacketType, bytes]: + """ + Parse the message and extract the flag. + """ + if len(data) < 1: + raise ValueError("Invalid message format") + return (Nomssip.PacketType(data[:1]), data[1:])