Skip to content

Commit

Permalink
Merge pull request #33 from cxz66666/udp-fix
Browse files Browse the repository at this point in the history
bugfix: 修复UDP转发功能
  • Loading branch information
thinkgos authored Sep 11, 2023
2 parents ea3e2a0 + a447d34 commit ef8e9fe
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 81 deletions.
18 changes: 8 additions & 10 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
# https://github.com/golangci/golangci/wiki/Configuration
linters-settings:
depguard:
list-type: blacklist
packages:
# logging is allowed only by logutils.Log, logrus
# is allowed to use only in logutils package
- github.com/sirupsen/logrus
packages-with-error-message:
- github.com/sirupsen/logrus: "logging is allowed only by logutils.Log"
rules:
main:
deny:
- pkg: "github.com/sirupsen/logrus"
desc: not allowed
exhaustive:
default-signifies-exhaustive: false
gci:
Expand Down Expand Up @@ -73,7 +71,7 @@ linters:
disable-all: true
enable:
- bodyclose
- deadcode
# - deadcode
- depguard
- dogsled
- errcheck
Expand All @@ -95,13 +93,13 @@ linters:
- rowserrcheck
- scopelint
- staticcheck
- structcheck
# - structcheck
- stylecheck
- typecheck
- unconvert
- unparam
- unused
- varcheck
# - varcheck
- whitespace
- unparam
# - interfacer
Expand Down
110 changes: 55 additions & 55 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package socks5

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -21,9 +22,9 @@ type Request struct {
statute.Request
// AuthContext provided during negotiation
AuthContext *AuthContext
// LocalAddr of the the network server listen
// LocalAddr of the network server listen
LocalAddr net.Addr
// RemoteAddr of the the network that sent the request
// RemoteAddr of the network that sent the request
RemoteAddr net.Addr
// DestAddr of the actual destination (might be affected by rewrite)
DestAddr *statute.AddrSpec
Expand Down Expand Up @@ -159,6 +160,8 @@ func (sf *Server) handleBind(_ context.Context, writer io.Writer, _ *Request) er
}

// handleAssociate is used to handle a connect command
//
//nolint:unparam
func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := sf.dial
Expand All @@ -167,33 +170,15 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
return net.Dial(net_, addr)
}
}

target, err := dial(ctx, "udp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := statute.RepHostUnreachable
if strings.Contains(msg, "refused") {
resp = statute.RepConnectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, resp, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
}
defer target.Close()

bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
}
defer bindLn.Close()

sf.logger.Errorf("target addr %v, listen addr: %s", target.RemoteAddr(), bindLn.LocalAddr())
sf.logger.Errorf("client want to used addr %v, listen addr: %s", request.DestAddr, bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client used
if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
Expand All @@ -204,71 +189,87 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
conns := sync.Map{}
bufPool := sf.bufferPool.Get()
defer func() {
target.Close()
bindLn.Close()
sf.bufferPool.Put(bufPool)
bindLn.Close()
conns.Range(func(key, value any) bool {
if connTarget, ok := value.(net.Conn); !ok {
sf.logger.Errorf("conns has illegal item %v:%v", key, value)
} else {
connTarget.Close()
}
return true
})
}()
for {
n, srcAddr, err := bindLn.ReadFrom(bufPool[:cap(bufPool)])
n, srcAddr, err := bindLn.ReadFromUDP(bufPool[:cap(bufPool)])
if err != nil {
if err == io.EOF {
return
}
if strings.Contains(err.Error(), "use of closed network connection") {
sf.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return
}
continue
}

pk, err := statute.ParseDatagram(bufPool[:n])
if err != nil {
continue
}

if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
// check src addr whether equal requst.DestAddr
srcEqual := ((request.DestAddr.IP.IsUnspecified()) || request.DestAddr.IP.Equal(srcAddr.IP)) && (request.DestAddr.Port == 0 || request.DestAddr.Port == srcAddr.Port) //nolint:lll
if !srcEqual {
continue
}

connKey := srcAddr.String() + "--" + pk.DstAddr.String()

if target, ok := conns.Load(connKey); !ok {
// if the 'connection' doesn't exist, create one and store it
targetNew, err := dial(ctx, "udp", pk.DstAddr.String())
if err != nil {
sf.logger.Errorf("connect to %v failed, %v", pk.DstAddr, err)
// TODO:continue or return Error?
continue
}
conns.Store(connKey, targetNew)
// read from remote server and write to original client
sf.goFunc(func() {
// read from remote server and write to client
bufPool := sf.bufferPool.Get()
defer func() {
target.Close()
bindLn.Close()
targetNew.Close()
conns.Delete(connKey)
sf.bufferPool.Put(bufPool)
}()

for {
buf := bufPool[:cap(bufPool)]
n, err := target.Read(buf)
n, err := targetNew.Read(buf)
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return
}
sf.logger.Errorf("read data from remote %s failed, %v", target.RemoteAddr().String(), err)
sf.logger.Errorf("read data from remote %s failed, %v", targetNew.RemoteAddr().String(), err)
return
}

pkb, err := statute.NewDatagram(target.RemoteAddr().String(), buf[:n])
if err != nil {
continue
}
tmpBufPool := sf.bufferPool.Get()
proBuf := tmpBufPool
proBuf = append(proBuf, pkb.Header()...)
proBuf = append(proBuf, pkb.Data...)
proBuf = append(proBuf, pk.Header()...)
proBuf = append(proBuf, buf[:n]...)
if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil {
sf.bufferPool.Put(tmpBufPool)
sf.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
sf.logger.Errorf("write data to client %s failed, %v", srcAddr, err)
return
}
sf.bufferPool.Put(tmpBufPool)
}
})
}

