Skip to content

Commit

Permalink
Merge pull request #45 from kotori2/master
Browse files Browse the repository at this point in the history
add request to dial callback
  • Loading branch information
thinkgos authored Jan 8, 2024
2 parents 6191a34 + 5438f4d commit e9cb1db
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 10 deletions.
17 changes: 12 additions & 5 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,20 @@ func (sf *Server) handleRequest(write io.Writer, req *Request) error {
// handleConnect is used to handle a connect command
func (sf *Server) handleConnect(ctx context.Context, writer io.Writer, request *Request) error {
// Attempt to connect
dial := sf.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
var target net.Conn
var err error

if sf.dialWithRequest != nil {
target, err = sf.dialWithRequest(ctx, "tcp", request.DestAddr.String(), request)
} else {
dial := sf.dial
if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
target, err = dial(ctx, "tcp", request.DestAddr.String())
}
target, err := dial(ctx, "tcp", request.DestAddr.String())
if err != nil {
msg := err.Error()
resp := statute.RepHostUnreachable
Expand Down
12 changes: 11 additions & 1 deletion option.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,23 @@ func WithLogger(l Logger) Option {
}
}

// WithDial Optional function for dialing out
// WithDial Optional function for dialing out.
// The callback set by WithDialAndRequest will be called first.
func WithDial(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(s *Server) {
s.dial = dial
}
}

// WithDialAndRequest Optional function for dialing out with the access of request detail.
func WithDialAndRequest(
dial func(ctx context.Context, network, addr string, request *Request) (net.Conn, error),
) Option {
return func(s *Server) {
s.dialWithRequest = dial
}
}

// WithGPool can be provided to do custom goroutine pool.
func WithGPool(pool GPool) Option {
return func(s *Server) {
Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ type Server struct {
// logger can be used to provide a custom log target.
// Defaults to io.Discard.
logger Logger
// Optional function for dialing out
// Optional function for dialing out.
// The callback set by dialWithRequest will be called first.
dial func(ctx context.Context, network, addr string) (net.Conn, error)
// Optional function for dialing out with the access of request detail.
dialWithRequest func(ctx context.Context, network, addr string, request *Request) (net.Conn, error)
// buffer pool
bufferPool bufferpool.BufPool
// goroutine pool
Expand All @@ -64,9 +67,6 @@ func NewServer(opts ...Option) *Server {
resolver: DNSResolver{},
rules: NewPermitAll(),
logger: NewLogger(log.New(io.Discard, "socks5: ", log.LstdFlags)),
dial: func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
},
}

for _, opt := range opts {
Expand Down
6 changes: 6 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package socks5

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
Expand Down Expand Up @@ -42,6 +43,11 @@ func TestSOCKS5_Connect(t *testing.T) {
srv := NewServer(
WithAuthMethods([]Authenticator{cator}),
WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))),
WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) {
require.Equal(t, network, "tcp")
require.Equal(t, addr, lAddr.String())
return net.Dial(network, addr)
}),
)

// Start listening
Expand Down

0 comments on commit e9cb1db

Please sign in to comment.