Skip to content

Commit

Permalink
refactor nomssip vs node encapsulation
Browse files Browse the repository at this point in the history
  • Loading branch information
youngjoon-lee committed Jul 11, 2024
1 parent 953b2d6 commit 010733c
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 93 deletions.
114 changes: 40 additions & 74 deletions mixnet/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
from enum import Enum
from typing import TypeAlias

from pysphinx.sphinx import (
Expand All @@ -12,7 +11,6 @@
)

from mixnet.config import GlobalConfig, NodeConfig
from mixnet.connection import DuplexConnection, MixSimplexConnection
from mixnet.nomssip import Nomssip
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder

Expand All @@ -32,97 +30,86 @@ class Node:
nomssip: Nomssip
reconstructor: MessageReconstructor
broadcast_channel: BroadcastChannel
# The actual packet size is calculated based on the max length of mix path by Sphinx encoding
# when the node is initialized, so that it can be used to generate noise packets.
packet_size: int

def __init__(self, config: NodeConfig, global_config: GlobalConfig):
self.config = config
self.global_config = global_config
self.nomssip = Nomssip(config.nomssip, self.__process_msg)
self.nomssip = Nomssip(
Nomssip.Config(
global_config.transmission_rate_per_sec,
config.nomssip.peering_degree,
self.__calculate_message_size(global_config),
),
self.__process_msg,
)
self.reconstructor = MessageReconstructor()
self.broadcast_channel = asyncio.Queue()

@staticmethod
def __calculate_message_size(global_config: GlobalConfig) -> int:
"""
Calculate the actual message size to be gossiped, which depends on the maximum length of mix path.
"""
sample_packet, _ = PacketBuilder.build_real_packets(
bytes(1), global_config.membership, self.global_config.max_mix_path_length
bytes(1), global_config.membership, global_config.max_mix_path_length
)[0]
self.packet_size = len(sample_packet.bytes())
return len(sample_packet.bytes())

async def __process_msg(self, msg: bytes) -> bytes | None:
async def __process_msg(self, msg: bytes) -> None:
"""
A handler to process messages received via gossip channel
"""
flag, msg = Node.__parse_msg(msg)
match flag:
case MsgType.NOISE:
# Drop noise packet
return None
case MsgType.REAL:
# Handle the packet and gossip the result if needed.
sphinx_packet = SphinxPacket.from_bytes(msg)
new_sphinx_packet = await self.__process_sphinx_packet(sphinx_packet)
if new_sphinx_packet is None:
return None
return Node.__build_msg(MsgType.REAL, new_sphinx_packet.bytes())
sphinx_packet = SphinxPacket.from_bytes(msg)
result = await self.__process_sphinx_packet(sphinx_packet)
match result:
case SphinxPacket():
# Gossip the next Sphinx packet
await self.nomssip.gossip(result.bytes())
case bytes():
# Broadcast the message fully recovered from Sphinx packets
await self.broadcast_channel.put(result)
case None:
return

async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> SphinxPacket | None:
) -> SphinxPacket | bytes | None:
"""
Unwrap the Sphinx packet and process the next Sphinx packet or the payload.
Unwrap the Sphinx packet and process the next Sphinx packet or the payload if possible
"""
try:
processed = packet.process(self.config.private_key)
match processed:
case ProcessedForwardHopPacket():
return processed.next_packet
case ProcessedFinalHopPacket():
await self.__process_sphinx_payload(processed.payload)
return await self.__process_sphinx_payload(processed.payload)
except ValueError:
# Return SphinxPacket as it is, if it cannot be unwrapped by the private key of this node.
return packet
# Return nothing, if it cannot be unwrapped by the private key of this node.
return None

async def __process_sphinx_payload(self, payload: Payload):
async def __process_sphinx_payload(self, payload: Payload) -> bytes | None:
"""
Process the Sphinx payload and broadcast it if it is a real message.
Process the Sphinx payload if possible
"""
msg_with_flag = self.reconstructor.add(
Fragment.from_bytes(payload.recover_plain_playload())
)
if msg_with_flag is not None:
flag, msg = PacketBuilder.parse_msg_and_flag(msg_with_flag)
if flag == MessageFlag.MESSAGE_FLAG_REAL:
await self.broadcast_channel.put(msg)
return msg
return None

def connect(self, peer: Node):
"""
Establish a duplex connection with a peer node.
"""
noise_msg = Node.__build_msg(MsgType.NOISE, bytes(self.packet_size))
inbound_conn, outbound_conn = asyncio.Queue(), asyncio.Queue()

# Register a duplex connection for its own use
self.nomssip.add_conn(
DuplexConnection(
inbound_conn,
MixSimplexConnection(
outbound_conn,
self.global_config.transmission_rate_per_sec,
noise_msg,
),
)
)
# Register the same duplex connection for the peer
peer.nomssip.add_conn(
DuplexConnection(
outbound_conn,
MixSimplexConnection(
inbound_conn,
self.global_config.transmission_rate_per_sec,
noise_msg,
),
)
)
self.nomssip.add_conn(inbound_conn, outbound_conn)
# Register a duplex connection for the peer
peer.nomssip.add_conn(outbound_conn, inbound_conn)

