Skip to content

Commit

Permalink
dns: add edns0 subnet option support
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Feb 9, 2020
1 parent 8121e20 commit cbc9c1f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 28 deletions.
2 changes: 2 additions & 0 deletions cmd/gost/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ func (r *route) GenRouters() ([]router, error) {
gost.ChainResolverOption(chain),
gost.TimeoutResolverOption(timeout),
gost.TTLResolverOption(ttl),
gost.PreferResolverOption(node.Get("prefer")),
gost.SrcIPResolverOption(net.ParseIP(node.Get("ip"))),
)
}

Expand Down
116 changes: 88 additions & 28 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -122,6 +123,8 @@ type resolverOptions struct {
chain *Chain
timeout time.Duration
ttl time.Duration
prefer string
srcIP net.IP
}

// ResolverOption allows a common way to set Resolver options.
Expand All @@ -148,6 +151,20 @@ func TTLResolverOption(ttl time.Duration) ResolverOption {
}
}

// PreferResolverOption sets the prefer for Resolver.
func PreferResolverOption(prefer string) ResolverOption {
return func(opts *resolverOptions) {
opts.prefer = prefer
}
}

// SrcIPResolverOption sets the source IP for Resolver.
func SrcIPResolverOption(ip net.IP) ResolverOption {
return func(opts *resolverOptions) {
opts.srcIP = ip
}
}

// Resolver is a name resolver for domain name.
// It contains a list of name servers.
type Resolver interface {
Expand Down Expand Up @@ -177,6 +194,7 @@ type resolver struct {
stopped chan struct{}
mux sync.RWMutex
prefer string // ipv4 or ipv6
srcIP net.IP // for edns0 subnet option
options resolverOptions
}

Expand Down Expand Up @@ -217,6 +235,12 @@ func (r *resolver) Init(opts ...ResolverOption) error {
if r.options.ttl != 0 {
r.ttl = r.options.ttl
}
if r.options.prefer != "" {
r.prefer = r.options.prefer
}
if r.options.srcIP != nil {
r.srcIP = r.options.srcIP
}

var nss []NameServer
for _, ns := range r.servers {
Expand Down Expand Up @@ -259,8 +283,9 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
host = host + "." + domain
}

ctx := context.Background()
for _, ns := range r.copyServers() {
ips, err = r.resolve(ns.exchanger, host)
ips, err = r.resolve(ctx, ns.exchanger, host)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
continue
Expand All @@ -277,7 +302,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}

func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
if ex == nil {
return
}
Expand All @@ -286,7 +311,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
prefer := r.prefer
r.mux.RUnlock()

ctx := context.Background()
if prefer == "ipv6" { // prefer ipv6
mq := &dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
Expand All @@ -302,9 +326,15 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
}

func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
mr, _, err := r.exchangeMsg(ctx, ex, mq)
if err != nil {
return
key := newResolverCacheKey(&mq.Question[0])
mr := r.cache.loadCache(key)
if mr == nil {
r.addSubnetOpt(mq)
mr, err = r.exchangeMsg(ctx, ex, mq)
if err != nil {
return
}
r.cache.storeCache(key, mr, r.TTL())
}

for _, ans := range mr.Answer {
Expand All @@ -319,49 +349,73 @@ func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (i
return
}

func (r *resolver) addSubnetOpt(m *dns.Msg) {
if m == nil || r.srcIP == nil {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip := r.srcIP.To4(); ip != nil {
e.Family = 1
e.SourceNetmask = 32
e.Address = ip.To4()
} else {
e.Family = 2
e.SourceNetmask = 128
e.Address = r.srcIP
}
opt.Option = append(opt.Option, e)
m.Extra = append(m.Extra, opt)
}

func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
mq := &dns.Msg{}
if err = mq.Unpack(query); err != nil {
return
}

var qs string
if len(mq.Question) > 0 {
qs = mq.Question[0].String()
if len(mq.Question) == 0 {
return nil, errors.New("empty question")
}

var mr *dns.Msg
for _, ns := range r.copyServers() {
var cache bool
mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq)
log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs)
if err == nil {
break
}
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
}
if err != nil {
return
}
return mr.Pack()
}

func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) {
// Only cache for single question.
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
if mr != nil {
cache = true
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return
return mr.Pack()
}

defer func() {
r.cache.storeCache(key, mr, r.TTL())
if mr != nil {
r.cache.storeCache(key, mr, r.TTL())
}
}()
}

r.addSubnetOpt(mq)

for _, ns := range r.copyServers() {
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
if err == nil {
break
}
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
}
if err != nil {
return
}
return mr.Pack()
}

func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
query, err := mq.Pack()
if err != nil {
return
Expand All @@ -386,6 +440,7 @@ func (r *resolver) TTL() time.Duration {
func (r *resolver) Reload(rd io.Reader) error {
var ttl, timeout, period time.Duration
var domain, prefer string
var srcIP net.IP
var nss []NameServer

if rd == nil || r.Stopped() {
Expand Down Expand Up @@ -422,6 +477,10 @@ func (r *resolver) Reload(rd io.Reader) error {
if len(ss) > 1 {
prefer = strings.ToLower(ss[1])
}
case "ip":
if len(ss) > 1 {
srcIP = net.ParseIP(ss[1])
}
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
if len(ss) <= 1 {
break
Expand Down Expand Up @@ -461,6 +520,7 @@ func (r *resolver) Reload(rd io.Reader) error {
r.domain = domain
r.period = period
r.prefer = prefer
r.srcIP = srcIP
r.servers = nss
r.mux.Unlock()

Expand Down

0 comments on commit cbc9c1f

Please sign in to comment.