Skip to content
This repository has been archived by the owner on May 25, 2023. It is now read-only.

Commit

Permalink
Make IPSet concurrency-safe.
Browse files Browse the repository at this point in the history
Addresses the non-documentation bits of #110.

Signed-off-by: David Anderson <[email protected]>
  • Loading branch information
danderson committed Jan 29, 2021
1 parent 06debf9 commit 54dbd4c
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 24 deletions.
5 changes: 4 additions & 1 deletion inlining_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func TestInlining(t *testing.T) {
})
for _, want := range []string{
"(*IPSet).Add",
"(*IPSet).Clone",
"(*IPSet).Remove",
"(*IPSet).RemoveRange",
"(*IPSet).removeRangeLocked",
"(*uint128).halves",
"IP.BitLen",
"IP.IPAddr",
Expand All @@ -60,6 +61,7 @@ func TestInlining(t *testing.T) {
"IP.Prior",
"IP.Unmap",
"IP.Zone",
"IP.lessOrEq",
"IP.v4",
"IP.v6",
"IP.v6u16",
Expand All @@ -74,6 +76,7 @@ func TestInlining(t *testing.T) {
"IPPrefix.Masked",
"IPPrefix.Valid",
"IPRange.Prefixes",
"IPRange.entirelyBefore",
"IPRange.prefixFrom128AndBits",
"IPRange.prefixFrom128AndBits-fm",
"IPv4",
Expand Down
81 changes: 61 additions & 20 deletions ipset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

package netaddr

import "sort"
import (
"sort"
"sync"
)

// IPSet represents a set of IP addresses.
//
Expand All @@ -15,21 +18,15 @@ import "sort"
// nothing on an empty set. Ranges may be fully, partially, or not
// overlapping.
type IPSet struct {
mu sync.Mutex // protects all fields

// in are the ranges in the set.
in []IPRange

// out are the ranges to be removed from 'in'.
out []IPRange
}

// toInOnly updates s to clear s.out, by merging any s.out into s.in.
func (s *IPSet) toInOnly() {
if len(s.out) > 0 {
s.in = s.Ranges()
s.out = nil
}
}

// Clone returns a copy of s that shares no memory with s.
func (s *IPSet) Clone() *IPSet {
return &IPSet{
Expand All @@ -45,12 +42,20 @@ func (s *IPSet) AddPrefix(p IPPrefix) { s.AddRange(p.Range()) }

// AddRange adds r to s.
func (s *IPSet) AddRange(r IPRange) {
s.mu.Lock()
defer s.mu.Unlock()
s.addRangeLocked(r)
}

func (s *IPSet) addRangeLocked(r IPRange) {
if !r.Valid() {
return
}
// If there are any removals (s.out), then we need to compact the set
// first to get the order right.
s.toInOnly()
if len(s.out) > 0 {
s.rangesLocked()
}
s.in = append(s.in, r)
}

Expand All @@ -59,7 +64,9 @@ func (s *IPSet) Remove(ip IP) { s.RemoveRange(IPRange{ip, ip}) }

// RemoveFreePrefix removes and returns a Prefix of length bits from the IPSet.
func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p IPPrefix, ok bool) {
prefixes := s.Prefixes()
s.mu.Lock()
defer s.mu.Unlock()
prefixes := s.prefixesLocked()
if len(prefixes) == 0 {
return IPPrefix{}, false
}
Expand All @@ -83,7 +90,7 @@ func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p IPPrefix, ok bool) {
}

prefix := IPPrefix{IP: bestFit.IP, Bits: bitLen}
s.RemovePrefix(prefix)
s.removeRangeLocked(prefix.Range())
return prefix, true
}

Expand All @@ -92,30 +99,43 @@ func (s *IPSet) RemovePrefix(p IPPrefix) { s.RemoveRange(p.Range()) }

// RemoveRange removes r from s.
func (s *IPSet) RemoveRange(r IPRange) {
s.mu.Lock()
defer s.mu.Unlock()
s.removeRangeLocked(r)
}

func (s *IPSet) removeRangeLocked(r IPRange) {
if r.Valid() {
s.out = append(s.out, r)
}
}

// AddSet adds all ranges in b to s.
func (s *IPSet) AddSet(b *IPSet) {
for _, r := range b.Ranges() {
s.AddRange(r)
rr := b.Ranges()
s.mu.Lock()
defer s.mu.Unlock()
for _, r := range rr {
s.addRangeLocked(r)
}
}

// RemoveSet removes all ranges in b from s.
func (s *IPSet) RemoveSet(b *IPSet) {
for _, r := range b.Ranges() {
s.RemoveRange(r)
rr := b.Ranges()
s.mu.Lock()
defer s.mu.Unlock()
for _, r := range rr {
s.removeRangeLocked(r)
}
}

// Complement updates s to contain the complement of its current
// contents.
func (s *IPSet) Complement() {
s.toInOnly()
s.out = s.in
s.mu.Lock()
defer s.mu.Unlock()
s.out = s.rangesLocked()
s.in = []IPRange{
IPPrefix{IP: IPv4(0, 0, 0, 0), Bits: 0}.Range(),
IPPrefix{IP: IPv6Unspecified(), Bits: 0}.Range(),
Expand Down Expand Up @@ -154,6 +174,12 @@ var debugf = discardf
// Ranges returns the minimum and sorted set of IP
// ranges that covers s.
func (s *IPSet) Ranges() []IPRange {
s.mu.Lock()
defer s.mu.Unlock()
return s.rangesLocked()
}

func (s *IPSet) rangesLocked() []IPRange {
const debug = false
if debug {
debugf("ranges start in=%v out=%v", s.in, s.out)
Expand All @@ -170,6 +196,14 @@ func (s *IPSet) Ranges() []IPRange {
debugf("ranges sort in=%v out=%v", in, out)
}

if len(out) == 0 {
// Fast path that avoids allocating further, if no removals
// are needed.
s.in = in
s.out = nil
return s.in
}

// in and out are sorted in ascending range order, and have no
// overlaps within each other. We can run a merge of the two lists
// in one pass.
Expand Down Expand Up @@ -274,7 +308,8 @@ func (s *IPSet) Ranges() []IPRange {
}
}

// TODO: possibly update s.in and s.out, if #110 supports that.
s.in = ret
s.out = nil

return ret
}
Expand All @@ -284,8 +319,14 @@ func (s *IPSet) Ranges() []IPRange {
// returning a new slice of prefixes that covers all of the given 'add'
// prefixes with all the 'remove' prefixes removed.
func (s *IPSet) Prefixes() []IPPrefix {
s.mu.Lock()
defer s.mu.Unlock()
return s.prefixesLocked()
}

func (s *IPSet) prefixesLocked() []IPPrefix {
var out []IPPrefix
for _, r := range s.Ranges() {
for _, r := range s.rangesLocked() {
out = append(out, r.Prefixes()...)
}
return out
Expand Down
4 changes: 2 additions & 2 deletions ipset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,11 @@ func TestIPSetOverlaps(t *testing.T) {
for _, test := range tests {
got := test.a.Overlaps(test.b)
if got != test.want {
t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.a, test.b, got, test.want)
t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.a.Ranges(), test.b.Ranges(), got, test.want)
}
got = test.b.Overlaps(test.a)
if got != test.want {
t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.b, test.a, got, test.want)
t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.b.Ranges(), test.a.Ranges(), got, test.want)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion netaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ func mergeIPRanges(rr []IPRange) (out []IPRange, valid bool) {
// the caller.
switch len(rr) {
case 0:
return nil, true
return []IPRange{}, true
case 1:
return []IPRange{rr[0]}, true
}
Expand Down

0 comments on commit 54dbd4c

Please sign in to comment.