async def send_message(self, msg: bytes):
"""
Expand All @@ -135,25 +122,4 @@ async def send_message(self, msg: bytes):
self.global_config.membership,
self.config.mix_path_length,
):
await self.nomssip.gossip(Node.__build_msg(MsgType.REAL, packet.bytes()))

@staticmethod
def __build_msg(flag: MsgType, data: bytes) -> bytes:
"""
Prepend a flag to the message, right before sending it via network channel.
"""
return flag.value + data

@staticmethod
def __parse_msg(data: bytes) -> tuple[MsgType, bytes]:
"""
Parse the message and extract the flag.
"""
if len(data) < 1:
raise ValueError("Invalid message format")
return (MsgType(data[:1]), data[1:])


class MsgType(Enum):
REAL = b"\x00"
NOISE = b"\x01"
await self.nomssip.gossip(packet.bytes())
103 changes: 84 additions & 19 deletions mixnet/nomssip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import hashlib
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable

from mixnet.config import NomssipConfig
from mixnet.connection import DuplexConnection
from mixnet.connection import DuplexConnection, MixSimplexConnection, SimplexConnection


class Nomssip:
Expand All @@ -12,31 +13,49 @@ class Nomssip:
Peers are connected via DuplexConnection.
"""

config: NomssipConfig
@dataclass
class Config:
transmission_rate_per_sec: int
peering_degree: int
msg_size: int

config: Config
conns: list[DuplexConnection]
# A handler to process inbound messages.
handler: Callable[[bytes], Awaitable[bytes | None]]
# A set of message hashes to prevent processing the same message twice.
msg_cache: set[bytes]
handler: Callable[[bytes], Awaitable[None]]
# A set of packet hashes to prevent gossiping/processing the same packet twice.
packet_cache: set[bytes]

def __init__(
self,
config: NomssipConfig,
handler: Callable[[bytes], Awaitable[bytes | None]],
config: Config,
handler: Callable[[bytes], Awaitable[None]],
):
self.config = config
self.conns = []
self.handler = handler
self.msg_cache = set()
self.packet_cache = set()
# A set just for gathering a reference of tasks to prevent them from being garbage collected.
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.tasks = set()

def add_conn(self, conn: DuplexConnection):
def add_conn(self, inbound: SimplexConnection, outbound: SimplexConnection):
if len(self.conns) >= self.config.peering_degree:
# For simplicity of the spec, reject the connection if the peering degree is reached.
raise ValueError("The peering degree is reached.")

noise_packet = self.__build_packet(
self.PacketType.NOISE, bytes(self.config.msg_size)
)
conn = DuplexConnection(
inbound,
MixSimplexConnection(
outbound,
self.config.transmission_rate_per_sec,
noise_packet,
),
)

self.conns.append(conn)
task = asyncio.create_task(self.__process_inbound_conn(conn))
self.tasks.add(task)
Expand All @@ -45,17 +64,63 @@ def add_conn(self, conn: DuplexConnection):

async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
msg = await conn.recv()
# Don't process the same message twice.
msg_hash = hashlib.sha256(msg).digest()
if msg_hash in self.msg_cache:
packet = await conn.recv()
if self.__check_update_cache(packet):
continue
self.msg_cache.add(msg_hash)

new_msg = await self.handler(msg)
if new_msg is not None:
await self.gossip(new_msg)
flag, msg = self.__parse_packet(packet)
match flag:
case self.PacketType.NOISE:
# Drop noise packet
continue
case self.PacketType.REAL:
await self.__gossip(packet)
await self.handler(msg)

async def gossip(self, msg: bytes):
"""
Gossip a message to all connected peers if the message has not been gossiped yet.
"""
# The message size must be fixed.
assert len(msg) == self.config.msg_size

packet = self.__build_packet(self.PacketType.REAL, msg)
if not self.__check_update_cache(packet):
await self.__gossip(packet)

async def gossip(self, packet: bytes):
async def __gossip(self, packet: bytes):
"""
An internal method to send a packet to all connected peers without checking the cache
"""
for conn in self.conns:
await conn.send(packet)

def __check_update_cache(self, packet: bytes) -> bool:
"""
Add a message to the cache, and return True if the message was already in the cache.
"""
hash = hashlib.sha256(packet).digest()
if hash in self.packet_cache:
return True
self.packet_cache.add(hash)
return False

class PacketType(Enum):
REAL = b"\x00"
NOISE = b"\x01"

@staticmethod
def __build_packet(flag: PacketType, data: bytes) -> bytes:
"""
Prepend a flag to the message, right before sending it via network channel.
"""
return flag.value + data

@staticmethod
def __parse_packet(data: bytes) -> tuple[PacketType, bytes]:
"""
Parse the message and extract the flag.
"""
if len(data) < 1:
raise ValueError("Invalid message format")
return (Nomssip.PacketType(data[:1]), data[1:])

0 comments on commit 010733c

Please sign in to comment.