diff --git a/server/conn.go b/server/conn.go index b62ad27..f6ce2b9 100644 --- a/server/conn.go +++ b/server/conn.go @@ -263,38 +263,44 @@ func (c *Conn) read() (pk any, err error) { case <-c.closed: return nil, net.ErrClosed default: - payload, err := c.reader.ReadPacket() - if err != nil { - return nil, err - } + } - if payload[0] != packetDecodeNeeded && payload[0] != packetDecodeNotNeeded { - return nil, fmt.Errorf("unknown decode byte marker %v", payload[0]) - } + payload, err := c.reader.ReadPacket() + if err != nil { + return nil, err + } - decompressed, err := snappy.Decode(nil, payload[1:]) - if err != nil { - return nil, err - } + if payload[0] != packetDecodeNeeded && payload[0] != packetDecodeNotNeeded { + return nil, fmt.Errorf("unknown decode byte marker %v", payload[0]) + } - if payload[0] == packetDecodeNotNeeded { - return decompressed, nil - } + decompressed, err := snappy.Decode(nil, payload[1:]) + if err != nil { + return nil, err + } - buf := bytes.NewBuffer(decompressed) - header := &packet.Header{} - if err := header.Read(buf); err != nil { - return nil, err - } + if payload[0] == packetDecodeNotNeeded { + return decompressed, nil + } + + buf := bytes.NewBuffer(decompressed) + header := &packet.Header{} + if err := header.Read(buf); err != nil { + return nil, err + } - factory, ok := c.pool[header.PacketID] - if !ok { - return nil, fmt.Errorf("unknown packet ID %v", header.PacketID) + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic while decoding packet %v: %v", header.PacketID, r) } - pk = factory() - pk.(packet.Packet).Marshal(c.protocol.NewReader(buf, c.shieldID, false)) - return pk, nil + }() + factory, ok := c.pool[header.PacketID] + if !ok { + return nil, fmt.Errorf("unknown packet ID %v", header.PacketID) } + pk = factory() + pk.(packet.Packet).Marshal(c.protocol.NewReader(buf, c.shieldID, false)) + return pk, nil } // deferPacket defers a packet to be returned later in ReadPacket().