diff --git a/be1-go/hub/standard_hub/helpers/data_structures.go b/be1-go/hub/standard_hub/helpers/data_structures.go new file mode 100644 index 0000000000..142d6ddbf2 --- /dev/null +++ b/be1-go/hub/standard_hub/helpers/data_structures.go @@ -0,0 +1,60 @@ +package helpers + +import ( + "maps" + "popstellar/channel" + "slices" + "sync" +) + +// MessageIds provides a thread-safe structure that stores a channel id with its corresponding message ids +type IdsByChannel struct { + sync.RWMutex + table map[string][]string +} + +func NewIdsByChannel() IdsByChannel { + return IdsByChannel{ + table: make(map[string][]string), + } +} + +func (i *IdsByChannel) Add(channel string, id string) { + i.Lock() + defer i.Unlock() + messageIds, channelStored := i.table[channel] + if !channelStored { + i.table[channel] = append(i.table[channel], id) + return + } + alreadyStoredId := slices.Contains(messageIds, id) + if !alreadyStoredId { + i.table[channel] = append(i.table[channel], id) + } +} + +func (i *IdsByChannel) GetAll() map[string][]string { + i.RLock() + defer i.RUnlock() + tableCopy := make(map[string][]string) + maps.Copy(tableCopy, i.table) + return tableCopy +} + +func (i *IdsByChannel) IsEmpty() bool { + i.RLock() + defer i.RUnlock() + + return len(i.table) == 0 +} + +type ChannelByID struct { + sync.RWMutex + table map[string]channel.Channel +} + +func NewChannelById() ChannelByID { + return ChannelByID{ + table: make(map[string]channel.Channel), + } +} diff --git a/be1-go/hub/standard_hub/message_handling.go b/be1-go/hub/standard_hub/message_handling.go index e465427dd4..86bcef3cd4 100644 --- a/be1-go/hub/standard_hub/message_handling.go +++ b/be1-go/hub/standard_hub/message_handling.go @@ -78,7 +78,7 @@ func (h *Hub) handleRootChannelPublishMessage(sock socket.Socket, publish method h.rootInbox.StoreMessage(publish.Params.Message) h.hubInbox.StoreMessage(publish.Params.Message) - h.addMessageId(publish.Params.Channel, publish.Params.Message.MessageID) + h.messageIdsByChannel.Add(publish.Params.Channel, publish.Params.Message.MessageID) return nil } @@ -143,7 +143,7 @@ func (h *Hub) handleRootChannelBroadcastMessage(sock socket.Socket, h.rootInbox.StoreMessage(broadcast.Params.Message) h.hubInbox.StoreMessage(broadcast.Params.Message) - h.addMessageId(broadcast.Params.Channel, broadcast.Params.Message.MessageID) + h.messageIdsByChannel.Add(broadcast.Params.Channel, broadcast.Params.Message.MessageID) return nil } @@ -275,7 +275,7 @@ func (h *Hub) handlePublish(socket socket.Socket, byteMessage []byte) (int, erro return publish.ID, err } h.hubInbox.StoreMessage(publish.Params.Message) - h.addMessageId(publish.Params.Channel, publish.Params.Message.MessageID) + h.messageIdsByChannel.Add(publish.Params.Channel, publish.Params.Message.MessageID) return publish.ID, nil } @@ -295,7 +295,7 @@ func (h *Hub) handlePublish(socket socket.Socket, byteMessage []byte) (int, erro } h.hubInbox.StoreMessage(publish.Params.Message) - h.addMessageId(publish.Params.Channel, publish.Params.Message.MessageID) + h.messageIdsByChannel.Add(publish.Params.Channel, publish.Params.Message.MessageID) return publish.ID, nil } @@ -325,7 +325,7 @@ func (h *Hub) handleBroadcast(socket socket.Socket, byteMessage []byte) error { return nil } h.hubInbox.StoreMessage(broadcast.Params.Message) - h.addMessageId(broadcast.Params.Channel, broadcast.Params.Message.MessageID) + h.messageIdsByChannel.Add(broadcast.Params.Channel, broadcast.Params.Message.MessageID) h.Unlock() @@ -435,7 +435,7 @@ func (h *Hub) handleHeartbeat(socket socket.Socket, receivedIds := heartbeat.Params - missingIds := getMissingIds(receivedIds, h.messageIdsByChannel, h.blacklist) + missingIds := getMissingIds(receivedIds, h.messageIdsByChannel.GetAll(), h.blacklist) if len(missingIds) > 0 { err = h.sendGetMessagesByIdToServer(socket, missingIds) @@ -592,7 +592,7 @@ func (h *Hub) handleReceivedMessage(socket socket.Socket, messageData message.Me h.Lock() h.hubInbox.StoreMessage(publish.Params.Message) - h.addMessageId(publish.Params.Channel, publish.Params.Message.MessageID) + h.messageIdsByChannel.Add(publish.Params.Channel, publish.Params.Message.MessageID) h.Unlock() return nil } diff --git a/be1-go/hub/standard_hub/mod.go b/be1-go/hub/standard_hub/mod.go index 76478906a2..f26db06977 100644 --- a/be1-go/hub/standard_hub/mod.go +++ b/be1-go/hub/standard_hub/mod.go @@ -4,9 +4,9 @@ import ( "context" "encoding/base64" "encoding/json" - "golang.org/x/exp/slices" "popstellar/channel" "popstellar/crypto" + "popstellar/hub/standard_hub/helpers" "popstellar/inbox" jsonrpc "popstellar/message" "popstellar/message/answer" @@ -88,7 +88,7 @@ type Hub struct { // messageIdsByChannel stores all the message ids and the corresponding channel ids // to help servers determine in which channel the message ids go - messageIdsByChannel map[string][]string + messageIdsByChannel helpers.IdsByChannel // peersInfo stores the info of the peers: public key, client and server endpoints associated with the socket ID peersInfo map[string]method.ServerInfo @@ -154,7 +154,7 @@ func NewHub(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAdd hubInbox: *inbox.NewInbox(rootChannel), rootInbox: *inbox.NewInbox(rootChannel), queries: newQueries(), - messageIdsByChannel: make(map[string][]string), + messageIdsByChannel: helpers.NewIdsByChannel(), peersInfo: make(map[string]method.ServerInfo), peersGreeted: make([]string, 0), blacklist: make([]string, 0), @@ -528,26 +528,24 @@ func (h *Hub) sendGetMessagesByIdToServer(socket socket.Socket, missingIds map[s // sendHeartbeatToServers sends a heartbeat message to all servers func (h *Hub) sendHeartbeatToServers() { - h.Lock() - defer h.Unlock() - if len(h.messageIdsByChannel) > 0 { - heartbeatMessage := method.Heartbeat{ - Base: query.Base{ - JSONRPCBase: jsonrpc.JSONRPCBase{ - JSONRPC: "2.0", - }, - Method: "heartbeat", + if h.messageIdsByChannel.IsEmpty() { + return + } + heartbeatMessage := method.Heartbeat{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", }, - Params: h.messageIdsByChannel, - } - - buf, err := json.Marshal(heartbeatMessage) - if err != nil { - h.log.Err(err).Msg("Failed to marshal and send heartbeat query") - } + Method: "heartbeat", + }, + Params: h.messageIdsByChannel.GetAll(), + } - h.serverSockets.SendToAll(buf) + buf, err := json.Marshal(heartbeatMessage) + if err != nil { + h.log.Err(err).Msg("Failed to marshal and send heartbeat query") } + h.serverSockets.SendToAll(buf) } // createLao creates a new LAO using the data in the publish parameter. @@ -667,16 +665,3 @@ func generateKeys() (kyber.Point, kyber.Scalar) { return point, secret } - -// addMessageId adds a message ID to the map of messageIds by channel of the hub -func (h *Hub) addMessageId(channelId string, messageId string) { - messageIds, channelStored := h.messageIdsByChannel[channelId] - if !channelStored { - h.messageIdsByChannel[channelId] = append(h.messageIdsByChannel[channelId], messageId) - } else { - alreadyStored := slices.Contains(messageIds, messageId) - if !alreadyStored { - h.messageIdsByChannel[channelId] = append(h.messageIdsByChannel[channelId], messageId) - } - } -}