Skip to content

Commit

Permalink
bugfix: fix udp proxy confused implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cxz66666 committed Sep 2, 2023
1 parent ea3e2a0 commit d319bc3
Showing 1 changed file with 52 additions and 52 deletions.
104 changes: 52 additions & 52 deletions handle.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package socks5

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -167,33 +169,15 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
return net.Dial(net_, addr)
}
}

target, err := dial(ctx, "udp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := statute.RepHostUnreachable
if strings.Contains(msg, "refused") {
resp = statute.RepConnectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = statute.RepNetworkUnreachable
}
if err := SendReply(writer, resp, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
}
defer target.Close()

bindLn, err := net.ListenUDP("udp", nil)
if err != nil {
if err := SendReply(writer, statute.RepServerFailure, nil); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
}
return fmt.Errorf("listen udp failed, %v", err)
}
defer bindLn.Close()

sf.logger.Errorf("target addr %v, listen addr: %s", target.RemoteAddr(), bindLn.LocalAddr())
sf.logger.Errorf("client want to used addr %v, listen addr: %s", request.DestAddr, bindLn.LocalAddr())
// send BND.ADDR and BND.PORT, client used
if err = SendReply(writer, statute.RepSuccess, bindLn.LocalAddr()); err != nil {
return fmt.Errorf("failed to send reply, %v", err)
Expand All @@ -204,71 +188,87 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
conns := sync.Map{}
bufPool := sf.bufferPool.Get()
defer func() {
target.Close()
bindLn.Close()
sf.bufferPool.Put(bufPool)
bindLn.Close()
conns.Range(func(key, value any) bool {
if connTarget, ok := value.(net.Conn); !ok {
sf.logger.Errorf("conns has illegal item %v:%v", key, value)
} else {
connTarget.Close()
}
return true
})
}()
for {
n, srcAddr, err := bindLn.ReadFrom(bufPool[:cap(bufPool)])
n, srcAddr, err := bindLn.ReadFromUDP(bufPool[:cap(bufPool)])
if err != nil {
if err == io.EOF {
return
}
if strings.Contains(err.Error(), "use of closed network connection") {
sf.logger.Errorf("read data from bind listen address %s failed, %v", bindLn.LocalAddr(), err)
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return
}
continue
}

pk, err := statute.ParseDatagram(bufPool[:n])
if err != nil {
continue
}

if _, ok := conns.LoadOrStore(srcAddr.String(), struct{}{}); !ok {
// check src addr whether equal requst.DestAddr
srcEqual := ((request.DestAddr.IP.IsUnspecified()) || bytes.Equal(request.DestAddr.IP, srcAddr.IP)) && (request.DestAddr.Port == 0 || request.DestAddr.Port == srcAddr.Port)
if !srcEqual {
continue
}

connKey := srcAddr.String() + "--" + pk.DstAddr.String()

if target, ok := conns.Load(connKey); !ok {
// if the 'connection' doesn't exist, create one and store it
targetNew, err := dial(ctx, "udp", pk.DstAddr.String())
if err != nil {
sf.logger.Errorf("connect to %v failed, %v", pk.DstAddr, err)
// TODO:continue or return Error?
continue
}
conns.Store(connKey, targetNew)
// read from remote server and write to original client
sf.goFunc(func() {
// read from remote server and write to client
bufPool := sf.bufferPool.Get()
defer func() {
target.Close()
bindLn.Close()
targetNew.Close()
conns.Delete(connKey)
sf.bufferPool.Put(bufPool)
}()

for {
buf := bufPool[:cap(bufPool)]
n, err := target.Read(buf)
n, err := targetNew.Read(buf)
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return
}
sf.logger.Errorf("read data from remote %s failed, %v", target.RemoteAddr().String(), err)
sf.logger.Errorf("read data from remote %s failed, %v", targetNew.RemoteAddr().String(), err)
return
}

pkb, err := statute.NewDatagram(target.RemoteAddr().String(), buf[:n])
if err != nil {
continue
}
tmpBufPool := sf.bufferPool.Get()
proBuf := tmpBufPool
proBuf = append(proBuf, pkb.Header()...)
proBuf = append(proBuf, pkb.Data...)
proBuf = append(proBuf, pk.Header()...)
proBuf = append(proBuf, buf[:n]...)
if _, err := bindLn.WriteTo(proBuf, srcAddr); err != nil {
sf.bufferPool.Put(tmpBufPool)
sf.logger.Errorf("write data to client %s failed, %v", bindLn.LocalAddr(), err)
sf.logger.Errorf("write data to client %s failed, %v", srcAddr, err)
return
}
sf.bufferPool.Put(tmpBufPool)
}
})
}

// 把消息写给remote sever
if _, err := target.Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote %s failed, %v", target.RemoteAddr().String(), err)
return
if _, err := targetNew.Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote server %s failed, %v", targetNew.RemoteAddr().String(), err)
return
}
} else {
if _, err := target.(net.Conn).Write(pk.Data); err != nil {
sf.logger.Errorf("write data to remote server %s failed, %v", target.(net.Conn).RemoteAddr().String(), err)
return
}
}
}
})
Expand All @@ -280,10 +280,10 @@ func (sf *Server) handleAssociate(ctx context.Context, writer io.Writer, request
_, err := request.Reader.Read(buf[:cap(buf)])
// sf.logger.Errorf("read data from client %s, %d bytesm, err is %+v", request.RemoteAddr.String(), num, err)
if err != nil {
if err == io.EOF {
bindLn.Close()
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
if strings.Contains(err.Error(), "use of closed network connection") {
} else {

Check failure on line 286 in handle.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

`if` block ends with a `return` statement, so drop this `else` and outdent its block (golint)
return err
}
}
Expand Down

0 comments on commit d319bc3

Please sign in to comment.