diff --git a/mixnet/node.py b/mixnet/node.py index e914a68f..fd8b6774 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -15,8 +15,7 @@ from mixnet.config import GlobalConfig, NodeConfig from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder -NetworkPacket: TypeAlias = SphinxPacket | bytes -NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket] +NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes] Connection: TypeAlias = NetworkPacketQueue BroadcastChannel: TypeAlias = asyncio.Queue[bytes] @@ -39,7 +38,7 @@ def __init__(self, config: NodeConfig, global_config: GlobalConfig): async def __process_sphinx_packet( self, packet: SphinxPacket - ) -> NetworkPacket | None: + ) -> SphinxPacket | None: try: processed = packet.process(self.config.private_key) match processed: @@ -83,19 +82,19 @@ async def send_message(self, msg: bytes): for packet, _ in PacketBuilder.build_real_packets( msg, self.global_config.membership ): - await self.mixgossip_channel.gossip(packet) + await self.mixgossip_channel.gossip(FLAG_REAL + packet.bytes()) class MixGossipChannel: peering_degree: int conns: list[DuplexConnection] - handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]] - msg_cache: set[NetworkPacket] + handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]] + msg_cache: set[bytes] def __init__( self, peer_degree: int, - handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]], + handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]], ): self.peering_degree = peer_degree self.conns = [] @@ -118,26 +117,24 @@ def add_conn(self, conn: DuplexConnection): async def __process_inbound_conn(self, conn: DuplexConnection): while True: - elem = await conn.recv() - # In practice, data transmitted through connections is going to be always 'bytes'. - # But here, we use the SphinxPacket type explicitly for simplicity - # without implementing serde for SphinxPacket. - if isinstance(elem, bytes): - assert elem == build_noise_packet() - # Drop packet + msg = await conn.recv() + if msg[:1] != FLAG_REAL: + # Drop noise packet continue - elif isinstance(elem, SphinxPacket): - # Don't process the same message twice. - msg_hash = hashlib.sha256(elem.bytes()).digest() - if msg_hash in self.msg_cache: - continue - self.msg_cache.add(msg_hash) - # Handle the packet and gossip the result if needed. - net_packet = await self.handler(elem) - if net_packet is not None: - await self.gossip(net_packet) - - async def gossip(self, packet: NetworkPacket): + + # Don't process the same message twice. + msg_hash = hashlib.sha256(msg).digest() + if msg_hash in self.msg_cache: + continue + self.msg_cache.add(msg_hash) + + # Handle the packet and gossip the result if needed. + sphinx_packet = SphinxPacket.from_bytes(msg[1:]) + new_sphinx_packet = await self.handler(sphinx_packet) + if new_sphinx_packet is not None: + await self.gossip(FLAG_REAL + new_sphinx_packet.bytes()) + + async def gossip(self, packet: bytes): for conn in self.conns: await conn.send(packet) @@ -150,10 +147,10 @@ def __init__(self, inbound: Connection, outbound: MixOutboundConnection): self.inbound = inbound self.outbound = outbound - async def recv(self) -> NetworkPacket: + async def recv(self) -> bytes: return await self.inbound.get() - async def send(self, packet: NetworkPacket): + async def send(self, packet: bytes): await self.outbound.send(packet) @@ -178,9 +175,13 @@ async def __run(self): elem = self.queue.get_nowait() await self.conn.put(elem) - async def send(self, elem: NetworkPacket): + async def send(self, elem: bytes): await self.queue.put(elem) +FLAG_REAL = b"\x00" +FLAG_NOISE = b"\x01" + + def build_noise_packet() -> bytes: - return bytes(DEFAULT_PAYLOAD_SIZE) + return FLAG_NOISE + bytes(DEFAULT_PAYLOAD_SIZE)