forked from trpc-group/trpc-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconsistenthash.go
199 lines (178 loc) · 5.13 KB
/
consistenthash.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Tencent is pleased to support the open source community by making tRPC available.
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
// If you have downloaded a copy of the tRPC source code from Tencent,
// please note that tRPC source code is licensed under the Apache 2.0 License that can be found in the LICENSE file.
// Package consistenthash provides consistent hash utilities.
package consistenthash
import (
"errors"
"fmt"
"sort"
"strconv"
"sync"
"time"
"github.com/cespare/xxhash"
"trpc.group/trpc-go/trpc-go/naming/loadbalance"
"trpc.group/trpc-go/trpc-go/naming/registry"
)
// defaultReplicas is the default virtual node coefficient.
const (
defaultReplicas int = 100
prime = 16777619
)
// Hash is the hash function type.
type Hash func(data []byte) uint64
// defaultHashFunc uses CRC32 as the default.
var defaultHashFunc Hash = xxhash.Sum64
func init() {
loadbalance.Register("consistent_hash", NewConsistentHash())
}
// NewConsistentHash creates a new ConsistentHash.
func NewConsistentHash() *ConsistentHash {
return &ConsistentHash{
pickers: new(sync.Map),
hashFunc: defaultHashFunc,
}
}
// NewCustomConsistentHash creates a new ConsistentHash with custom hash function.
func NewCustomConsistentHash(hashFunc Hash) *ConsistentHash {
return &ConsistentHash{
pickers: new(sync.Map),
hashFunc: hashFunc,
}
}
// ConsistentHash defines the consistent hash.
type ConsistentHash struct {
pickers *sync.Map
interval time.Duration
hashFunc Hash
}
// Select implements loadbalance.LoadBalancer.
func (ch *ConsistentHash) Select(serviceName string, list []*registry.Node,
opt ...loadbalance.Option) (*registry.Node, error) {
opts := &loadbalance.Options{}
for _, o := range opt {
o(opts)
}
p, ok := ch.pickers.Load(serviceName)
if ok {
return p.(*chPicker).Pick(list, opts)
}
newPicker := &chPicker{
interval: ch.interval,
hashFunc: ch.hashFunc,
}
v, ok := ch.pickers.LoadOrStore(serviceName, newPicker)
if !ok {
return newPicker.Pick(list, opts)
}
return v.(*chPicker).Pick(list, opts)
}
// chPicker is the picker of the consistent hash.
type chPicker struct {
list []*registry.Node
hashFunc Hash
keys Uint64Slice // a hash slice of sorted node list, it's length is #(node)*replica
hashMap map[uint64][]*registry.Node // a map which keeps hash-nodes maps
mu sync.Mutex
interval time.Duration
}
// Pick picks a node.
func (p *chPicker) Pick(list []*registry.Node, opts *loadbalance.Options) (*registry.Node, error) {
if len(list) == 0 {
return nil, loadbalance.ErrNoServerAvailable
}
// Returns error if opts.Key is not provided.
if opts.Key == "" {
return nil, errors.New("missing key")
}
tmpKeys, tmpMap, err := p.updateState(list, opts.Replicas)
if err != nil {
return nil, err
}
hash := p.hashFunc([]byte(opts.Key))
// Find the best matched node by binary search. Node A is better than B if A's hash value is
// greater than B's.
idx := sort.Search(len(tmpKeys), func(i int) bool { return tmpKeys[i] >= hash })
if idx == len(tmpKeys) {
idx = 0
}
nodes, ok := tmpMap[tmpKeys[idx]]
if !ok {
return nil, loadbalance.ErrNoServerAvailable
}
switch len(nodes) {
case 1:
return nodes[0], nil
default:
innerIndex := p.hashFunc(innerRepr(opts.Key))
pos := int(innerIndex % uint64(len(nodes)))
return nodes[pos], nil
}
}
// updateState recalculates list every so often if nodes changed.
func (p *chPicker) updateState(list []*registry.Node, replicas int) (Uint64Slice, map[uint64][]*registry.Node, error) {
p.mu.Lock()
defer p.mu.Unlock()
// if node list is the same as last update, there is no need to update hash ring.
if isNodeSliceEqualBCE(p.list, list) {
return p.keys, p.hashMap, nil
}
actualReplicas := replicas
if actualReplicas <= 0 {
actualReplicas = defaultReplicas
}
// update node list.
p.list = list
p.hashMap = make(map[uint64][]*registry.Node)
p.keys = make(Uint64Slice, len(list)*actualReplicas)
for i, node := range list {
if node == nil {
// node must not be nil.
return nil, nil, errors.New("list contains nil node")
}
for j := 0; j < actualReplicas; j++ {
hash := p.hashFunc([]byte(strconv.Itoa(j) + node.Address))
p.keys[i*(actualReplicas)+j] = hash
p.hashMap[hash] = append(p.hashMap[hash], node)
}
}
sort.Sort(p.keys)
return p.keys, p.hashMap, nil
}
// Uint64Slice defines uint64 slice.
type Uint64Slice []uint64
// Len returns the length of the slice.
func (s Uint64Slice) Len() int {
return len(s)
}
// Less returns whether the value at i is less than j.
func (s Uint64Slice) Less(i, j int) bool {
return s[i] < s[j]
}
// Swap swaps values between i and j.
func (s Uint64Slice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// isNodeSliceEqualBCE check whether two node list is equal by BCE.
func isNodeSliceEqualBCE(a, b []*registry.Node) bool {
if len(a) != len(b) {
return false
}
if (a == nil) != (b == nil) {
return false
}
b = b[:len(a)]
for i, v := range a {
if (v == nil) != (b[i] == nil) {
return false
}
if v.Address != b[i].Address {
return false
}
}
return true
}
func innerRepr(key interface{}) []byte {
return []byte(fmt.Sprintf("%d:%v", prime, key))
}