Skip to content

Commit

Permalink
sphinx serde
Browse files Browse the repository at this point in the history
  • Loading branch information
youngjoon-lee committed Jun 28, 2024
1 parent 3a70343 commit 1506b00
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions mixnet/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import hashlib
from enum import Enum
from typing import Awaitable, Callable, TypeAlias

from pysphinx.payload import DEFAULT_PAYLOAD_SIZE
Expand All @@ -15,8 +16,7 @@
from mixnet.config import GlobalConfig, NodeConfig
from mixnet.packet import Fragment, MessageFlag, MessageReconstructor, PacketBuilder

NetworkPacket: TypeAlias = SphinxPacket | bytes
NetworkPacketQueue: TypeAlias = asyncio.Queue[NetworkPacket]
NetworkPacketQueue: TypeAlias = asyncio.Queue[bytes]
Connection: TypeAlias = NetworkPacketQueue
BroadcastChannel: TypeAlias = asyncio.Queue[bytes]

Expand All @@ -39,7 +39,7 @@ def __init__(self, config: NodeConfig, global_config: GlobalConfig):

async def __process_sphinx_packet(
self, packet: SphinxPacket
) -> NetworkPacket | None:
) -> SphinxPacket | None:
try:
processed = packet.process(self.config.private_key)
match processed:
Expand Down Expand Up @@ -83,19 +83,19 @@ async def send_message(self, msg: bytes):
for packet, _ in PacketBuilder.build_real_packets(
msg, self.global_config.membership
):
await self.mixgossip_channel.gossip(packet)
await self.mixgossip_channel.gossip(MsgType.REAL.value + packet.bytes())


class MixGossipChannel:
peering_degree: int
conns: list[DuplexConnection]
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]]
msg_cache: set[NetworkPacket]
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]]
msg_cache: set[bytes]

def __init__(
self,
peer_degree: int,
handler: Callable[[SphinxPacket], Awaitable[NetworkPacket | None]],
handler: Callable[[SphinxPacket], Awaitable[SphinxPacket | None]],
):
self.peering_degree = peer_degree
self.conns = []
Expand All @@ -118,26 +118,24 @@ def add_conn(self, conn: DuplexConnection):

async def __process_inbound_conn(self, conn: DuplexConnection):
while True:
elem = await conn.recv()
# In practice, data transmitted through connections is going to be always 'bytes'.
# But here, we use the SphinxPacket type explicitly for simplicity
# without implementing serde for SphinxPacket.
if isinstance(elem, bytes):
assert elem == build_noise_packet()
# Drop packet
msg = await conn.recv()
if msg[:1] != MsgType.REAL.value:
# Drop noise packet
continue
elif isinstance(elem, SphinxPacket):
# Don't process the same message twice.
msg_hash = hashlib.sha256(elem.bytes()).digest()
if msg_hash in self.msg_cache:
continue
self.msg_cache.add(msg_hash)
# Handle the packet and gossip the result if needed.
net_packet = await self.handler(elem)
if net_packet is not None:
await self.gossip(net_packet)

async def gossip(self, packet: NetworkPacket):

# Don't process the same message twice.
msg_hash = hashlib.sha256(msg).digest()
if msg_hash in self.msg_cache:
continue
self.msg_cache.add(msg_hash)

# Handle the packet and gossip the result if needed.
sphinx_packet = SphinxPacket.from_bytes(msg[1:])
new_sphinx_packet = await self.handler(sphinx_packet)
if new_sphinx_packet is not None:
await self.gossip(MsgType.REAL.value + new_sphinx_packet.bytes())

async def gossip(self, packet: bytes):
for conn in self.conns:
await conn.send(packet)

Expand All @@ -150,10 +148,10 @@ def __init__(self, inbound: Connection, outbound: MixOutboundConnection):
self.inbound = inbound
self.outbound = outbound

async def recv(self) -> NetworkPacket:
async def recv(self) -> bytes:
return await self.inbound.get()

async def send(self, packet: NetworkPacket):
async def send(self, packet: bytes):
await self.outbound.send(packet)


Expand All @@ -178,9 +176,14 @@ async def __run(self):
elem = self.queue.get_nowait()
await self.conn.put(elem)

async def send(self, elem: NetworkPacket):
async def send(self, elem: bytes):
await self.queue.put(elem)


class MsgType(Enum):
REAL = b"\x00"
NOISE = b"\x01"


def build_noise_packet() -> bytes:
return bytes(DEFAULT_PAYLOAD_SIZE)
return MsgType.NOISE.value + bytes(DEFAULT_PAYLOAD_SIZE)

0 comments on commit 1506b00

Please sign in to comment.