From b54d575a50a3bc10694fbda3d6b67191962e91c6 Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Mon, 15 Jan 2024 10:50:03 -0500 Subject: [PATCH 1/4] ci: push container image tagged mls-dev from this branch for testing --- .github/workflows/push-mls.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/push-mls.yml b/.github/workflows/push-mls.yml index d55e127c..68f946d3 100644 --- a/.github/workflows/push-mls.yml +++ b/.github/workflows/push-mls.yml @@ -2,7 +2,7 @@ name: Push MLS Container on: push: branches: - - mls + - snor/mls-subscribe jobs: deploy: concurrency: main @@ -28,7 +28,7 @@ jobs: - name: Push id: push run: | - export DOCKER_IMAGE_TAG=mls + export DOCKER_IMAGE_TAG=mls-dev IMAGE_TO_DEPLOY=xmtp/node-go@$(dev/docker/build) echo Successfully pushed $IMAGE_TO_DEPLOY echo "docker_image=${IMAGE_TO_DEPLOY}" >> $GITHUB_OUTPUT From 798df89c4a30d305a2e6e3b972e82f295488f39a Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Mon, 15 Jan 2024 15:36:22 -0500 Subject: [PATCH 2/4] feat: mls subscribe via db --- pkg/mls/api/v1/service.go | 336 +++++++++++++++++++++++++---- pkg/mls/api/v1/service_test.go | 371 +++++++++++++++++++++++++++++---- pkg/mls/store/store.go | 110 +++++----- pkg/mls/store/store_test.go | 214 ++++++++++--------- 4 files changed, 806 insertions(+), 225 deletions(-) diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 348c149c..83359c65 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -330,11 +330,69 @@ func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcom } func (s *Service) QueryGroupMessages(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { - return s.store.QueryGroupMessagesV1(ctx, req) + if req.PagingInfo == nil { + req.PagingInfo = &mlsv1.PagingInfo{} + } + if req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { + req.PagingInfo.Direction = mlsv1.SortDirection_SORT_DIRECTION_DESCENDING + } + if req.PagingInfo.Limit == 0 || req.PagingInfo.Limit > mlsstore.MaxQueryPageSize { + req.PagingInfo.Limit = mlsstore.MaxQueryPageSize + } + + msgs, err := s.store.QueryGroupMessagesV1(ctx, req) + if err != nil { + return nil, err + } + + pbMsgs := make([]*mlsv1.GroupMessage, 0, len(msgs)) + for _, msg := range msgs { + pbMsgs = append(pbMsgs, toProtoGroupMessage(msg)) + } + + pagingInfo := &mlsv1.PagingInfo{Limit: uint32(req.PagingInfo.Limit), IdCursor: 0, Direction: req.PagingInfo.Direction} + if len(pbMsgs) >= int(req.PagingInfo.Limit) { + lastMsg := msgs[len(pbMsgs)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + return &mlsv1.QueryGroupMessagesResponse{ + Messages: pbMsgs, + PagingInfo: pagingInfo, + }, nil } func (s *Service) QueryWelcomeMessages(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { - return s.store.QueryWelcomeMessagesV1(ctx, req) + if req.PagingInfo == nil { + req.PagingInfo = &mlsv1.PagingInfo{} + } + if req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { + req.PagingInfo.Direction = mlsv1.SortDirection_SORT_DIRECTION_DESCENDING + } + if req.PagingInfo.Limit == 0 || req.PagingInfo.Limit > mlsstore.MaxQueryPageSize { + req.PagingInfo.Limit = mlsstore.MaxQueryPageSize + } + + msgs, err := s.store.QueryWelcomeMessagesV1(ctx, req) + if err != nil { + return nil, err + } + + pbMsgs := make([]*mlsv1.WelcomeMessage, 0, len(msgs)) + for _, msg := range msgs { + pbMsgs = append(pbMsgs, toProtoWelcomeMessage(msg)) + } + + pagingInfo := &mlsv1.PagingInfo{Limit: uint32(req.PagingInfo.Limit), IdCursor: 0, Direction: req.PagingInfo.Direction} + if len(pbMsgs) >= int(req.PagingInfo.Limit) { + lastMsg := msgs[len(pbMsgs)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + return &mlsv1.QueryWelcomeMessagesResponse{ + Messages: pbMsgs, + PagingInfo: pagingInfo, + }, nil } func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesRequest, stream mlsv1.MlsApi_SubscribeGroupMessagesServer) error { @@ -344,24 +402,22 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques // See: https://github.com/xmtp/libxmtp/pull/58 _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) - var streamLock sync.Mutex + var hasMessagesLock sync.Mutex + var hasMessages bool + setHasMessages := func() { + hasMessagesLock.Lock() + defer hasMessagesLock.Unlock() + hasMessages = true + } + + var retErr error + for _, filter := range req.Filters { + filter := filter + natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId) sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { - var msg mlsv1.GroupMessage - err := pb.Unmarshal(natsMsg.Data, &msg) - if err != nil { - log.Error("parsing group message from bytes", zap.Error(err)) - return - } - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(&msg) - if err != nil { - log.Error("sending group message to subscribe", zap.Error(err)) - } - }() + setHasMessages() }) if err != nil { log.Error("error subscribing to group messages", zap.Error(err)) @@ -370,14 +426,106 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques defer func() { _ = sub.Unsubscribe() }() + + go func() { + pagingInfo := &mlsv1.PagingInfo{ + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + } + if filter.IdCursor > 0 { + pagingInfo.IdCursor = filter.IdCursor + } else { + latestMsg, err := s.store.GetLatestGroupMessage(stream.Context(), filter.GroupId) + if err != nil && !mlsstore.IsNotFoundError(err) { + log.Error("error getting latest group message", zap.Error(err)) + retErr = err + return + } + if latestMsg != nil { + pagingInfo.IdCursor = latestMsg.Id + } + } + + activeTicker := time.NewTicker(100 * time.Millisecond) + defer activeTicker.Stop() + passiveTicker := time.NewTicker(5 * time.Second) + defer passiveTicker.Stop() + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + case <-passiveTicker.C: + setHasMessages() + case <-activeTicker.C: + var skip bool + func() { + hasMessagesLock.Lock() + defer hasMessagesLock.Unlock() + if !hasMessages { + skip = true + } + hasMessages = false + }() + if skip { + continue + } + + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + default: + } + + msgs, err := s.store.QueryGroupMessagesV1(stream.Context(), &mlsv1.QueryGroupMessagesRequest{ + GroupId: filter.GroupId, + PagingInfo: pagingInfo, + }) + if err != nil { + if err == context.Canceled { + return + } + log.Error("error querying for subscription cursor messages", zap.Error(err)) + // Break out and try again during the next ticker period. + break + } + + for _, msg := range msgs { + pbMsg := toProtoGroupMessage(msg) + err := stream.Send(pbMsg) + if err != nil { + log.Error("error streaming group message", zap.Error(err)) + } + } + + // We can't just use resp.PagingInfo since we always + // want the cursor from the last message even if it's + // the last page. + if len(msgs) > 0 { + lastMsg := msgs[len(msgs)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + if len(msgs) == 0 { + break + } + } + } + } + }() } select { case <-stream.Context().Done(): - return nil + break case <-s.ctx.Done(): - return nil + break } + + return retErr } func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRequest, stream mlsv1.MlsApi_SubscribeWelcomeMessagesServer) error { @@ -387,24 +535,22 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe // See: https://github.com/xmtp/libxmtp/pull/58 _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) - var streamLock sync.Mutex + var hasMessagesLock sync.Mutex + var hasMessages bool + setHasMessages := func() { + hasMessagesLock.Lock() + defer hasMessagesLock.Unlock() + hasMessages = true + } + + var retErr error + for _, filter := range req.Filters { + filter := filter + natsSubject := buildNatsSubjectForWelcomeMessages(filter.InstallationKey) sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { - var msg mlsv1.WelcomeMessage - err := pb.Unmarshal(natsMsg.Data, &msg) - if err != nil { - log.Error("parsing welcome message from bytes", zap.Error(err)) - return - } - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(&msg) - if err != nil { - log.Error("sending welcome message to subscribe", zap.Error(err)) - } - }() + setHasMessages() }) if err != nil { log.Error("error subscribing to welcome messages", zap.Error(err)) @@ -413,14 +559,106 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe defer func() { _ = sub.Unsubscribe() }() + + go func() { + pagingInfo := &mlsv1.PagingInfo{ + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + } + if filter.IdCursor > 0 { + pagingInfo.IdCursor = filter.IdCursor + } else { + latestMsg, err := s.store.GetLatestWelcomeMessage(stream.Context(), filter.InstallationKey) + if err != nil && !mlsstore.IsNotFoundError(err) { + log.Error("error getting latest welcome message", zap.Error(err)) + retErr = err + return + } + if latestMsg != nil { + pagingInfo.IdCursor = latestMsg.Id + } + } + + activeTicker := time.NewTicker(200 * time.Millisecond) + defer activeTicker.Stop() + passiveTicker := time.NewTicker(5 * time.Second) + defer passiveTicker.Stop() + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + case <-passiveTicker.C: + setHasMessages() + case <-activeTicker.C: + var skip bool + func() { + hasMessagesLock.Lock() + defer hasMessagesLock.Unlock() + if !hasMessages { + skip = true + } + hasMessages = false + }() + if skip { + continue + } + + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + default: + } + + msgs, err := s.store.QueryWelcomeMessagesV1(stream.Context(), &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: filter.InstallationKey, + PagingInfo: pagingInfo, + }) + if err != nil { + if err == context.Canceled { + return + } + log.Error("error querying for subscription cursor messages", zap.Error(err)) + // Break out and try again during the next ticker period. + break + } + + for _, msg := range msgs { + pbMsg := toProtoWelcomeMessage(msg) + err := stream.Send(pbMsg) + if err != nil { + log.Error("error streaming welcome message", zap.Error(err)) + } + } + + // We can't just use resp.PagingInfo since we always + // want the cursor from the last message even if it's + // the last page. + if len(msgs) > 0 { + lastMsg := msgs[len(msgs)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + if len(msgs) == 0 { + break + } + } + } + } + }() } select { case <-stream.Context().Done(): - return nil + break case <-s.ctx.Done(): - return nil + break } + + return retErr } func buildNatsSubjectForGroupMessages(groupId []byte) string { @@ -512,3 +750,29 @@ func requireReadyToSend(groupId string, message []byte) error { } return nil } + +func toProtoGroupMessage(msg *mlsstore.GroupMessage) *mlsv1.GroupMessage { + return &mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: msg.Id, + GroupId: msg.GroupId, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + Data: msg.Data, + }, + }, + } +} + +func toProtoWelcomeMessage(msg *mlsstore.WelcomeMessage) *mlsv1.WelcomeMessage { + return &mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: msg.Id, + InstallationKey: msg.InstallationKey, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + Data: msg.Data, + }, + }, + } +} diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index a13495ca..0152a2ab 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "context" "errors" "fmt" @@ -248,13 +249,13 @@ func TestSendGroupMessages(t *testing.T) { }) require.NoError(t, err) - resp, err := svc.store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err := svc.store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: groupId, }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, resp.Messages[0].GetV1().Data, []byte("test")) - require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) + require.Len(t, msgs, 1) + require.Equal(t, msgs[0].Data, []byte("test")) + require.NotEmpty(t, msgs[0].CreatedAt) } func TestSendWelcomeMessages(t *testing.T) { @@ -278,13 +279,13 @@ func TestSendWelcomeMessages(t *testing.T) { }) require.NoError(t, err) - resp, err := svc.store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err := svc.store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: installationId, }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, resp.Messages[0].GetV1().Data, []byte("test")) - require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) + require.Len(t, msgs, 1) + require.Equal(t, msgs[0].Data, []byte("test")) + require.NotEmpty(t, msgs[0].CreatedAt) } func TestGetIdentityUpdates(t *testing.T) { @@ -336,34 +337,49 @@ func TestGetIdentityUpdates(t *testing.T) { require.Len(t, identityUpdates.Updates[0].Updates, 2) } -func TestSubscribeGroupMessages(t *testing.T) { +func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { ctx := context.Background() - svc, _, _, cleanup := newTestService(t, ctx) + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() groupId := []byte(test.RandomString(32)) + // Initial message that does not get included in the stream. + mlsValidationService.mockValidateGroupMessages(groupId) + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data0"), + }, + }, + }, + }, + }) + require.NoError(t, err) + + // Set of 10 messages that are included in the stream. msgs := make([]*mlsv1.GroupMessage, 10) for i := 0; i < 10; i++ { msgs[i] = &mlsv1.GroupMessage{ Version: &mlsv1.GroupMessage_V1_{ V1: &mlsv1.GroupMessage_V1{ - Id: uint64(i + 1), - CreatedNs: uint64(i + 1), - GroupId: groupId, - Data: []byte(fmt.Sprintf("data%d", i+1)), + GroupId: groupId, + Data: []byte(fmt.Sprintf("data%d", i+1)), }, }, } } + // Set up expectations of streaming the 10 messages. ctrl := gomock.NewController(t) stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) for _, msg := range msgs { - stream.EXPECT().Send(newGroupMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(newGroupMessageDataEqualsMatcher(msg)).Return(nil).Times(1) } - stream.EXPECT().Context().Return(ctx) + stream.EXPECT().Context().Return(ctx).AnyTimes() go func() { err := svc.SubscribeGroupMessages(&mlsv1.SubscribeGroupMessagesRequest{ @@ -377,35 +393,183 @@ func TestSubscribeGroupMessages(t *testing.T) { }() time.Sleep(50 * time.Millisecond) - for _, msg := range msgs { + // Send the messages (store and relay). + for i, msg := range msgs { + mlsValidationService.mockValidateGroupMessages(groupId) + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: msg.GetV1().Data, + }, + }, + }, + }, + }) + require.NoError(t, err) + msgB, err := proto.Marshal(msg) require.NoError(t, err) + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + + if i == 4 { + time.Sleep(200 * time.Millisecond) + } + } + // Expectations should eventually be satisfied. + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + +func TestSubscribeGroupMessages_WithCursor(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + groupId := []byte(test.RandomString(32)) + + // Initial message before stream starts. + mlsValidationService.mockValidateGroupMessages(groupId) + initialMsgs := []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data1"), + }, + }, + }, + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data2"), + }, + }, + }, + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data3"), + }, + }, + }, + } + for _, msg := range initialMsgs { + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{msg}, + }) + require.NoError(t, err) + } + + // Set of 10 messages that are included in the stream. + msgs := make([]*mlsv1.GroupMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + GroupId: groupId, + Data: []byte(fmt.Sprintf("data%d", i+4)), + }, + }, + } + } + + // Set up expectations of streaming the 10 messages. + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream.EXPECT().Send(newGroupMessageDataEqualsMatcher(&mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Data: []byte("data3"), + }, + }, + })).Return(nil).Times(1) + for _, msg := range msgs { + stream.EXPECT().Send(newGroupMessageDataEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx).AnyTimes() + + go func() { + err := svc.SubscribeGroupMessages(&mlsv1.SubscribeGroupMessagesRequest{ + Filters: []*mlsv1.SubscribeGroupMessagesRequest_Filter{ + { + GroupId: groupId, + IdCursor: 2, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + // Send the messages (store and relay). + for i, msg := range msgs { + mlsValidationService.mockValidateGroupMessages(groupId) + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: msg.GetV1().Data, + }, + }, + }, + }, + }) + require.NoError(t, err) + + msgB, err := proto.Marshal(msg) + require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), Timestamp: int64(msg.GetV1().CreatedNs), Payload: msgB, }) require.NoError(t, err) + + if i == 4 { + time.Sleep(200 * time.Millisecond) + } } + // Expectations should eventually be satisfied. require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) } -func TestSubscribeWelcomeMessages(t *testing.T) { +func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { ctx := context.Background() svc, _, _, cleanup := newTestService(t, ctx) defer cleanup() installationKey := []byte(test.RandomString(32)) + // Initial message that does not get included in the stream. + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: []byte("data0"), + }, + }, + }, + }, + }) + require.NoError(t, err) + + // Set of 10 messages that are included in the stream. msgs := make([]*mlsv1.WelcomeMessage, 10) for i := 0; i < 10; i++ { msgs[i] = &mlsv1.WelcomeMessage{ Version: &mlsv1.WelcomeMessage_V1_{ V1: &mlsv1.WelcomeMessage_V1{ - Id: uint64(i + 1), - CreatedNs: uint64(i + 1), InstallationKey: installationKey, Data: []byte(fmt.Sprintf("data%d", i+1)), }, @@ -413,13 +577,14 @@ func TestSubscribeWelcomeMessages(t *testing.T) { } } + // Set up expectations of streaming the 10 messages. ctrl := gomock.NewController(t) stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) for _, msg := range msgs { - stream.EXPECT().Send(newWelcomeMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(newWelcomeMessageDataEqualsMatcher(msg)).Return(nil).Times(1) } - stream.EXPECT().Context().Return(ctx) + stream.EXPECT().Context().Return(ctx).AnyTimes() go func() { err := svc.SubscribeWelcomeMessages(&mlsv1.SubscribeWelcomeMessagesRequest{ @@ -433,49 +598,185 @@ func TestSubscribeWelcomeMessages(t *testing.T) { }() time.Sleep(50 * time.Millisecond) - for _, msg := range msgs { + // Send the messages (store and relay). + for i, msg := range msgs { + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: msg.GetV1().Data, + }, + }, + }, + }, + }) + require.NoError(t, err) + msgB, err := proto.Marshal(msg) require.NoError(t, err) + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + + if i == 4 { + time.Sleep(200 * time.Millisecond) + } + } + + // Expectations should eventually be satisfied. + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + +func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + installationKey := []byte(test.RandomString(32)) + + // Initial message before stream starts. + initialMsgs := []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: []byte("data1"), + }, + }, + }, + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: []byte("data2"), + }, + }, + }, + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: []byte("data3"), + }, + }, + }, + } + for _, msg := range initialMsgs { + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{msg}, + }) + require.NoError(t, err) + } + + // Set of 10 messages that are included in the stream. + msgs := make([]*mlsv1.WelcomeMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + InstallationKey: installationKey, + Data: []byte(fmt.Sprintf("data%d", i+4)), + }, + }, + } + } + + // Set up expectations of streaming the 10 messages. + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream.EXPECT().Send(newWelcomeMessageDataEqualsMatcher(&mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Data: []byte("data3"), + }, + }, + })).Return(nil).Times(1) + for _, msg := range msgs { + stream.EXPECT().Send(newWelcomeMessageDataEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx).AnyTimes() + + go func() { + err := svc.SubscribeWelcomeMessages(&mlsv1.SubscribeWelcomeMessagesRequest{ + Filters: []*mlsv1.SubscribeWelcomeMessagesRequest_Filter{ + { + InstallationKey: installationKey, + IdCursor: 2, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + // Send the messages (store and relay). + for i, msg := range msgs { + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + Data: msg.GetV1().Data, + }, + }, + }, + }, + }) + require.NoError(t, err) + msgB, err := proto.Marshal(msg) + require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), Timestamp: int64(msg.GetV1().CreatedNs), Payload: msgB, }) require.NoError(t, err) + + if i == 4 { + time.Sleep(200 * time.Millisecond) + } } + // Expectations should eventually be satisfied. require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) } -type groupMessageEqualsMatcher struct { +type groupMessageDataEqualsMatcher struct { obj *mlsv1.GroupMessage } -func newGroupMessageEqualsMatcher(obj *mlsv1.GroupMessage) *groupMessageEqualsMatcher { - return &groupMessageEqualsMatcher{obj} +func newGroupMessageDataEqualsMatcher(obj *mlsv1.GroupMessage) *groupMessageDataEqualsMatcher { + return &groupMessageDataEqualsMatcher{obj} } -func (m *groupMessageEqualsMatcher) Matches(obj interface{}) bool { - return proto.Equal(m.obj, obj.(*mlsv1.GroupMessage)) +func (m *groupMessageDataEqualsMatcher) Matches(obj interface{}) bool { + return bytes.Equal(m.obj.GetV1().Data, obj.(*mlsv1.GroupMessage).GetV1().Data) } -func (m *groupMessageEqualsMatcher) String() string { +func (m *groupMessageDataEqualsMatcher) String() string { return m.obj.String() } -type welcomeMessageEqualsMatcher struct { +type welcomeMessageDataEqualsMatcher struct { obj *mlsv1.WelcomeMessage } -func newWelcomeMessageEqualsMatcher(obj *mlsv1.WelcomeMessage) *welcomeMessageEqualsMatcher { - return &welcomeMessageEqualsMatcher{obj} +func newWelcomeMessageDataEqualsMatcher(obj *mlsv1.WelcomeMessage) *welcomeMessageDataEqualsMatcher { + return &welcomeMessageDataEqualsMatcher{obj} } -func (m *welcomeMessageEqualsMatcher) Matches(obj interface{}) bool { - return proto.Equal(m.obj, obj.(*mlsv1.WelcomeMessage)) +func (m *welcomeMessageDataEqualsMatcher) Matches(obj interface{}) bool { + return bytes.Equal(m.obj.GetV1().Data, obj.(*mlsv1.WelcomeMessage).GetV1().Data) } -func (m *welcomeMessageEqualsMatcher) String() string { +func (m *welcomeMessageDataEqualsMatcher) String() string { return m.obj.String() } diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 34d48207..cb5a28a6 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -15,7 +15,7 @@ import ( "go.uber.org/zap" ) -const maxPageSize = 100 +const MaxQueryPageSize = 100 type Store struct { config Config @@ -30,8 +30,10 @@ type MlsStore interface { GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte) (*WelcomeMessage, error) - QueryGroupMessagesV1(ctx context.Context, query *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) - QueryWelcomeMessagesV1(ctx context.Context, query *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) + GetLatestGroupMessage(ctx context.Context, groupId []byte) (*GroupMessage, error) + GetLatestWelcomeMessage(ctx context.Context, installationKey []byte) (*WelcomeMessage, error) + QueryGroupMessagesV1(ctx context.Context, query *mlsv1.QueryGroupMessagesRequest) ([]*GroupMessage, error) + QueryWelcomeMessagesV1(ctx context.Context, query *mlsv1.QueryWelcomeMessagesRequest) ([]*WelcomeMessage, error) } func New(ctx context.Context, config Config) (*Store, error) { @@ -215,7 +217,31 @@ func (s *Store) InsertWelcomeMessage(ctx context.Context, installationId []byte, return &message, nil } -func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { +func (s *Store) GetLatestGroupMessage(ctx context.Context, groupId []byte) (*GroupMessage, error) { + var msg GroupMessage + err := s.db.NewSelect().Model(&msg).Where("group_id = ?", groupId).Order("id DESC").Limit(1).Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, NewNotFoundError(err) + } + return nil, err + } + return &msg, nil +} + +func (s *Store) GetLatestWelcomeMessage(ctx context.Context, installationKey []byte) (*WelcomeMessage, error) { + var msg WelcomeMessage + err := s.db.NewSelect().Model(&msg).Where("installation_key = ?", installationKey).Order("id DESC").Limit(1).Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, NewNotFoundError(err) + } + return nil, err + } + return &msg, nil +} + +func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) ([]*GroupMessage, error) { msgs := make([]*GroupMessage, 0) if len(req.GroupId) == 0 { @@ -237,8 +263,8 @@ func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupM q = q.Order("id ASC") } - pageSize := maxPageSize - if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { + pageSize := MaxQueryPageSize + if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= MaxQueryPageSize { pageSize = int(req.PagingInfo.Limit) } q = q.Limit(pageSize) @@ -256,33 +282,10 @@ func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupM return nil, err } - messages := make([]*mlsv1.GroupMessage, 0, len(msgs)) - for _, msg := range msgs { - messages = append(messages, &mlsv1.GroupMessage{ - Version: &mlsv1.GroupMessage_V1_{ - V1: &mlsv1.GroupMessage_V1{ - Id: msg.Id, - CreatedNs: uint64(msg.CreatedAt.UnixNano()), - GroupId: msg.GroupId, - Data: msg.Data, - }, - }, - }) - } - - pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} - if len(messages) >= pageSize { - lastMsg := msgs[len(messages)-1] - pagingInfo.IdCursor = lastMsg.Id - } - - return &mlsv1.QueryGroupMessagesResponse{ - Messages: messages, - PagingInfo: pagingInfo, - }, nil + return msgs, nil } -func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { +func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) ([]*WelcomeMessage, error) { msgs := make([]*WelcomeMessage, 0) if len(req.InstallationKey) == 0 { @@ -304,8 +307,8 @@ func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelc q = q.Order("id ASC") } - pageSize := maxPageSize - if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { + pageSize := MaxQueryPageSize + if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= MaxQueryPageSize { pageSize = int(req.PagingInfo.Limit) } q = q.Limit(pageSize) @@ -323,29 +326,7 @@ func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelc return nil, err } - messages := make([]*mlsv1.WelcomeMessage, 0, len(msgs)) - for _, msg := range msgs { - messages = append(messages, &mlsv1.WelcomeMessage{ - Version: &mlsv1.WelcomeMessage_V1_{ - V1: &mlsv1.WelcomeMessage_V1{ - Id: msg.Id, - CreatedNs: uint64(msg.CreatedAt.UnixNano()), - Data: msg.Data, - }, - }, - }) - } - - pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} - if len(messages) >= pageSize { - lastMsg := msgs[len(messages)-1] - pagingInfo.IdCursor = lastMsg.Id - } - - return &mlsv1.QueryWelcomeMessagesResponse{ - Messages: messages, - PagingInfo: pagingInfo, - }, nil + return msgs, nil } func (s *Store) migrate(ctx context.Context) error { @@ -416,3 +397,20 @@ func IsAlreadyExistsError(err error) bool { _, ok := err.(*AlreadyExistsError) return ok } + +type NotFoundError struct { + Err error +} + +func (e *NotFoundError) Error() string { + return e.Err.Error() +} + +func NewNotFoundError(err error) *NotFoundError { + return &NotFoundError{err} +} + +func IsNotFoundError(err error) bool { + _, ok := err.(*NotFoundError) + return ok +} diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 9f196110..9823af45 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -294,9 +294,9 @@ func TestQueryGroupMessagesV1_MissingGroup(t *testing.T) { ctx := context.Background() - resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{}) + msgs, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{}) require.EqualError(t, err, "group is required") - require.Nil(t, resp) + require.Nil(t, msgs) } func TestQueryWelcomeMessagesV1_MissingInstallation(t *testing.T) { @@ -305,9 +305,9 @@ func TestQueryWelcomeMessagesV1_MissingInstallation(t *testing.T) { ctx := context.Background() - resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{}) + msgs, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{}) require.EqualError(t, err, "installation is required") - require.Nil(t, resp) + require.Nil(t, msgs) } func TestQueryGroupMessagesV1_Filter(t *testing.T) { @@ -322,38 +322,38 @@ func TestQueryGroupMessagesV1_Filter(t *testing.T) { _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("data3")) require.NoError(t, err) - resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("unknown"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 0) + require.Len(t, msgs, 0) - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("data3"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("data1"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data1"), msgs[1].Data) - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group2"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, []byte("data2"), resp.Messages[0].GetV1().Data) + require.Len(t, msgs, 1) + require.Equal(t, []byte("data2"), msgs[0].Data) // Sort ascending - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), PagingInfo: &mlsv1.PagingInfo{ Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("data1"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("data3"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("data1"), msgs[0].Data) + require.Equal(t, []byte("data3"), msgs[1].Data) } func TestQueryWelcomeMessagesV1_Filter(t *testing.T) { @@ -368,38 +368,38 @@ func TestQueryWelcomeMessagesV1_Filter(t *testing.T) { _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("data3")) require.NoError(t, err) - resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("unknown"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 0) + require.Len(t, msgs, 0) - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("data3"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("data1"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data1"), msgs[1].Data) - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation2"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, []byte("data2"), resp.Messages[0].GetV1().Data) + require.Len(t, msgs, 1) + require.Equal(t, []byte("data2"), msgs[0].Data) // Sort ascending - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), PagingInfo: &mlsv1.PagingInfo{ Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("data1"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("data3"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("data1"), msgs[0].Data) + require.Equal(t, []byte("data3"), msgs[1].Data) } func TestQueryGroupMessagesV1_Paginate(t *testing.T) { @@ -424,77 +424,86 @@ func TestQueryGroupMessagesV1_Paginate(t *testing.T) { _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content8")) require.NoError(t, err) - resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 6) - require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) - require.Equal(t, []byte("content6"), resp.Messages[2].GetV1().Data) - require.Equal(t, []byte("content5"), resp.Messages[3].GetV1().Data) - require.Equal(t, []byte("content3"), resp.Messages[4].GetV1().Data) - require.Equal(t, []byte("content1"), resp.Messages[5].GetV1().Data) + require.Len(t, msgs, 6) + require.Equal(t, []byte("content8"), msgs[0].Data) + require.Equal(t, []byte("content7"), msgs[1].Data) + require.Equal(t, []byte("content6"), msgs[2].Data) + require.Equal(t, []byte("content5"), msgs[3].Data) + require.Equal(t, []byte("content3"), msgs[4].Data) + require.Equal(t, []byte("content1"), msgs[5].Data) - thirdMsg := resp.Messages[2] - fifthMsg := resp.Messages[4] + thirdMsg := msgs[2] + fifthMsg := msgs[4] - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content8"), msgs[0].Data) + require.Equal(t, []byte("content7"), msgs[1].Data) // Order descending by default - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, - IdCursor: thirdMsg.GetV1().Id, + IdCursor: thirdMsg.Id, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content3"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content5"), msgs[0].Data) + require.Equal(t, []byte("content3"), msgs[1].Data) // Next page from previous response - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ - GroupId: []byte("group1"), - PagingInfo: resp.PagingInfo, + lastMsg := msgs[len(msgs)-1] + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + IdCursor: lastMsg.Id, + }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, []byte("content1"), resp.Messages[0].GetV1().Data) + require.Len(t, msgs, 1) + require.Equal(t, []byte("content1"), msgs[0].Data) // Order ascending - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ GroupId: []byte("group1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, - IdCursor: fifthMsg.GetV1().Id, + IdCursor: fifthMsg.Id, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content6"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content5"), msgs[0].Data) + require.Equal(t, []byte("content6"), msgs[1].Data) // Next page from previous response - resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ - GroupId: []byte("group1"), - PagingInfo: resp.PagingInfo, + lastMsg = msgs[len(msgs)-1] + msgs, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + IdCursor: lastMsg.Id, + }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content7"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content8"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content7"), msgs[0].Data) + require.Equal(t, []byte("content8"), msgs[1].Data) } func TestQueryWelcomeMessagesV1_Paginate(t *testing.T) { @@ -519,75 +528,84 @@ func TestQueryWelcomeMessagesV1_Paginate(t *testing.T) { _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content8")) require.NoError(t, err) - resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), }) require.NoError(t, err) - require.Len(t, resp.Messages, 6) - require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) - require.Equal(t, []byte("content6"), resp.Messages[2].GetV1().Data) - require.Equal(t, []byte("content5"), resp.Messages[3].GetV1().Data) - require.Equal(t, []byte("content3"), resp.Messages[4].GetV1().Data) - require.Equal(t, []byte("content1"), resp.Messages[5].GetV1().Data) + require.Len(t, msgs, 6) + require.Equal(t, []byte("content8"), msgs[0].Data) + require.Equal(t, []byte("content7"), msgs[1].Data) + require.Equal(t, []byte("content6"), msgs[2].Data) + require.Equal(t, []byte("content5"), msgs[3].Data) + require.Equal(t, []byte("content3"), msgs[4].Data) + require.Equal(t, []byte("content1"), msgs[5].Data) - thirdMsg := resp.Messages[2] - fifthMsg := resp.Messages[4] + thirdMsg := msgs[2] + fifthMsg := msgs[4] - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content8"), msgs[0].Data) + require.Equal(t, []byte("content7"), msgs[1].Data) // Order descending by default - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, - IdCursor: thirdMsg.GetV1().Id, + IdCursor: thirdMsg.Id, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content3"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content5"), msgs[0].Data) + require.Equal(t, []byte("content3"), msgs[1].Data) // Next page from previous response - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + lastMsg := msgs[len(msgs)-1] + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), - PagingInfo: resp.PagingInfo, + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + IdCursor: lastMsg.Id, + }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 1) - require.Equal(t, []byte("content1"), resp.Messages[0].GetV1().Data) + require.Len(t, msgs, 1) + require.Equal(t, []byte("content1"), msgs[0].Data) // Order ascending - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), PagingInfo: &mlsv1.PagingInfo{ Limit: 2, Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, - IdCursor: fifthMsg.GetV1().Id, + IdCursor: fifthMsg.Id, }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content6"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content5"), msgs[0].Data) + require.Equal(t, []byte("content6"), msgs[1].Data) // Next page from previous response - resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + lastMsg = msgs[len(msgs)-1] + msgs, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ InstallationKey: []byte("installation1"), - PagingInfo: resp.PagingInfo, + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + IdCursor: lastMsg.Id, + }, }) require.NoError(t, err) - require.Len(t, resp.Messages, 2) - require.Equal(t, []byte("content7"), resp.Messages[0].GetV1().Data) - require.Equal(t, []byte("content8"), resp.Messages[1].GetV1().Data) + require.Len(t, msgs, 2) + require.Equal(t, []byte("content7"), msgs[0].Data) + require.Equal(t, []byte("content8"), msgs[1].Data) } From 57bd804c13eb905fd49a48a778f6893163ca39b5 Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Fri, 19 Jan 2024 18:57:09 -0500 Subject: [PATCH 3/4] fix: no need for full mls messages over pubsub --- pkg/mls/api/v1/service.go | 67 +++++++++------------------------- pkg/mls/api/v1/service_test.go | 17 ++------- 2 files changed, 21 insertions(+), 63 deletions(-) diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 83359c65..7fcdcb68 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -20,7 +20,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - pb "google.golang.org/protobuf/proto" emptypb "google.golang.org/protobuf/types/known/emptypb" ) @@ -88,28 +87,18 @@ func (s *Service) Close() { func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) error { if topic.IsMLSV1Group(wakuMsg.ContentTopic) { - var msg mlsv1.GroupMessage - err := pb.Unmarshal(wakuMsg.Payload, &msg) - if err != nil { - return err - } - if msg.GetV1() == nil { - return nil - } - err = s.nc.Publish(buildNatsSubjectForGroupMessages(msg.GetV1().GroupId), wakuMsg.Payload) + // The waku message payload is just the group ID as bytes since we only + // need to use it as a signal that a new message was published, without + // any other content. + err := s.nc.Publish(buildNatsSubjectForGroupMessages(wakuMsg.Payload), wakuMsg.Payload) if err != nil { return err } } else if topic.IsMLSV1Welcome(wakuMsg.ContentTopic) { - var msg mlsv1.WelcomeMessage - err := pb.Unmarshal(wakuMsg.Payload, &msg) - if err != nil { - return err - } - if msg.GetV1() == nil { - return nil - } - err = s.nc.Publish(buildNatsSubjectForWelcomeMessages(msg.GetV1().InstallationKey), wakuMsg.Payload) + // The waku message payload is just the installation key as bytes since + // we only need to use it as a signal that a new message was published, + // without any other content. + err := s.nc.Publish(buildNatsSubjectForWelcomeMessages(wakuMsg.Payload), wakuMsg.Payload) if err != nil { return err } @@ -261,24 +250,13 @@ func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMes return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) } - msgB, err := pb.Marshal(&mlsv1.GroupMessage{ - Version: &mlsv1.GroupMessage_V1_{ - V1: &mlsv1.GroupMessage_V1{ - Id: msg.Id, - CreatedNs: uint64(msg.CreatedAt.UnixNano()), - GroupId: msg.GroupId, - Data: msg.Data, - }, - }, - }) - if err != nil { - return nil, err - } - err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1GroupTopic(decodedGroupId), Timestamp: msg.CreatedAt.UnixNano(), - Payload: msgB, + // The waku message payload is just the group ID as bytes since we + // only need to use it as a signal that a new message was + // published, without any other content. + Payload: msg.GroupId, }) if err != nil { return nil, err @@ -303,24 +281,13 @@ func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcom return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) } - msgB, err := pb.Marshal(&mlsv1.WelcomeMessage{ - Version: &mlsv1.WelcomeMessage_V1_{ - V1: &mlsv1.WelcomeMessage_V1{ - Id: msg.Id, - CreatedNs: uint64(msg.CreatedAt.UnixNano()), - InstallationKey: msg.InstallationKey, - Data: msg.Data, - }, - }, - }) - if err != nil { - return nil, err - } - err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ - ContentTopic: topic.BuildMLSV1WelcomeTopic(input.GetV1().InstallationKey), + ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.InstallationKey), Timestamp: msg.CreatedAt.UnixNano(), - Payload: msgB, + // The waku message payload is just the installation key as bytes + // since we only need to use it as a signal that a new message was + // published, without any other content. + Payload: msg.InstallationKey, }) if err != nil { return nil, err diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index 0152a2ab..168c18ef 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -18,7 +18,6 @@ import ( test "github.com/xmtp/xmtp-node-go/pkg/testing" "github.com/xmtp/xmtp-node-go/pkg/topic" "go.uber.org/mock/gomock" - "google.golang.org/protobuf/proto" ) type mockedMLSValidationService struct { @@ -409,12 +408,10 @@ func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { }) require.NoError(t, err) - msgB, err := proto.Marshal(msg) - require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), Timestamp: int64(msg.GetV1().CreatedNs), - Payload: msgB, + Payload: msg.GetV1().GroupId, }) require.NoError(t, err) @@ -524,12 +521,10 @@ func TestSubscribeGroupMessages_WithCursor(t *testing.T) { }) require.NoError(t, err) - msgB, err := proto.Marshal(msg) - require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), Timestamp: int64(msg.GetV1().CreatedNs), - Payload: msgB, + Payload: msg.GetV1().GroupId, }) require.NoError(t, err) @@ -614,12 +609,10 @@ func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { }) require.NoError(t, err) - msgB, err := proto.Marshal(msg) - require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), Timestamp: int64(msg.GetV1().CreatedNs), - Payload: msgB, + Payload: msg.GetV1().InstallationKey, }) require.NoError(t, err) @@ -731,12 +724,10 @@ func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { }) require.NoError(t, err) - msgB, err := proto.Marshal(msg) - require.NoError(t, err) err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), Timestamp: int64(msg.GetV1().CreatedNs), - Payload: msgB, + Payload: msg.GetV1().InstallationKey, }) require.NoError(t, err) From 40d197e25353506dfcb8edf4f15488e45a0485c6 Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Fri, 19 Jan 2024 21:29:31 -0500 Subject: [PATCH 4/4] mls: remove unnecessary passive ticker from subscribe --- pkg/mls/api/v1/service.go | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 7fcdcb68..62413e2b 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -412,19 +412,15 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques } } - activeTicker := time.NewTicker(100 * time.Millisecond) - defer activeTicker.Stop() - passiveTicker := time.NewTicker(5 * time.Second) - defer passiveTicker.Stop() + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() for { select { case <-stream.Context().Done(): return case <-s.ctx.Done(): return - case <-passiveTicker.C: - setHasMessages() - case <-activeTicker.C: + case <-ticker.C: var skip bool func() { hasMessagesLock.Lock() @@ -545,19 +541,15 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe } } - activeTicker := time.NewTicker(200 * time.Millisecond) - defer activeTicker.Stop() - passiveTicker := time.NewTicker(5 * time.Second) - defer passiveTicker.Stop() + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() for { select { case <-stream.Context().Done(): return case <-s.ctx.Done(): return - case <-passiveTicker.C: - setHasMessages() - case <-activeTicker.C: + case <-ticker.C: var skip bool func() { hasMessagesLock.Lock()