Skip to content

Commit

Permalink
fix: concurrent map writes (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
nii236 authored Jul 4, 2023
1 parent b783885 commit ae388da
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 27 deletions.
76 changes: 55 additions & 21 deletions relay/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package relay
import (
"encoding/json"
"strings"
"sync"

"go.uber.org/zap/zapcore"
)
Expand Down Expand Up @@ -81,30 +82,43 @@ func cachedMessageKey(topic string) string {
}

// TopicClientSet stores topic -> clients relationship
type TopicClientSet map[string]map[*client]struct{}
type TopicClientSet struct {
*sync.RWMutex
Data map[string]map[*client]struct{}
}

func NewTopicClientSet() TopicClientSet {
return make(map[string]map[*client]struct{})
func NewTopicClientSet() *TopicClientSet {
return &TopicClientSet{
RWMutex: &sync.RWMutex{},
Data: map[string]map[*client]struct{}{},
}
}

func (ts TopicClientSet) Get(topic string) map[*client]struct{} {
return ts[topic]
func (ts *TopicClientSet) Get(topic string) map[*client]struct{} {
ts.RLock()
defer ts.RUnlock()
return ts.Data[topic]
}

func (ts TopicClientSet) Set(topic string, c *client) {
if _, ok := ts[topic]; !ok {
ts[topic] = make(map[*client]struct{})
func (ts *TopicClientSet) Set(topic string, c *client) {
ts.Lock()
defer ts.Unlock()
if _, ok := ts.Data[topic]; !ok {
ts.Data[topic] = make(map[*client]struct{})
}

ts[topic][c] = struct{}{}
ts.Data[topic][c] = struct{}{}
}

// GetTopicsByClient returns the topics associated with the specified client,
// meanwhile, remove the client from these topics if `clear` is true
// returns the topics the client has associated with
func (ts TopicClientSet) GetTopicsByClient(c *client, clear bool) []string {
func (ts *TopicClientSet) GetTopicsByClient(c *client, clear bool) []string {
// Write lock in a read func because we may remove the client from the topics
ts.Lock()
defer ts.Unlock()
topics := []string{}
for topic, set := range ts {
for topic, set := range ts.Data {
if _, ok := set[c]; ok {
topics = append(topics, topic)
}
Expand All @@ -115,26 +129,46 @@ func (ts TopicClientSet) GetTopicsByClient(c *client, clear bool) []string {
return topics
}

func (ts TopicClientSet) Unset(topic string, c *client) {
delete(ts[topic], c)
func (ts *TopicClientSet) Unset(topic string, c *client) {
ts.Lock()
defer ts.Unlock()
delete(ts.Data[topic], c)
}

func (ts *TopicClientSet) Len(topic string) int {
ts.RLock()
defer ts.RUnlock()
return len(ts.Data[topic])
}

func (ts TopicClientSet) Len(topic string) int {
return len(ts[topic])
func (ts *TopicClientSet) Clear(topic string) {
ts.Lock()
defer ts.Unlock()
delete(ts.Data, topic)
}

func (ts TopicClientSet) Clear(topic string) {
delete(ts, topic)
type TopicSet struct {
*sync.RWMutex
Data map[string]struct{}
}

type TopicSet map[string]struct{}
func NewTopicSet() *TopicSet {
return &TopicSet{
RWMutex: &sync.RWMutex{},
Data: map[string]struct{}{},
}
}

func NewTopicSet() TopicSet {
return make(map[string]struct{})
func (tm TopicSet) Set(topic string) {
tm.Lock()
defer tm.Unlock()
tm.Data[topic] = struct{}{}
}

func (tm TopicSet) MarshalLogArray(encoder zapcore.ArrayEncoder) error {
for topic := range tm {
tm.Lock()
defer tm.Unlock()
for topic := range tm.Data {
encoder.AppendString(topic)
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions relay/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type client struct {
active bool // heartbeat related
role RoleType // dapp or wallet
session string // session id
pubTopics TopicSet
subTopics TopicSet
pubTopics *TopicSet
subTopics *TopicSet

sendbuf chan SocketMessage // send buffer
ping chan struct{}
Expand Down
8 changes: 4 additions & 4 deletions relay/wsserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type WsServer struct {
redisConn *redis.Client
redisSubConn *redis.PubSub

publishers TopicClientSet
subscribers TopicClientSet
publishers *TopicClientSet
subscribers *TopicClientSet

localCh chan SocketMessage // for handling message of local clients
}
Expand Down Expand Up @@ -109,12 +109,12 @@ func (ws *WsServer) Run() {
switch message.Type {
case Pub:
// do not modify wsserver's local variable in seperate goroutine
message.client.pubTopics[message.Topic] = struct{}{}
message.client.pubTopics.Set(message.Topic)
ws.publishers.Set(message.Topic, message.client)
go ws.pubMessage(message)
log.Info("local message", zap.Any("client", message.client), zap.Any("message", message))
case Sub:
message.client.subTopics[message.Topic] = struct{}{}
message.client.subTopics.Set(message.Topic)
ws.subscribers.Set(message.Topic, message.client)
go ws.subMessage(message)
log.Info("local message", zap.Any("client", message.client), zap.Any("message", message))
Expand Down

0 comments on commit ae388da

Please sign in to comment.