// 把消息写给remote sever
if _, err := target.Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote %s failed, %v", target.RemoteAddr().String(), err)
return
if _, err := targetNew.Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote server %s failed, %v", targetNew.RemoteAddr().String(), err)
return
}
} else {
if _, err := target.(net.Conn).Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote server %s failed, %v", target.(net.Conn).RemoteAddr().String(), err)
return
}
}
}
})
Expand All @@ -280,12 +281,11 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
_, err := request.Reader.Read(buf[:cap(buf)])
// sf.logger.Errorf("read data from client %s, %d bytesm, err is %+v", request.RemoteAddr.String(), num, err)
if err != nil {
if err == io.EOF {
bindLn.Close()
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
if strings.Contains(err.Error(), "use of closed network connection") {
return err
}
return err
}
}
}
Expand Down
39 changes: 23 additions & 16 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package socks5

import (
"bytes"
"encoding/binary"
"errors"
"io"
"log"
Expand Down Expand Up @@ -111,24 +112,29 @@ func TestSOCKS5_Connect(t *testing.T) {
func TestSOCKS5_Associate(t *testing.T) {
locIP := net.ParseIP("127.0.0.1")
// Create a local listener
lAddr := &net.UDPAddr{IP: locIP, Port: 12399}
l, err := net.ListenUDP("udp", lAddr)
serverAddr := &net.UDPAddr{IP: locIP, Port: 12399}
server, err := net.ListenUDP("udp", serverAddr)
require.NoError(t, err)
defer l.Close()
defer server.Close()

go func() {
buf := make([]byte, 2048)
for {
n, remote, err := l.ReadFrom(buf)
n, remote, err := server.ReadFrom(buf)
if err != nil {
return
}
require.Equal(t, []byte("ping"), buf[:n])

l.WriteTo([]byte("pong"), remote) //nolint: errcheck
server.WriteTo([]byte("pong"), remote) //nolint: errcheck
}
}()

clientAddr := &net.UDPAddr{IP: locIP, Port: 12499}
client, err := net.ListenUDP("udp", clientAddr)
require.NoError(t, err)
defer client.Close()

// Create a socks server
cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}}
proxySrv := NewServer(
Expand Down Expand Up @@ -158,8 +164,8 @@ func TestSOCKS5_Associate(t *testing.T) {
Reserved: 0,
DstAddr: statute.AddrSpec{
FQDN: "",
IP: locIP,
Port: lAddr.Port,
IP: clientAddr.IP,
Port: clientAddr.Port,
AddrType: statute.ATYPIPv4,
},
}
Expand All @@ -185,19 +191,20 @@ func TestSOCKS5_Associate(t *testing.T) {
require.Equal(t, statute.VersionSocks5, rspHead.Version)
require.Equal(t, statute.RepSuccess, rspHead.Response)

ipByte := []byte(serverAddr.IP.To4())
portByte := make([]byte, 2)
binary.BigEndian.PutUint16(portByte, uint16(serverAddr.Port))

msgBytes := []byte{0, 0, 0, statute.ATYPIPv4}
msgBytes = append(msgBytes, ipByte...)
msgBytes = append(msgBytes, portByte...)
msgBytes = append(msgBytes, []byte("ping")...)
client.WriteTo(msgBytes, &net.UDPAddr{IP: locIP, Port: rspHead.BndAddr.Port}) //nolint: errcheck
// t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port)
udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: locIP,
Port: rspHead.BndAddr.Port,
})
require.NoError(t, err)
// Send a ping
udpConn.Write(append([]byte{0, 0, 0, statute.ATYPIPv4, 0, 0, 0, 0, 0, 0}, []byte("ping")...)) //nolint: errcheck
response := make([]byte, 1024)
n, _, err := udpConn.ReadFrom(response)
n, _, err := client.ReadFrom(response)
require.NoError(t, err)
assert.Equal(t, []byte("pong"), response[n-4:n])

time.Sleep(time.Second * 1)
}

Expand Down
2 changes: 2 additions & 0 deletions statute/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ func NewUserPassRequest(ver byte, user, pass []byte) UserPassRequest {
}

// ParseUserPassRequest parse user's password request.
//
//nolint:nakedret
func ParseUserPassRequest(r io.Reader) (nup UserPassRequest, err error) {
tmp := []byte{0, 0}

Expand Down
2 changes: 2 additions & 0 deletions statute/datagram.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ func NewDatagram(destAddr string, data []byte) (p Datagram, err error) {
}

// ParseDatagram parse to datagram from bytes
//
//nolint:nakedret
func ParseDatagram(b []byte) (da Datagram, err error) {
if len(b) < 4+net.IPv4len+2 { // no enough data
err = errors.New("datagram to short")
Expand Down

0 comments on commit ef8e9fe

Please sign in to comment.