Skip to content

Commit

Permalink
fix(contacts)_: fix trust status not being saved to cache when changed
Browse files Browse the repository at this point in the history
  • Loading branch information
jrainville committed Oct 25, 2024
1 parent 6ee6206 commit 86df4ff
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
30 changes: 21 additions & 9 deletions protocol/messenger_contact_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ import (
const minContactVerificationMessageLen = 1
const maxContactVerificationMessageLen = 280

var (
ErrContactNotMutual = errors.New("must be a mutual contact")
)

func (m *Messenger) SendContactVerificationRequest(ctx context.Context, contactID string, challenge string) (*MessengerResponse, error) {
if len(challenge) < minContactVerificationMessageLen || len(challenge) > maxContactVerificationMessageLen {
return nil, errors.New("invalid verification request challenge length")
}

contact, ok := m.allContacts.Load(contactID)
if !ok || !contact.mutual() {
return nil, errors.New("must be a mutual contact")
return nil, ErrContactNotMutual
}

verifRequest := &verification.Request{
Expand Down Expand Up @@ -138,7 +142,7 @@ func (m *Messenger) SendContactVerificationRequest(ctx context.Context, contactI
func (m *Messenger) GetVerificationRequestSentTo(ctx context.Context, contactID string) (*verification.Request, error) {
_, ok := m.allContacts.Load(contactID)
if !ok {
return nil, errors.New("contact not found")
return nil, ErrContactNotFound
}

return m.verificationDatabase.GetLatestVerificationRequestSentTo(contactID)
Expand Down Expand Up @@ -279,7 +283,7 @@ func (m *Messenger) AcceptContactVerificationRequest(ctx context.Context, id str

contact, ok := m.allContacts.Load(contactID)
if !ok || !contact.mutual() {
return nil, errors.New("must be a mutual contact")
return nil, ErrContactNotMutual
}

chat, ok := m.allChats.Load(contactID)
Expand Down Expand Up @@ -394,7 +398,7 @@ func (m *Messenger) VerifiedTrusted(ctx context.Context, request *requests.Verif

contact, ok := m.allContacts.Load(contactID)
if !ok || !contact.mutual() {
return nil, errors.New("must be a mutual contact")
return nil, ErrContactNotMutual
}

err = m.setTrustStatusForContact(context.Background(), contactID, verification.TrustStatusTRUSTED)
Expand Down Expand Up @@ -589,7 +593,7 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st

contact, ok := m.allContacts.Load(verifRequest.From)
if !ok || !contact.mutual() {
return nil, errors.New("must be a mutual contact")
return nil, ErrContactNotMutual
}
contactID := verifRequest.From
contact, err = m.setContactVerificationStatus(contactID, VerificationStatusVERIFIED)
Expand Down Expand Up @@ -686,7 +690,7 @@ func (m *Messenger) DeclineContactVerificationRequest(ctx context.Context, id st
func (m *Messenger) setContactVerificationStatus(contactID string, verificationStatus VerificationStatus) (*Contact, error) {
contact, ok := m.allContacts.Load(contactID)
if !ok || !contact.mutual() {
return nil, errors.New("must be a mutual contact")
return nil, ErrContactNotMutual
}

contact.VerificationStatus = verificationStatus
Expand Down Expand Up @@ -714,13 +718,21 @@ func (m *Messenger) setContactVerificationStatus(contactID string, verificationS
}

func (m *Messenger) setTrustStatusForContact(ctx context.Context, contactID string, trustStatus verification.TrustStatus) error {
contact, ok := m.allContacts.Load(contactID)
if !ok {
return ErrContactNotFound
}

currentTime := m.getTimesource().GetCurrentTime()

err := m.verificationDatabase.SetTrustStatus(contactID, trustStatus, currentTime)
if err != nil {
return err
}

contact.TrustStatus = trustStatus
m.allContacts.Store(contactID, contact)

return m.SyncTrustedUser(ctx, contactID, trustStatus, m.dispatchMessage)
}

Expand Down Expand Up @@ -784,7 +796,7 @@ func (m *Messenger) HandleRequestContactVerification(state *ReceivedMessageState
contact := state.CurrentMessageState.Contact
if !contact.mutual() {
m.logger.Debug("Received a verification request for a non added mutual contact", zap.String("contactID", contactID))
return errors.New("must be a mutual contact")
return ErrContactNotMutual
}

persistedVR, err := m.verificationDatabase.GetVerificationRequest(id)
Expand Down Expand Up @@ -875,7 +887,7 @@ func (m *Messenger) HandleAcceptContactVerification(state *ReceivedMessageState,
contact := state.CurrentMessageState.Contact
if !contact.mutual() {
m.logger.Debug("Received a verification response for a non mutual contact", zap.String("contactID", contactID))
return errors.New("must be a mutual contact")
return ErrContactNotMutual
}

persistedVR, err := m.verificationDatabase.GetVerificationRequest(request.Id)
Expand Down Expand Up @@ -964,7 +976,7 @@ func (m *Messenger) HandleDeclineContactVerification(state *ReceivedMessageState
contact := state.CurrentMessageState.Contact
if !contact.mutual() {
m.logger.Debug("Received a verification decline for a non mutual contact", zap.String("contactID", contactID))
return errors.New("must be a mutual contact")
return ErrContactNotMutual
}

persistedVR, err := m.verificationDatabase.GetVerificationRequest(request.Id)
Expand Down
38 changes: 38 additions & 0 deletions protocol/messenger_contact_verification_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,41 @@ func (s *MessengerVerificationRequests) newMessenger(shh types.Waku) *Messenger
s.Require().NoError(err)
return messenger
}

func (s *MessengerVerificationRequests) TestTrustStatus() {
theirMessenger := s.newMessenger(s.shh)
defer TearDownMessenger(&s.Suite, theirMessenger)

s.mutualContact(theirMessenger)

theirPk := types.EncodeHex(crypto.FromECDSAPub(&theirMessenger.identity.PublicKey))

// Test Mark as Trusted
err := s.m.MarkAsTrusted(context.Background(), theirPk)
s.Require().NoError(err)

contactFromCache, ok := s.m.allContacts.Load(theirPk)
s.Require().True(ok)
s.Require().Equal(verification.TrustStatusTRUSTED, contactFromCache.TrustStatus)

// Test Remove Trust Mark
err = s.m.RemoveTrustStatus(context.Background(), theirPk)
s.Require().NoError(err)

contactFromCache, ok = s.m.allContacts.Load(theirPk)
s.Require().True(ok)
s.Require().Equal(verification.TrustStatusUNKNOWN, contactFromCache.TrustStatus)

// Test Mark as Untrustoworthy
err = s.m.MarkAsUntrustworthy(context.Background(), theirPk)
s.Require().NoError(err)

contactFromCache, ok = s.m.allContacts.Load(theirPk)
s.Require().True(ok)
s.Require().Equal(verification.TrustStatusUNTRUSTWORTHY, contactFromCache.TrustStatus)

// Test calling with an unknown contact
err = s.m.MarkAsTrusted(context.Background(), "0x00000123")
s.Require().Error(err)
s.Require().Equal("contact not found", err.Error())
}

0 comments on commit 86df4ff

Please sign in to comment.