Skip to content

Commit

Permalink
Remove inbox validation for uploading key packages
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Aug 14, 2024
1 parent 24971f4 commit a816114
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 184 deletions.
2 changes: 1 addition & 1 deletion dev/docker/env
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
set -e

function docker_compose() {
docker-compose -f dev/docker/docker-compose.yml -p xmtpd "$@"
docker compose -f dev/docker/docker-compose.yml -p xmtpd "$@"
}
2 changes: 1 addition & 1 deletion dev/e2e/docker/env
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
set -e

function docker_compose() {
docker-compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@"
docker compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@"
}
10 changes: 5 additions & 5 deletions pkg/api/message/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot
log.Debug("received message")

if len(env.ContentTopic) > MaxContentTopicNameSize {
return nil, status.Errorf(codes.InvalidArgument, "topic length too big")
return nil, status.Error(codes.InvalidArgument, "topic length too big")
}

if len(env.Message) > MaxMessageSize {
return nil, status.Errorf(codes.InvalidArgument, "message too big")
return nil, status.Error(codes.InvalidArgument, "message too big")
}

if !topic.IsEphemeral(env.ContentTopic) {
_, err := s.store.InsertMessage(env)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}
}

Expand All @@ -150,7 +150,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot
Payload: env.Message,
})
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}

metrics.EmitPublishedEnvelope(ctx, log, env)
Expand Down Expand Up @@ -393,7 +393,7 @@ func (s *Service) BatchQuery(ctx context.Context, req *proto.BatchQueryRequest)
// We execute the query using the existing Query API
resp, err := s.Query(ctx, query)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}
responses = append(responses, resp)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SET statement_timeout = 0;

--bun:split
ALTER TABLE installations
ADD COLUMN inbox_id BYTEA NOT NULL,
ADD COLUMN expiration BIGINT NOT NULL;

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SET statement_timeout = 0;

--bun:split
ALTER TABLE installations
DROP COLUMN IF EXISTS inbox_id,
DROP COLUMN IF EXISTS expiration;

12 changes: 8 additions & 4 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) er
return nil
}

/*
*
DEPRECATED: Use UploadKeyPackage instead
*
*/
func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) {
if err := validateRegisterInstallationRequest(req); err != nil {
return nil, err
Expand All @@ -126,9 +131,9 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterI
if len(results) != 1 {
return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results))
}

installationKey := results[0].InstallationKey
credential := results[0].Credential
if err = s.store.CreateInstallation(ctx, installationKey, credential.InboxId, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil {
if err = s.store.CreateOrUpdateInstallation(ctx, installationKey, req.KeyPackage.KeyPackageTlsSerialized); err != nil {
return nil, err
}
return &mlsv1.RegisterInstallationResponse{
Expand Down Expand Up @@ -178,9 +183,8 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack
}

installationId := validationResults[0].InstallationKey
expiration := validationResults[0].Expiration

if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil {
if err = s.store.CreateOrUpdateInstallation(ctx, installationId, keyPackageBytes); err != nil {
return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err)
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/mls/api/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ func TestRegisterInstallation(t *testing.T) {
defer cleanup()

installationId := test.RandomBytes(32)
inboxId := test.RandomInboxId()
keyPackage := []byte("test")

mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId)
mockValidateInboxIdKeyPackages(mlsValidationService, installationId, test.RandomInboxId())

res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
KeyPackageTlsSerialized: keyPackage,
},
IsInboxIdCredential: false,
})
Expand All @@ -98,7 +98,8 @@ func TestRegisterInstallation(t *testing.T) {
installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId)
require.NoError(t, err)

require.Equal(t, inboxId, installation.InboxID)
require.Equal(t, installationId, installation.ID)
require.Equal(t, []byte("test"), installation.KeyPackage)
}

func TestRegisterInstallationError(t *testing.T) {
Expand Down
23 changes: 7 additions & 16 deletions pkg/mls/store/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,24 @@ WHERE (address, inbox_id, association_sequence_id) =(
address,
inbox_id);

-- name: CreateInstallation :exec
INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration)
VALUES (@id, @created_at, @updated_at, decode(@inbox_id, 'hex'), @key_package, @expiration);
-- name: CreateOrUpdateInstallation :exec
INSERT INTO installations(id, created_at, updated_at, key_package)
VALUES (@id, @created_at, @updated_at, @key_package)
ON CONFLICT (id)
DO UPDATE SET
key_package = @key_package, updated_at = @updated_at;

