diff --git a/dtls/server/session.go b/dtls/server/session.go index 90061a51..5f0d9154 100644 --- a/dtls/server/session.go +++ b/dtls/server/session.go @@ -146,7 +146,7 @@ func (s *Session) Run(cc *client.Conn) (err error) { return fmt.Errorf("cannot read from connection: %w", err) } readBuf = readBuf[:readLen] - err = cc.Process(readBuf) + err = cc.Process(nil, readBuf) if err != nil { return err } diff --git a/message/pool/message.go b/message/pool/message.go index 22cf8865..04fcc58c 100644 --- a/message/pool/message.go +++ b/message/pool/message.go @@ -10,6 +10,7 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/net" "go.uber.org/atomic" ) @@ -26,6 +27,7 @@ type Message struct { // Context context of request. ctx context.Context msg message.Message + controlMessage *net.ControlMessage // control message for UDP hijacked atomic.Bool isModified bool valueBuffer []byte @@ -73,6 +75,22 @@ func (r *Message) SetMessage(message message.Message) { r.isModified = true } +func (r *Message) SetControlMessage(cm *net.ControlMessage) { + r.controlMessage = cm +} + +func (r *Message) ControlMessage() *net.ControlMessage { + return r.controlMessage +} + +// UpsertControlMessage set value only when origin value is not set. +func (r *Message) UpsertControlMessage(cm *net.ControlMessage) { + if r.controlMessage != nil { + return + } + r.SetControlMessage(cm) +} + // SetMessageID only 0 to 2^16-1 are valid. func (r *Message) SetMessageID(mid int32) { r.msg.MessageID = mid @@ -120,6 +138,7 @@ func (r *Message) Reset() { r.valueBuffer = r.origValueBuffer r.body = nil r.isModified = false + r.controlMessage = nil if cap(r.bufferMarshal) > 1024 { r.bufferMarshal = make([]byte, 256) } @@ -568,6 +587,7 @@ func (r *Message) Clone(msg *Message) error { msg.ResetOptionsTo(r.Options()) msg.SetType(r.Type()) msg.SetMessageID(r.MessageID()) + msg.SetControlMessage(r.ControlMessage()) if r.Body() != nil { buf := bytes.NewBuffer(nil) diff --git a/net/connUDP.go b/net/connUDP.go index f749b22d..a241b5b9 100644 --- a/net/connUDP.go +++ b/net/connUDP.go @@ -24,8 +24,36 @@ type UDPConn struct { } type ControlMessage struct { - Src net.IP // source address, specifying only - IfIndex int // interface index, must be 1 <= value when specifying + // For connection oriented packetConn the ControlMessage fields are ignored, only linux supports set control message. + + Dst net.IP // destination address of the packet + Src net.IP // source address of the packet + IfIndex int // interface index, 0 means any interface +} + +func (c *ControlMessage) String() string { + if c == nil { + return "" + } + var sb strings.Builder + if c.Dst != nil { + sb.WriteString(fmt.Sprintf("Dst: %s, ", c.Dst)) + } + if c.Src != nil { + sb.WriteString(fmt.Sprintf("Src: %s, ", c.Src)) + } + if c.IfIndex >= 1 { + sb.WriteString(fmt.Sprintf("IfIndex: %d, ", c.IfIndex)) + } + return sb.String() +} + +// GetIfIndex returns the interface index of the network interface. 0 means no interface index specified. +func (c *ControlMessage) GetIfIndex() int { + if c == nil { + return 0 + } + return c.IfIndex } type packetConn interface { @@ -36,22 +64,37 @@ type packetConn interface { SetMulticastLoopback(on bool) error JoinGroup(ifi *net.Interface, group net.Addr) error LeaveGroup(ifi *net.Interface, group net.Addr) error + ReadFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error) + SupportsControlMessage() bool + IsIPv6() bool } type packetConnIPv4 struct { - packetConnIPv4 *ipv4.PacketConn + packetConn *ipv4.PacketConn + supportsControlMessage bool } func newPacketConnIPv4(p *ipv4.PacketConn) *packetConnIPv4 { - return &packetConnIPv4{p} + if err := p.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface|ipv4.FlagSrc, true); err != nil { + return &packetConnIPv4{packetConn: p, supportsControlMessage: false} + } + return &packetConnIPv4{packetConn: p, supportsControlMessage: true} +} + +func (p *packetConnIPv4) SupportsControlMessage() bool { + return p.supportsControlMessage +} + +func (p *packetConnIPv4) IsIPv6() bool { + return false } func (p *packetConnIPv4) SetMulticastInterface(ifi *net.Interface) error { - return p.packetConnIPv4.SetMulticastInterface(ifi) + return p.packetConn.SetMulticastInterface(ifi) } func (p *packetConnIPv4) SetWriteDeadline(t time.Time) error { - return p.packetConnIPv4.SetWriteDeadline(t) + return p.packetConn.SetWriteDeadline(t) } func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { @@ -62,39 +105,83 @@ func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n IfIndex: cm.IfIndex, } } - return p.packetConnIPv4.WriteTo(b, c, dst) + return p.packetConn.WriteTo(b, c, dst) +} + +func (p *packetConnIPv4) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) { + n, cm, src, err := p.packetConn.ReadFrom(b) + if err != nil { + return -1, nil, nil, err + } + var controlMessage *ControlMessage + if p.supportsControlMessage && cm != nil { + controlMessage = &ControlMessage{ + Dst: cm.Dst, + Src: cm.Src, + IfIndex: cm.IfIndex, + } + } + return n, controlMessage, src, err } func (p *packetConnIPv4) SetMulticastHopLimit(hoplim int) error { - return p.packetConnIPv4.SetMulticastTTL(hoplim) + return p.packetConn.SetMulticastTTL(hoplim) } func (p *packetConnIPv4) SetMulticastLoopback(on bool) error { - return p.packetConnIPv4.SetMulticastLoopback(on) + return p.packetConn.SetMulticastLoopback(on) } func (p *packetConnIPv4) JoinGroup(ifi *net.Interface, group net.Addr) error { - return p.packetConnIPv4.JoinGroup(ifi, group) + return p.packetConn.JoinGroup(ifi, group) } func (p *packetConnIPv4) LeaveGroup(ifi *net.Interface, group net.Addr) error { - return p.packetConnIPv4.LeaveGroup(ifi, group) + return p.packetConn.LeaveGroup(ifi, group) } type packetConnIPv6 struct { - packetConnIPv6 *ipv6.PacketConn + packetConn *ipv6.PacketConn + supportsControlMessage bool } func newPacketConnIPv6(p *ipv6.PacketConn) *packetConnIPv6 { - return &packetConnIPv6{p} + if err := p.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface|ipv6.FlagSrc, true); err != nil { + return &packetConnIPv6{packetConn: p, supportsControlMessage: false} + } + return &packetConnIPv6{packetConn: p, supportsControlMessage: true} +} + +func (p *packetConnIPv6) SupportsControlMessage() bool { + return p.supportsControlMessage +} + +func (p *packetConnIPv6) IsIPv6() bool { + return true } func (p *packetConnIPv6) SetMulticastInterface(ifi *net.Interface) error { - return p.packetConnIPv6.SetMulticastInterface(ifi) + return p.packetConn.SetMulticastInterface(ifi) } func (p *packetConnIPv6) SetWriteDeadline(t time.Time) error { - return p.packetConnIPv6.SetWriteDeadline(t) + return p.packetConn.SetWriteDeadline(t) +} + +func (p *packetConnIPv6) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) { + n, cm, src, err := p.packetConn.ReadFrom(b) + if err != nil { + return -1, nil, nil, err + } + var controlMessage *ControlMessage + if p.supportsControlMessage && cm != nil { + controlMessage = &ControlMessage{ + Dst: cm.Dst, + Src: cm.Src, + IfIndex: cm.IfIndex, + } + } + return n, controlMessage, src, err } func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { @@ -105,27 +192,23 @@ func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n IfIndex: cm.IfIndex, } } - return p.packetConnIPv6.WriteTo(b, c, dst) + return p.packetConn.WriteTo(b, c, dst) } func (p *packetConnIPv6) SetMulticastHopLimit(hoplim int) error { - return p.packetConnIPv6.SetMulticastHopLimit(hoplim) + return p.packetConn.SetMulticastHopLimit(hoplim) } func (p *packetConnIPv6) SetMulticastLoopback(on bool) error { - return p.packetConnIPv6.SetMulticastLoopback(on) + return p.packetConn.SetMulticastLoopback(on) } func (p *packetConnIPv6) JoinGroup(ifi *net.Interface, group net.Addr) error { - return p.packetConnIPv6.JoinGroup(ifi, group) + return p.packetConn.JoinGroup(ifi, group) } func (p *packetConnIPv6) LeaveGroup(ifi *net.Interface, group net.Addr) error { - return p.packetConnIPv6.LeaveGroup(ifi, group) -} - -func (p *packetConnIPv6) SetControlMessage(on bool) error { - return p.packetConnIPv6.SetMulticastLoopback(on) + return p.packetConn.LeaveGroup(ifi, group) } // IsIPv6 return's true if addr is IPV6. @@ -158,26 +241,41 @@ func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) { return NewUDPConn(network, conn, opts...), nil } -// NewUDPConn creates connection over net.UDPConn. -func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { - cfg := DefaultUDPConnConfig - for _, o := range opts { - o.ApplyUDP(&cfg) - } - +func newPacketConn(c *net.UDPConn) (packetConn, error) { laddr := c.LocalAddr() if laddr == nil { - panic(fmt.Errorf("invalid UDP connection")) + return nil, fmt.Errorf("invalid UDP connection") } addr, ok := laddr.(*net.UDPAddr) if !ok { - panic(fmt.Errorf("invalid address type(%T), UDP address expected", laddr)) + return nil, fmt.Errorf("invalid address type(%T), UDP address expected", laddr) } + return newPacketConnWithAddr(addr, c) +} + +func newPacketConnWithAddr(addr *net.UDPAddr, c *net.UDPConn) (packetConn, error) { var pc packetConn + var err error if IsIPv6(addr.IP) { pc = newPacketConnIPv6(ipv6.NewPacketConn(c)) } else { pc = newPacketConnIPv4(ipv4.NewPacketConn(c)) + if err != nil { + return nil, fmt.Errorf("invalid UDPv4 connection: %w", err) + } + } + return pc, nil +} + +// NewUDPConn creates connection over net.UDPConn. +func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { + cfg := DefaultUDPConnConfig + for _, o := range opts { + o.ApplyUDP(&cfg) + } + pc, err := newPacketConn(c) + if err != nil { + panic(err) } return &UDPConn{ @@ -211,41 +309,48 @@ func (c *UDPConn) Close() error { return c.connection.Close() } -func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error { - var pktSrc net.IP - var p packetConn - if IsIPv6(raddr.IP) { - p = newPacketConnIPv6(ipv6.NewPacketConn(c.connection)) - pktSrc = net.IPv6zero - } else { - p = newPacketConnIPv4(ipv4.NewPacketConn(c.connection)) - pktSrc = net.IPv4zero - } +func toPacketSrcIP(src *net.IP, p packetConn) net.IP { if src != nil { - pktSrc = *src + return *src + } + if p.IsIPv6() { + return net.IPv6zero } + return net.IPv4zero +} +func toControlMessage(p packetConn, iface *net.Interface, src *net.IP) *ControlMessage { + if iface != nil || src != nil { + ifaceIdx := 0 + if iface != nil { + ifaceIdx = iface.Index + } + return &ControlMessage{ + Src: toPacketSrcIP(src, p), + IfIndex: ifaceIdx, + } + } + return nil +} + +func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error { if c.closed.Load() { return ErrConnectionIsClosed } + p, err := newPacketConnWithAddr(raddr, c.connection) + if err != nil { + return err + } if iface != nil { - if err := p.SetMulticastInterface(iface); err != nil { + if err = p.SetMulticastInterface(iface); err != nil { return err } } - if err := p.SetMulticastHopLimit(multicastHopLimit); err != nil { + if err = p.SetMulticastHopLimit(multicastHopLimit); err != nil { return err } - - var err error - if iface != nil || src != nil { - _, err = p.WriteTo(buffer, &ControlMessage{ - Src: pktSrc, - IfIndex: iface.Index, - }, raddr) - } else { - _, err = p.WriteTo(buffer, nil, raddr) - } + cm := toControlMessage(p, iface, src) + _, err = p.WriteTo(buffer, cm, raddr) return err } @@ -408,8 +513,18 @@ func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer return nil } +func (c *UDPConn) writeTo(raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) (int, error) { + if !supportsOverrideRemoteAddr(c.connection) { + // If the remote address is set, we can use it as the destination address + // because the connection is already established. + // Note: Overwriting the destination address is only supported on Linux. + return c.connection.Write(buffer) + } + return c.packetConn.WriteTo(buffer, cm, raddr) +} + // WriteWithContext writes data with context. -func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error { +func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) error { if raddr == nil { return fmt.Errorf("cannot write with context: invalid raddr") } @@ -422,7 +537,7 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff if c.closed.Load() { return ErrConnectionIsClosed } - n, err := WriteToUDP(c.connection, raddr, buffer) + n, err := c.writeTo(raddr, cm, buffer) if err != nil { return err } @@ -434,20 +549,23 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff } // ReadWithContext reads packet with context. -func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) { +func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *ControlMessage, *net.UDPAddr, error) { select { case <-ctx.Done(): - return -1, nil, ctx.Err() + return -1, nil, nil, ctx.Err() default: } if c.closed.Load() { - return -1, nil, ErrConnectionIsClosed + return -1, nil, nil, ErrConnectionIsClosed } - n, s, err := c.connection.ReadFromUDP(buffer) + n, cm, srcAddr, err := c.packetConn.ReadFrom(buffer) if err != nil { - return -1, nil, fmt.Errorf("cannot read from udp connection: %w", err) + return -1, nil, nil, fmt.Errorf("cannot read from udp connection: %w", err) + } + if udpAdrr, ok := srcAddr.(*net.UDPAddr); ok { + return n, cm, udpAdrr, nil } - return n, s, err + return -1, nil, nil, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr) } // SetMulticastLoopback sets whether transmitted multicast packets diff --git a/net/connUDP_internal_test.go b/net/connUDP_internal_test.go new file mode 100644 index 00000000..a37650ca --- /dev/null +++ b/net/connUDP_internal_test.go @@ -0,0 +1,511 @@ +package net + +import ( + "context" + "net" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + udpNetwork = "udp" + udp4Network = "udp4" + udp6Network = "udp6" +) + +func TestUDPConnWriteWithContext(t *testing.T) { + peerAddr := "127.0.0.1:2154" + b, err := net.ResolveUDPAddr(udpNetwork, peerAddr) + require.NoError(t, err) + + ctxCanceled, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + + type args struct { + ctx context.Context + udpCtx *net.UDPAddr + buffer []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid", + args: args{ + ctx: context.Background(), + udpCtx: b, + buffer: []byte("hello world"), + }, + }, + { + name: "cancelled", + args: args{ + ctx: ctxCanceled, + buffer: []byte("hello world"), + }, + wantErr: true, + }, + } + + a, err := net.ResolveUDPAddr(udpNetwork, "127.0.0.1:") + require.NoError(t, err) + l1, err := net.ListenUDP(udpNetwork, a) + require.NoError(t, err) + c1 := NewUDPConn(udpNetwork, l1, WithErrors(func(err error) { t.Log(err) })) + defer func() { + errC := c1.Close() + require.NoError(t, errC) + }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + l2, err := net.ListenUDP(udpNetwork, b) + require.NoError(t, err) + c2 := NewUDPConn(udpNetwork, l2, WithErrors(func(err error) { t.Log(err) })) + defer func() { + errC := c2.Close() + require.NoError(t, errC) + }() + + go func() { + b := make([]byte, 1024) + _, _, _, errR := c2.ReadWithContext(ctx, b) + if errR != nil { + return + } + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, nil, tt.args.buffer) + + c1.LocalAddr() + c1.RemoteAddr() + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestUDPConnwriteMulticastWithContext(t *testing.T) { + peerAddr := "224.0.1.187:9999" + b, err := net.ResolveUDPAddr(udp4Network, peerAddr) + require.NoError(t, err) + + ctxCanceled, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + payload := []byte("hello world") + + ifs, err := net.Interfaces() + require.NoError(t, err) + var iface net.Interface + for _, i := range ifs { + if i.Flags&net.FlagMulticast == net.FlagMulticast && i.Flags&net.FlagUp == net.FlagUp { + iface = i + break + } + } + require.NotEmpty(t, iface) + + type args struct { + ctx context.Context + udpCtx *net.UDPAddr + buffer []byte + opts []MulticastOption + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid all interfaces", + args: args{ + ctx: context.Background(), + udpCtx: b, + buffer: payload, + opts: []MulticastOption{WithAllMulticastInterface()}, + }, + }, + { + name: "valid any interface", + args: args{ + ctx: context.Background(), + udpCtx: b, + buffer: payload, + opts: []MulticastOption{WithAnyMulticastInterface()}, + }, + }, + { + name: "valid first interface", + args: args{ + ctx: context.Background(), + udpCtx: b, + buffer: payload, + opts: []MulticastOption{WithMulticastInterface(iface)}, + }, + }, + { + name: "cancelled", + args: args{ + ctx: ctxCanceled, + udpCtx: b, + buffer: payload, + }, + wantErr: true, + }, + } + + listenAddr := ":" + strconv.Itoa(b.Port) + c, err := net.ResolveUDPAddr(udp4Network, listenAddr) + require.NoError(t, err) + l2, err := net.ListenUDP(udp4Network, c) + require.NoError(t, err) + c2 := NewUDPConn(udpNetwork, l2, WithErrors(func(err error) { t.Log(err) })) + defer func() { + errC := c2.Close() + require.NoError(t, errC) + }() + ifaces, err := net.Interfaces() + require.NoError(t, err) + for _, iface := range ifaces { + ifa := iface + err = c2.JoinGroup(&ifa, b) + if err != nil { + t.Logf("fmt cannot join group %v: %v", ifa.Name, err) + } + } + + err = c2.SetMulticastLoopback(true) + require.NoError(t, err) + + a, err := net.ResolveUDPAddr(udp4Network, "") + require.NoError(t, err) + l1, err := net.ListenUDP(udp4Network, a) + require.NoError(t, err) + c1 := NewUDPConn(udpNetwork, l1, WithErrors(func(err error) { t.Log(err) })) + defer func() { + errC := c1.Close() + require.NoError(t, errC) + }() + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + b := make([]byte, 1024) + n, _, _, errR := c2.ReadWithContext(ctx, b) + assert.NoError(t, errR) + if n > 0 { + b = b[:n] + assert.Equal(t, payload, b) + } + wg.Done() + }() + defer wg.Wait() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = c1.WriteMulticast(tt.args.ctx, tt.args.udpCtx, tt.args.buffer, tt.args.opts...) + c1.LocalAddr() + c1.RemoteAddr() + + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestControlMessageString(t *testing.T) { + tests := []struct { + name string + c *ControlMessage + want string + }{ + { + name: "nil", + c: nil, + want: "", + }, + { + name: "dst", + c: &ControlMessage{ + Dst: net.IPv4(192, 168, 1, 1), + }, + want: "Dst: 192.168.1.1, ", + }, + { + name: "src", + c: &ControlMessage{ + Src: net.IPv4(192, 168, 1, 2), + }, + want: "Src: 192.168.1.2, ", + }, + { + name: "ifIndex", + c: &ControlMessage{ + IfIndex: 1, + }, + want: "IfIndex: 1, ", + }, + { + name: "all", + c: &ControlMessage{ + Dst: net.IPv4(192, 168, 1, 1), + Src: net.IPv4(192, 168, 1, 2), + IfIndex: 1, + }, + want: "Dst: 192.168.1.1, Src: 192.168.1.2, IfIndex: 1, ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.c.String()) + }) + } +} + +func getIfaceAddr(t *testing.T, iface net.Interface, ipv4 bool) net.IP { + addrs, err := iface.Addrs() + require.NoError(t, err) + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + require.NoError(t, err) + if !ip.IsPrivate() { + continue + } + if ipv4 { + if ip.To4() != nil { + return ip + } + continue + } + return ip + } + return nil +} + +func TestUDPConnWriteToAddr(t *testing.T) { + ifaces, err := net.Interfaces() + require.NoError(t, err) + var iface net.Interface + for _, i := range ifaces { + if i.Flags&net.FlagUp == net.FlagUp && i.Flags&net.FlagMulticast == net.FlagMulticast && i.Flags&net.FlagLoopback != net.FlagLoopback { + iface = i + break + } + } + require.NotEmpty(t, iface) + type args struct { + iface *net.Interface + src net.IP + multicastHopLimit int + raddr *net.UDPAddr + buffer []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "IPv4", + args: args{ + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, true), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + { + name: "IPv6", + args: args{ + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, false), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + { + name: "closed", + args: args{ + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, true), Port: 1234}, + buffer: []byte("hello world"), + }, + wantErr: true, + }, + { + name: "with interface", + args: args{ + iface: &iface, + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, true), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + { + name: "with source", + args: args{ + src: net.IP{127, 0, 0, 1}, + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, true), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + { + name: "with multicast hop limit", + args: args{ + multicastHopLimit: 5, + raddr: &net.UDPAddr{IP: net.IPv4(224, 0, 0, 1), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + { + name: "with interface and source", + args: args{ + iface: &iface, + src: getIfaceAddr(t, iface, true), + raddr: &net.UDPAddr{IP: getIfaceAddr(t, iface, true), Port: 1234}, + buffer: []byte("hello world"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network := udp4Network + ip := getIfaceAddr(t, iface, true) + if IsIPv6(tt.args.src) { + network = udp6Network + ip = getIfaceAddr(t, iface, false) + } + p, err := net.ListenUDP(network, &net.UDPAddr{IP: ip, Port: 1235}) + require.NoError(t, err) + defer func() { + errC := p.Close() + require.NoError(t, errC) + }() + c := &UDPConn{ + connection: p, + } + if tt.wantErr { + c.closed.Store(true) + } + err = c.writeToAddr(tt.args.iface, &tt.args.src, tt.args.multicastHopLimit, tt.args.raddr, tt.args.buffer) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestPacketConnReadFrom(t *testing.T) { + readUDP4Conn, err := net.ListenUDP(udp4Network, &net.UDPAddr{Port: 1234}) + require.NoError(t, err) + defer func() { + errC := readUDP4Conn.Close() + require.NoError(t, errC) + }() + + require.Nil(t, readUDP4Conn.RemoteAddr()) + + writeUDP4Conn, err := net.DialUDP(udp4Network, nil, readUDP4Conn.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + defer func() { + errC := writeUDP4Conn.Close() + require.NoError(t, errC) + }() + + require.NotNil(t, writeUDP4Conn.RemoteAddr()) + + readUDP6Conn, err := net.ListenUDP(udp6Network, &net.UDPAddr{Port: 1235}) + require.NoError(t, err) + defer func() { + errC := readUDP6Conn.Close() + require.NoError(t, errC) + }() + writeUDP6Conn, err := net.DialUDP(udp6Network, nil, readUDP6Conn.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + defer func() { + errC := writeUDP6Conn.Close() + require.NoError(t, errC) + }() + + type fields struct { + packetConn *net.UDPConn + } + type args struct { + b []byte + } + tests := []struct { + name string + fields fields + args args + wantN int + wantErr bool + }{ + { + name: "valid UDP4", + fields: fields{ + packetConn: readUDP4Conn, + }, + args: args{ + b: []byte("hello world"), + }, + wantN: 11, + wantErr: false, + }, + { + name: "valid UDP6", + fields: fields{ + packetConn: readUDP6Conn, + }, + args: args{ + b: []byte("hello world"), + }, + wantN: 11, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, err := newPacketConn(tt.fields.packetConn) + require.NoError(t, err) + if !tt.wantErr && tt.fields.packetConn == readUDP4Conn { + n, errW := writeUDP4Conn.Write(tt.args.b) + require.NoError(t, errW) + require.Equal(t, len(tt.args.b), n) + } + if !tt.wantErr && tt.fields.packetConn == readUDP6Conn { + n, errW := writeUDP6Conn.Write(tt.args.b) + require.NoError(t, errW) + require.Equal(t, len(tt.args.b), n) + } + gotN, gotCm, gotSrc, err := p.ReadFrom(tt.args.b) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + if p.SupportsControlMessage() { + require.NotNil(t, gotCm) + } else { + require.Nil(t, gotCm) + } + require.NotNil(t, gotSrc) + require.Equal(t, tt.wantN, gotN) + }) + } +} diff --git a/net/connUDP_test.go b/net/connUDP_test.go deleted file mode 100644 index 5cfe3c8d..00000000 --- a/net/connUDP_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package net - -import ( - "context" - "net" - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestUDPConnWriteWithContext(t *testing.T) { - peerAddr := "127.0.0.1:2154" - b, err := net.ResolveUDPAddr("udp", peerAddr) - require.NoError(t, err) - - ctxCanceled, ctxCancel := context.WithCancel(context.Background()) - ctxCancel() - - type args struct { - ctx context.Context - udpCtx *net.UDPAddr - buffer []byte - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid", - args: args{ - ctx: context.Background(), - udpCtx: b, - buffer: []byte("hello world"), - }, - }, - { - name: "cancelled", - args: args{ - ctx: ctxCanceled, - buffer: []byte("hello world"), - }, - wantErr: true, - }, - } - - a, err := net.ResolveUDPAddr("udp", "127.0.0.1:") - require.NoError(t, err) - l1, err := net.ListenUDP("udp", a) - require.NoError(t, err) - c1 := NewUDPConn("udp", l1, WithErrors(func(err error) { t.Log(err) })) - defer func() { - errC := c1.Close() - require.NoError(t, errC) - }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - l2, err := net.ListenUDP("udp", b) - require.NoError(t, err) - c2 := NewUDPConn("udp", l2, WithErrors(func(err error) { t.Log(err) })) - defer func() { - errC := c2.Close() - require.NoError(t, errC) - }() - - go func() { - b := make([]byte, 1024) - _, _, errR := c2.ReadWithContext(ctx, b) - if errR != nil { - return - } - }() - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, tt.args.buffer) - - c1.LocalAddr() - c1.RemoteAddr() - - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestUDPConnwriteMulticastWithContext(t *testing.T) { - peerAddr := "224.0.1.187:9999" - b, err := net.ResolveUDPAddr("udp4", peerAddr) - require.NoError(t, err) - - ctxCanceled, ctxCancel := context.WithCancel(context.Background()) - ctxCancel() - payload := []byte("hello world") - - ifs, err := net.Interfaces() - require.NoError(t, err) - var iface net.Interface - for _, i := range ifs { - if i.Flags&net.FlagMulticast == net.FlagMulticast && i.Flags&net.FlagUp == net.FlagUp { - iface = i - break - } - } - require.NotEmpty(t, iface) - - type args struct { - ctx context.Context - udpCtx *net.UDPAddr - buffer []byte - opts []MulticastOption - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid all interfaces", - args: args{ - ctx: context.Background(), - udpCtx: b, - buffer: payload, - opts: []MulticastOption{WithAllMulticastInterface()}, - }, - }, - { - name: "valid any interface", - args: args{ - ctx: context.Background(), - udpCtx: b, - buffer: payload, - opts: []MulticastOption{WithAnyMulticastInterface()}, - }, - }, - { - name: "valid first interface", - args: args{ - ctx: context.Background(), - udpCtx: b, - buffer: payload, - opts: []MulticastOption{WithMulticastInterface(iface)}, - }, - }, - { - name: "cancelled", - args: args{ - ctx: ctxCanceled, - udpCtx: b, - buffer: payload, - }, - wantErr: true, - }, - } - - listenAddr := ":" + strconv.Itoa(b.Port) - c, err := net.ResolveUDPAddr("udp4", listenAddr) - require.NoError(t, err) - l2, err := net.ListenUDP("udp4", c) - require.NoError(t, err) - c2 := NewUDPConn("udp", l2, WithErrors(func(err error) { t.Log(err) })) - defer func() { - errC := c2.Close() - require.NoError(t, errC) - }() - ifaces, err := net.Interfaces() - require.NoError(t, err) - for _, iface := range ifaces { - ifa := iface - err = c2.JoinGroup(&ifa, b) - if err != nil { - t.Logf("fmt cannot join group %v: %v", ifa.Name, err) - } - } - - err = c2.SetMulticastLoopback(true) - require.NoError(t, err) - - a, err := net.ResolveUDPAddr("udp4", "") - require.NoError(t, err) - l1, err := net.ListenUDP("udp4", a) - require.NoError(t, err) - c1 := NewUDPConn("udp", l1, WithErrors(func(err error) { t.Log(err) })) - defer func() { - errC := c1.Close() - require.NoError(t, errC) - }() - require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - b := make([]byte, 1024) - n, _, errR := c2.ReadWithContext(ctx, b) - assert.NoError(t, errR) - if n > 0 { - b = b[:n] - assert.Equal(t, payload, b) - } - wg.Done() - }() - defer wg.Wait() - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err = c1.WriteMulticast(tt.args.ctx, tt.args.udpCtx, tt.args.buffer, tt.args.opts...) - c1.LocalAddr() - c1.RemoteAddr() - - if tt.wantErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - } -} diff --git a/net/supportsOverrideRemoteAddr.go b/net/supportsOverrideRemoteAddr.go new file mode 100644 index 00000000..c0f99ad1 --- /dev/null +++ b/net/supportsOverrideRemoteAddr.go @@ -0,0 +1,9 @@ +//go:build !linux + +package net + +import "net" + +func supportsOverrideRemoteAddr(c *net.UDPConn) bool { + return c.RemoteAddr() == nil +} diff --git a/net/supportsOverrideRemoteAddr_linux.go b/net/supportsOverrideRemoteAddr_linux.go new file mode 100644 index 00000000..eba19e4a --- /dev/null +++ b/net/supportsOverrideRemoteAddr_linux.go @@ -0,0 +1,7 @@ +package net + +import "net" + +func supportsOverrideRemoteAddr(*net.UDPConn) bool { + return true +} diff --git a/net/udp.go b/net/udp.go deleted file mode 100644 index 0459266c..00000000 --- a/net/udp.go +++ /dev/null @@ -1,15 +0,0 @@ -package net - -import ( - "net" -) - -// WriteToUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. -func WriteToUDP(conn *net.UDPConn, raddr *net.UDPAddr, b []byte) (int, error) { - if conn.RemoteAddr() == nil { - // Connection remote address must be nil otherwise - // "WriteTo with pre-connected connection" will be thrown - return conn.WriteToUDP(b, raddr) - } - return conn.Write(b) -} diff --git a/udp/client/conn.go b/udp/client/conn.go index 8e20d871..a164c790 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -716,6 +716,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con }() resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) + ifIndex := req.ControlMessage().GetIfIndex() w := responsewriter.New(resp, cc, req.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -730,6 +731,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con // nothing to send return } + upsertInterfaceToMessage(w.Message(), ifIndex) errW := cc.writeMessageAsync(w.Message()) if errW != nil { cc.closeConnection() @@ -741,6 +743,15 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess cc.sendPong(w, r) } +func upsertInterfaceToMessage(m *pool.Message, ifIndex int) { + if ifIndex >= 1 { + cm := coapNet.ControlMessage{ + IfIndex: ifIndex, + } + m.UpsertControlMessage(&cm) + } +} + func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { // ping request if r.Code() == codes.Empty && r.Type() == message.Confirmable && len(r.Token()) == 0 && len(r.Options()) == 0 && r.Body() == nil { @@ -752,6 +763,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { elem.ReleaseMessage(cc) resp := cc.AcquireMessage(cc.Context()) resp.SetToken(r.Token()) + upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex()) w := responsewriter.New(resp, cc, r.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -769,7 +781,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { return false } -func (cc *Conn) Process(datagram []byte) error { +func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error { if uint32(len(datagram)) > cc.session.MaxMessageSize() { return fmt.Errorf("max message size(%v) was exceeded %v", cc.session.MaxMessageSize(), len(datagram)) } @@ -779,6 +791,7 @@ func (cc *Conn) Process(datagram []byte) error { cc.ReleaseMessage(req) return err } + req.SetControlMessage(cm) req.SetSequence(cc.Sequence()) cc.checkMyMessageID(req) cc.inactivityMonitor.Notify() diff --git a/udp/client_test.go b/udp/client_test.go index a3f5b7dc..5781c7a3 100644 --- a/udp/client_test.go +++ b/udp/client_test.go @@ -758,7 +758,7 @@ func TestClientKeepAliveMonitor(t *testing.T) { go func() { defer serverWg.Done() for { - _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024)) + _, _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024)) if errR != nil { if errors.Is(errR, net.ErrClosed) { return diff --git a/udp/server/discover.go b/udp/server/discover.go index aa41af4c..86f81980 100644 --- a/udp/server/discover.go +++ b/udp/server/discover.go @@ -75,7 +75,7 @@ func (s *Server) DiscoveryRequest(req *pool.Message, address string, receiverFun return err } } else { - err = c.WriteWithContext(req.Context(), addr, data) + err = c.WriteWithContext(req.Context(), addr, nil, data) if err != nil { return err } diff --git a/udp/server/server.go b/udp/server/server.go index 8bceeab5..0f512699 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -145,7 +145,7 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { for { buf := m - n, raddr, err := l.ReadWithContext(s.ctx, buf) + n, cm, raddr, err := l.ReadWithContext(s.ctx, buf) if err != nil { wg.Wait() @@ -165,7 +165,7 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) continue } - err = cc.Process(buf) + err = cc.Process(cm, buf) if err != nil { s.closeConnection(cc) s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) diff --git a/udp/server/session.go b/udp/server/session.go index 870ff53c..103deb6d 100644 --- a/udp/server/session.go +++ b/udp/server/session.go @@ -109,7 +109,7 @@ func (s *Session) WriteMessage(req *pool.Message) error { if err != nil { return fmt.Errorf("cannot marshal: %w", err) } - return s.connection.WriteWithContext(req.Context(), s.raddr, data) + return s.connection.WriteWithContext(req.Context(), s.raddr, req.ControlMessage(), data) } // WriteMulticastMessage sends multicast to the remote multicast address. @@ -135,12 +135,12 @@ func (s *Session) Run(cc *client.Conn) (err error) { m := make([]byte, s.mtu) for { buf := m - n, _, err := s.connection.ReadWithContext(s.Context(), buf) + n, cm, _, err := s.connection.ReadWithContext(s.Context(), buf) if err != nil { return err } buf = buf[:n] - err = cc.Process(buf) + err = cc.Process(cm, buf) if err != nil { return err }