-- name: GetInstallation :one
SELECT
id,
created_at,
updated_at,
encode(inbox_id, 'hex') AS inbox_id,
key_package,
expiration
key_package
FROM
installations
WHERE
id = $1;

-- name: UpdateKeyPackage :execrows
UPDATE
installations
SET
key_package = @key_package,
updated_at = @updated_at,
expiration = @expiration
WHERE
id = @id;

-- name: FetchKeyPackages :many
SELECT
id,
Expand Down
2 changes: 0 additions & 2 deletions pkg/mls/store/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 12 additions & 57 deletions pkg/mls/store/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 6 additions & 28 deletions pkg/mls/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ type IdentityStore interface {
type MlsStore interface {
IdentityStore

CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error
UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error
CreateOrUpdateInstallation(ctx context.Context, installationId []byte, keyPackage []byte) error
FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error)
InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error)
InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*queries.WelcomeMessage, error)
Expand Down Expand Up @@ -246,38 +245,17 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent
}

// Creates the installation and last resort key package
func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error {
createdAt := nowNs()
func (s *Store) CreateOrUpdateInstallation(ctx context.Context, installationId []byte, keyPackage []byte) error {
now := nowNs()

return s.queries.CreateInstallation(ctx, queries.CreateInstallationParams{
return s.queries.CreateOrUpdateInstallation(ctx, queries.CreateOrUpdateInstallationParams{
ID: installationId,
CreatedAt: createdAt,
InboxID: inboxId,
CreatedAt: now,
UpdatedAt: now,
KeyPackage: keyPackage,
Expiration: int64(expiration),
})
}

// Insert a new key package, ignoring any that may already exist
func (s *Store) UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error {
rowsUpdated, err := s.queries.UpdateKeyPackage(ctx, queries.UpdateKeyPackageParams{
ID: installationId,
UpdatedAt: nowNs(),
KeyPackage: keyPackage,
Expiration: int64(expiration),
})

if err != nil {
return err
}

if rowsUpdated == 0 {
return errors.New("installation id unknown")
}

return nil
}

func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) {
return s.queries.FetchKeyPackages(ctx, installationIds)
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/mls/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ func TestCreateInstallation(t *testing.T) {

ctx := context.Background()
installationId := test.RandomBytes(32)
inboxId := test.RandomInboxId()

err := store.CreateInstallation(ctx, installationId, inboxId, test.RandomBytes(32), 0)
err := store.CreateOrUpdateInstallation(ctx, installationId, test.RandomBytes(32))
require.NoError(t, err)

installationFromDb, err := store.queries.GetInstallation(ctx, installationId)
Expand All @@ -185,21 +184,23 @@ func TestUpdateKeyPackage(t *testing.T) {

ctx := context.Background()
installationId := test.RandomBytes(32)
inboxId := test.RandomInboxId()
keyPackage := test.RandomBytes(32)

err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0)
err := store.CreateOrUpdateInstallation(ctx, installationId, keyPackage)
require.NoError(t, err)
afterCreate, err := store.queries.GetInstallation(ctx, installationId)
require.NoError(t, err)

keyPackage2 := test.RandomBytes(32)
err = store.UpdateKeyPackage(ctx, installationId, keyPackage2, 1)
err = store.CreateOrUpdateInstallation(ctx, installationId, keyPackage2)
require.NoError(t, err)

installationFromDb, err := store.queries.GetInstallation(ctx, installationId)
require.NoError(t, err)

require.Equal(t, keyPackage2, installationFromDb.KeyPackage)
require.Equal(t, int64(1), installationFromDb.Expiration)
require.Greater(t, installationFromDb.UpdatedAt, afterCreate.UpdatedAt)
require.Equal(t, installationFromDb.CreatedAt, afterCreate.CreatedAt)
}

func TestConsumeLastResortKeyPackage(t *testing.T) {
Expand All @@ -209,9 +210,8 @@ func TestConsumeLastResortKeyPackage(t *testing.T) {
ctx := context.Background()
installationId := test.RandomBytes(32)
keyPackage := test.RandomBytes(32)
inboxId := test.RandomInboxId()

err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0)
err := store.CreateOrUpdateInstallation(ctx, installationId, keyPackage)
require.NoError(t, err)

fetchResult, err := store.FetchKeyPackages(ctx, [][]byte{installationId})
Expand Down
Loading

0 comments on commit a816114

Please sign in to comment.