From 1a5a0cde7c7742bfb9b170ae26fc5ae874bfa522 Mon Sep 17 00:00:00 2001 From: Grady Berry Ward Date: Mon, 8 Jan 2024 17:59:36 -0700 Subject: [PATCH] Creates Backend for Merging User Accounts (#117) --- cmd/server/pactasrv/BUILD.bazel | 1 + cmd/server/pactasrv/admin.go | 170 +++++++++ cmd/server/pactasrv/blobs.go | 14 +- cmd/server/pactasrv/conv/oapi_to_pacta.go | 2 + cmd/server/pactasrv/conv/pacta_to_oapi.go | 2 + cmd/server/pactasrv/pactasrv.go | 8 +- cmd/server/pactasrv/user.go | 5 +- db/sqldb/BUILD.bazel | 2 + db/sqldb/audit_log.go | 52 ++- db/sqldb/audit_log_test.go | 221 ++++++++++-- db/sqldb/golden/human_readable_schema.sql | 17 +- db/sqldb/golden/schema_dump.sql | 31 +- db/sqldb/merge.go | 128 +++++++ db/sqldb/merge_test.go | 335 ++++++++++++++++++ .../0009_support_user_merge.down.sql | 28 ++ .../migrations/0009_support_user_merge.up.sql | 19 + db/sqldb/owner.go | 60 +++- db/sqldb/sqldb.go | 43 +++ db/sqldb/sqldb_test.go | 1 + db/sqldb/user.go | 60 +++- db/sqldb/user_test.go | 2 +- openapi/pacta.yaml | 56 +++ pacta/pacta.go | 4 + 23 files changed, 1190 insertions(+), 71 deletions(-) create mode 100644 cmd/server/pactasrv/admin.go create mode 100644 db/sqldb/merge.go create mode 100644 db/sqldb/merge_test.go create mode 100644 db/sqldb/migrations/0009_support_user_merge.down.sql create mode 100644 db/sqldb/migrations/0009_support_user_merge.up.sql diff --git a/cmd/server/pactasrv/BUILD.bazel b/cmd/server/pactasrv/BUILD.bazel index e287aa1..bcbb45f 100644 --- a/cmd/server/pactasrv/BUILD.bazel +++ b/cmd/server/pactasrv/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "pactasrv", srcs = [ + "admin.go", "analysis.go", "audit_logs.go", "blobs.go", diff --git a/cmd/server/pactasrv/admin.go b/cmd/server/pactasrv/admin.go new file mode 100644 index 0000000..7ca772e --- /dev/null +++ b/cmd/server/pactasrv/admin.go @@ -0,0 +1,170 @@ +package pactasrv + +import ( + "context" + "fmt" + + "github.com/RMI/pacta/db" + "github.com/RMI/pacta/oapierr" + api "github.com/RMI/pacta/openapi/pacta" + "github.com/RMI/pacta/pacta" + "go.uber.org/zap" +) + +// Merges two users together +// (POST /admin/merge-users) +func (s *Server) MergeUsers(ctx context.Context, request api.MergeUsersRequestObject) (api.MergeUsersResponseObject, error) { + req := request.Body + actorUserInfo, err := s.getactorInfoOrErrIfAnon(ctx) + if err != nil { + return nil, err + } + fieldsIfErr := []zap.Field{ + zap.String("actor_user_id", string(actorUserInfo.UserID)), + zap.String("from_user_id", req.FromUserId), + zap.String("to_user_id", req.ToUserId), + } + if !actorUserInfo.IsAdmin && !actorUserInfo.IsSuperAdmin { + return nil, oapierr.Forbidden("only admins can merge users", fieldsIfErr...) + } + + sourceUID := pacta.UserID(req.FromUserId) + destUID := pacta.UserID(req.ToUserId) + if sourceUID == destUID { + return nil, oapierr.BadRequest("cannot merge user into themselves", fieldsIfErr...) + } + + var ( + numIncompleteUploads, numAnalyses, numPortfolios, numPortfolioGroups, numAuditLogsCreated int + buris []pacta.BlobURI + ) + + err = s.DB.Transactional(ctx, func(tx db.Tx) error { + sourceOwner, err := s.DB.GetOwnerForUser(tx, sourceUID) + if err != nil { + return fmt.Errorf("failed to get owner for source user: %w", err) + } + destOwner, err := s.DB.GetOwnerForUser(tx, destUID) + if err != nil { + return fmt.Errorf("failed to get owner for destination user: %w", err) + } + + if err = s.DB.RecordUserMerge(tx, sourceUID, destUID, actorUserInfo.UserID); err != nil { + return fmt.Errorf("failed to record user merge: %w", err) + } + if err = s.DB.RecordOwnerMerge(tx, sourceOwner, destOwner, actorUserInfo.UserID); err != nil { + return fmt.Errorf("failed to record owner merge: %w", err) + } + + auditLogsToCreate := []*pacta.AuditLog{} + addAuditLog := func(t pacta.AuditLogTargetType, id string) { + auditLogsToCreate = append(auditLogsToCreate, &pacta.AuditLog{ + Action: pacta.AuditLogAction_TransferOwnership, + ActorType: pacta.AuditLogActorType_Admin, + ActorID: string(actorUserInfo.UserID), + ActorOwner: &pacta.Owner{ID: actorUserInfo.OwnerID}, + PrimaryTargetType: t, + PrimaryTargetID: id, + PrimaryTargetOwner: &pacta.Owner{ID: destOwner}, + SecondaryTargetType: pacta.AuditLogTargetType_User, + SecondaryTargetID: string(sourceUID), + SecondaryTargetOwner: &pacta.Owner{ID: sourceOwner}, + }) + } + + incompleteUploads, err := s.DB.IncompleteUploadsByOwner(tx, sourceOwner) + if err != nil { + return fmt.Errorf("failed to get incomplete uploads for source owner: %w", err) + } + for i, upload := range incompleteUploads { + err := s.DB.UpdateIncompleteUpload(tx, upload.ID, db.SetIncompleteUploadOwner(destOwner)) + if err != nil { + return fmt.Errorf("failed to update upload owner %d/%d: %w", i+1, len(incompleteUploads), err) + } + addAuditLog(pacta.AuditLogTargetType_IncompleteUpload, string(upload.ID)) + } + numIncompleteUploads = len(incompleteUploads) + + analyses, err := s.DB.AnalysesByOwner(tx, sourceOwner) + if err != nil { + return fmt.Errorf("failed to get analyses for source owner: %w", err) + } + for i, analysis := range analyses { + err := s.DB.UpdateAnalysis(tx, analysis.ID, db.SetAnalysisOwner(destOwner)) + if err != nil { + return fmt.Errorf("failed to update analysis owner %d/%d: %w", i+1, len(analyses), err) + } + addAuditLog(pacta.AuditLogTargetType_Analysis, string(analysis.ID)) + } + numAnalyses = len(analyses) + + portfolios, err := s.DB.PortfoliosByOwner(tx, sourceOwner) + if err != nil { + return fmt.Errorf("failed to get portfolios for source owner: %w", err) + } + for i, portfolio := range portfolios { + err := s.DB.UpdatePortfolio(tx, portfolio.ID, db.SetPortfolioOwner(destOwner)) + if err != nil { + return fmt.Errorf("failed to update portfolio owner %d/%d: %w", i+1, len(portfolios), err) + } + addAuditLog(pacta.AuditLogTargetType_Portfolio, string(portfolio.ID)) + } + numPortfolios = len(portfolios) + + portfolioGroups, err := s.DB.PortfolioGroupsByOwner(tx, sourceOwner) + if err != nil { + return fmt.Errorf("failed to get portfolio groups for source owner: %w", err) + } + for i, portfolioGroup := range portfolioGroups { + err := s.DB.UpdatePortfolioGroup(tx, portfolioGroup.ID, db.SetPortfolioGroupOwner(destOwner)) + if err != nil { + return fmt.Errorf("failed to update portfolio group owner %d/%d: %w", i+1, len(portfolioGroups), err) + } + addAuditLog(pacta.AuditLogTargetType_PortfolioGroup, string(portfolioGroup.ID)) + } + numPortfolioGroups = len(portfolioGroups) + + if err := s.DB.CreateAuditLogs(tx, auditLogsToCreate); err != nil { + return fmt.Errorf("failed to create audit logs: %w", err) + } + numAuditLogsCreated = len(auditLogsToCreate) + + // Now that we've transferred all the entities, we can delete the user. + deletedUserBuris, err := s.DB.DeleteUser(tx, sourceUID) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + if len(deletedUserBuris) > 0 { + // Note in this case we won't commit the transaction, so this data won't be orphaned. + return fmt.Errorf("failed to delete user: user still has blobs: %v", deletedUserBuris) + } + + return nil + }) + if err != nil { + fieldsIfErr := append(fieldsIfErr, zap.Error(err)) + return nil, oapierr.Internal("failed to merge users", fieldsIfErr...) + } + + if err := s.deleteBlobs(ctx, buris...); err != nil { + return nil, err + } + + s.Logger.Info("user merge completed successfully", + zap.String("actor_user_id", string(actorUserInfo.UserID)), + zap.String("from_user_id", req.FromUserId), + zap.String("to_user_id", req.ToUserId), + zap.Int("num_incomplete_uploads", numIncompleteUploads), + zap.Int("num_analyses", numAnalyses), + zap.Int("num_portfolios", numPortfolios), + zap.Int("num_portfolio_groups", numPortfolioGroups), + zap.Int("num_audit_logs_created", numAuditLogsCreated), + ) + return api.MergeUsers200JSONResponse{ + AuditLogsCreated: numAuditLogsCreated, + IncompleteUploadCount: numIncompleteUploads, + PortfolioCount: numPortfolios, + PortfolioGroupCount: numPortfolioGroups, + AnalysisCount: numAnalyses, + }, nil +} diff --git a/cmd/server/pactasrv/blobs.go b/cmd/server/pactasrv/blobs.go index 4b17866..48008ab 100644 --- a/cmd/server/pactasrv/blobs.go +++ b/cmd/server/pactasrv/blobs.go @@ -2,7 +2,6 @@ package pactasrv import ( "context" - "fmt" "github.com/RMI/pacta/db" "github.com/RMI/pacta/oapierr" @@ -12,7 +11,7 @@ import ( ) func (s *Server) AccessBlobContent(ctx context.Context, request api.AccessBlobContentRequestObject) (api.AccessBlobContentResponseObject, error) { - actorInfo, err := s.getActorInfoOrFail(ctx) + actorInfo, err := s.getactorInfoOrErrIfAnon(ctx) if err != nil { return nil, err } @@ -69,16 +68,7 @@ func (s *Server) AccessBlobContent(ctx context.Context, request api.AccessBlobCo return nil, oapierr.Internal("error getting blobs", zap.Error(err), zap.Strings("blob_ids", asStrs(blobIDs))) } - err = s.DB.Transactional(ctx, func(tx db.Tx) error { - for i, al := range auditLogs { - _, err := s.DB.CreateAuditLog(tx, al) - if err != nil { - return fmt.Errorf("creating audit log %d/%d: %w", i+1, len(auditLogs), err) - } - } - return nil - }) - if err != nil { + if err = s.DB.CreateAuditLogs(s.DB.NoTxn(ctx), auditLogs); err != nil { return nil, oapierr.Internal("error creating audit logs - no download URLs generated", zap.Error(err), zap.Strings("blob_ids", asStrs(blobIDs))) } diff --git a/cmd/server/pactasrv/conv/oapi_to_pacta.go b/cmd/server/pactasrv/conv/oapi_to_pacta.go index 6c6885c..eb207cb 100644 --- a/cmd/server/pactasrv/conv/oapi_to_pacta.go +++ b/cmd/server/pactasrv/conv/oapi_to_pacta.go @@ -144,6 +144,8 @@ func auditLogActionFromOAPI(i api.AuditLogAction) (pacta.AuditLogAction, error) return pacta.AuditLogAction_EnableSharing, nil case api.AuditLogActionDisableSharing: return pacta.AuditLogAction_DisableSharing, nil + case api.AuditLogActionTransferOwnership: + return pacta.AuditLogAction_TransferOwnership, nil } return "", oapierr.BadRequest("unknown audit log action", zap.String("audit_log_action", string(i))) } diff --git a/cmd/server/pactasrv/conv/pacta_to_oapi.go b/cmd/server/pactasrv/conv/pacta_to_oapi.go index c547959..2e41a7b 100644 --- a/cmd/server/pactasrv/conv/pacta_to_oapi.go +++ b/cmd/server/pactasrv/conv/pacta_to_oapi.go @@ -452,6 +452,8 @@ func auditLogActionToOAPI(i pacta.AuditLogAction) (api.AuditLogAction, error) { return api.AuditLogActionEnableSharing, nil case pacta.AuditLogAction_DisableSharing: return api.AuditLogActionDisableSharing, nil + case pacta.AuditLogAction_TransferOwnership: + return api.AuditLogActionTransferOwnership, nil } return "", oapierr.Internal(fmt.Sprintf("auditLogActionToOAPI: unknown action: %q", i)) } diff --git a/cmd/server/pactasrv/pactasrv.go b/cmd/server/pactasrv/pactasrv.go index ad150eb..66f59a7 100644 --- a/cmd/server/pactasrv/pactasrv.go +++ b/cmd/server/pactasrv/pactasrv.go @@ -114,10 +114,14 @@ type DB interface { User(tx db.Tx, id pacta.UserID) (*pacta.User, error) Users(tx db.Tx, ids []pacta.UserID) (map[pacta.UserID]*pacta.User, error) UpdateUser(tx db.Tx, id pacta.UserID, mutations ...db.UpdateUserFn) error - DeleteUser(tx db.Tx, id pacta.UserID) error + DeleteUser(tx db.Tx, id pacta.UserID) ([]pacta.BlobURI, error) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, error) + CreateAuditLogs(tx db.Tx, as []*pacta.AuditLog) error AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.PageInfo, error) + + RecordUserMerge(tx db.Tx, fromUserID, toUserID, actorUserID pacta.UserID) error + RecordOwnerMerge(tx db.Tx, fromUserID, toUserID pacta.OwnerID, actorUserID pacta.UserID) error } type Blob interface { @@ -211,7 +215,7 @@ type actorInfo struct { IsSuperAdmin bool } -func (s *Server) getActorInfoOrFail(ctx context.Context) (*actorInfo, error) { +func (s *Server) getactorInfoOrErrIfAnon(ctx context.Context) (*actorInfo, error) { actorUserID, err := getUserID(ctx) if err != nil { return nil, err diff --git a/cmd/server/pactasrv/user.go b/cmd/server/pactasrv/user.go index 04ab164..0ee13a3 100644 --- a/cmd/server/pactasrv/user.go +++ b/cmd/server/pactasrv/user.go @@ -60,10 +60,13 @@ func (s *Server) UpdateUser(ctx context.Context, request api.UpdateUserRequestOb // (DELETE /user/{id}) func (s *Server) DeleteUser(ctx context.Context, request api.DeleteUserRequestObject) (api.DeleteUserResponseObject, error) { // TODO(#12) Implement Authorization - err := s.DB.DeleteUser(s.DB.NoTxn(ctx), pacta.UserID(request.Id)) + blobURIs, err := s.DB.DeleteUser(s.DB.NoTxn(ctx), pacta.UserID(request.Id)) if err != nil { return nil, oapierr.Internal("failed to delete user", zap.Error(err)) } + if err := s.deleteBlobs(ctx, blobURIs...); err != nil { + return nil, err + } return api.DeleteUser204Response{}, nil } diff --git a/db/sqldb/BUILD.bazel b/db/sqldb/BUILD.bazel index e2e7625..79ba81b 100644 --- a/db/sqldb/BUILD.bazel +++ b/db/sqldb/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "initiative.go", "initiative_invitation.go", "initiative_user.go", + "merge.go", "owner.go", "pacta_version.go", "portfolio.go", @@ -48,6 +49,7 @@ go_test( "initiative_invitation_test.go", "initiative_test.go", "initiative_user_test.go", + "merge_test.go", "owner_test.go", "pacta_version_test.go", "portfolio_group_test.go", diff --git a/db/sqldb/audit_log.go b/db/sqldb/audit_log.go index be5966c..d0e0adb 100644 --- a/db/sqldb/audit_log.go +++ b/db/sqldb/audit_log.go @@ -35,6 +35,10 @@ func (d *DB) AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.Pa if err != nil { return nil, nil, fmt.Errorf("converting cursor to offset: %w", err) } + q, err = d.expandAuditLogQueryToAccountForMerges(tx, q) + if err != nil { + return nil, nil, fmt.Errorf("expanding audit_log query to account for merges: %w", err) + } sql, args, err := auditLogQuery(q) if err != nil { return nil, nil, fmt.Errorf("building audit_log query: %w", err) @@ -55,8 +59,42 @@ func (d *DB) AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.Pa } func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, error) { + sql, args, id, err := d.buildCreateAuditLogQuery(tx, a) + if err != nil { + return "", err + } + err = d.exec(tx, sql, args...) + if err != nil { + return "", fmt.Errorf("creating audit_log row: %w", err) + } + return id, nil +} + +func (d *DB) CreateAuditLogs(tx db.Tx, als []*pacta.AuditLog) error { + if len(als) == 0 { + return nil + } + if len(als) == 1 { + _, err := d.CreateAuditLog(tx, als[0]) + return err + } + batch := &pgx.Batch{} + for _, al := range als { + sql, args, _, err := d.buildCreateAuditLogQuery(tx, al) + if err != nil { + return fmt.Errorf("building batch audit_log updates: %w", err) + } + batch.Queue(sql, args...) + } + if err := d.ExecBatch(tx, batch); err != nil { + return fmt.Errorf("batch creating audit_logs: %w", err) + } + return nil +} + +func (d *DB) buildCreateAuditLogQuery(tx db.Tx, a *pacta.AuditLog) (string, []interface{}, pacta.AuditLogID, error) { if err := validateAuditLogForCreation(a); err != nil { - return "", fmt.Errorf("validating audit_log for creation: %w", err) + return "", nil, "", fmt.Errorf("validating audit_log for creation: %w", err) } id := pacta.AuditLogID(d.randomID(auditLogIDNamespace)) ownerFn := func(o *pacta.Owner) pgtype.Text { @@ -70,7 +108,7 @@ func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, erro stt.Valid = true stt.String = string(a.SecondaryTargetType) } - err := d.exec(tx, ` + sql := ` INSERT INTO audit_log ( id, action, actor_type, actor_id, actor_owner_id, @@ -79,13 +117,13 @@ func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, erro ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); - `, id, a.Action, a.ActorType, a.ActorID, ownerFn(a.ActorOwner), + ` + args := []interface{}{ + id, a.Action, a.ActorType, a.ActorID, ownerFn(a.ActorOwner), a.PrimaryTargetType, a.PrimaryTargetID, ownerFn(a.PrimaryTargetOwner), - stt, a.SecondaryTargetID, ownerFn(a.SecondaryTargetOwner)) - if err != nil { - return "", fmt.Errorf("creating audit_log row: %w", err) + stt, a.SecondaryTargetID, ownerFn(a.SecondaryTargetOwner), } - return id, nil + return sql, args, id, nil } func rowsToAuditLogs(rows pgx.Rows) ([]*pacta.AuditLog, error) { diff --git a/db/sqldb/audit_log_test.go b/db/sqldb/audit_log_test.go index 0349d52..38694ff 100644 --- a/db/sqldb/audit_log_test.go +++ b/db/sqldb/audit_log_test.go @@ -135,7 +135,7 @@ func TestAuditSearch(t *testing.T) { actorOwner2 := &pacta.Owner{ID: "owner2"} targetType1 := pacta.AuditLogTargetType_Portfolio targetType2 := pacta.AuditLogTargetType_IncompleteUpload - targetID1 := "portfolio-1" + targetID1 := actorID1 targetID2 := "incomplete-upload-2" targetOwner1 := &pacta.Owner{ID: "owner3"} targetOwner2 := &pacta.Owner{ID: "owner4"} @@ -222,7 +222,7 @@ func TestAuditSearch(t *testing.T) { for i, a := range auditLogs { actual[i] = a.ID } - if diff := cmp.Diff(c.expected, actual, sortAuditLogIDs()); diff != "" { + if diff := cmp.Diff(c.expected, actual, auditLogIDCmpOpts()); diff != "" { t.Errorf("unexpected diff:\n%s", diff) } }) @@ -237,31 +237,31 @@ func TestAuditSearch(t *testing.T) { }{{ name: "All Match", where: []*db.AuditLogQueryWhere{ - &db.AuditLogQueryWhere{InID: []pacta.AuditLogID{alID1}}, - &db.AuditLogQueryWhere{MinCreatedAt: beforeCreation}, - &db.AuditLogQueryWhere{MaxCreatedAt: afterCreation}, - &db.AuditLogQueryWhere{InAction: []pacta.AuditLogAction{action1}}, - &db.AuditLogQueryWhere{InActorType: []pacta.AuditLogActorType{actorType1}}, - &db.AuditLogQueryWhere{InActorID: []string{actorID1}}, - &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, - &db.AuditLogQueryWhere{InTargetType: []pacta.AuditLogTargetType{targetType1}}, - &db.AuditLogQueryWhere{InTargetID: []string{targetID1}}, - &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, + {InID: []pacta.AuditLogID{alID1}}, + {MinCreatedAt: beforeCreation}, + {MaxCreatedAt: afterCreation}, + {InAction: []pacta.AuditLogAction{action1}}, + {InActorType: []pacta.AuditLogActorType{actorType1}}, + {InActorID: []string{actorID1}}, + {InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, + {InTargetType: []pacta.AuditLogTargetType{targetType1}}, + {InTargetID: []string{targetID1}}, + {InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, }, expected: []pacta.AuditLogID{alID1}, }, { name: "One Does not Match", where: []*db.AuditLogQueryWhere{ - &db.AuditLogQueryWhere{InID: []pacta.AuditLogID{alID1}}, - &db.AuditLogQueryWhere{MinCreatedAt: beforeCreation}, - &db.AuditLogQueryWhere{MaxCreatedAt: afterCreation}, - &db.AuditLogQueryWhere{InAction: []pacta.AuditLogAction{action1}}, - &db.AuditLogQueryWhere{InActorType: []pacta.AuditLogActorType{actorType2}}, - &db.AuditLogQueryWhere{InActorID: []string{actorID1}}, - &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, - &db.AuditLogQueryWhere{InTargetType: []pacta.AuditLogTargetType{targetType1}}, - &db.AuditLogQueryWhere{InTargetID: []string{targetID1}}, - &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, + {InID: []pacta.AuditLogID{alID1}}, + {MinCreatedAt: beforeCreation}, + {MaxCreatedAt: afterCreation}, + {InAction: []pacta.AuditLogAction{action1}}, + {InActorType: []pacta.AuditLogActorType{actorType2}}, + {InActorID: []string{actorID1}}, + {InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, + {InTargetType: []pacta.AuditLogTargetType{targetType1}}, + {InTargetID: []string{targetID1}}, + {InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, }, expected: []pacta.AuditLogID{}, }} @@ -279,7 +279,171 @@ func TestAuditSearch(t *testing.T) { for i, a := range auditLogs { actual[i] = a.ID } - if diff := cmp.Diff(c.expected, actual, sortAuditLogIDs()); diff != "" { + if diff := cmp.Diff(c.expected, actual, auditLogIDCmpOpts()); diff != "" { + t.Errorf("unexpected diff:\n%s", diff) + } + }) + } + }) +} + +func TestAuditSearchAfterMerge(t *testing.T) { + action1 := pacta.AuditLogAction_AddTo + action2 := pacta.AuditLogAction_Create + actorType1 := pacta.AuditLogActorType_Owner + actorType2 := pacta.AuditLogActorType_System + actorID1 := "user1" + actorID2 := "user2" + actorOwner1 := &pacta.Owner{ID: "owner1"} + actorOwner2 := &pacta.Owner{ID: "owner2"} + targetType1 := pacta.AuditLogTargetType_Portfolio + targetType2 := pacta.AuditLogTargetType_IncompleteUpload + targetID1 := actorID1 + targetID2 := "incomplete-upload-2" + targetOwner1 := &pacta.Owner{ID: "owner3"} + targetOwner2 := &pacta.Owner{ID: "owner4"} + + ctx := context.Background() + tdb := createDBForTesting(t) + tx := tdb.NoTxn(ctx) + + alID1, err0 := tdb.CreateAuditLog(tx, &pacta.AuditLog{ActorType: actorType1, ActorID: actorID1, ActorOwner: actorOwner1, Action: action1, PrimaryTargetType: targetType1, PrimaryTargetID: targetID2, PrimaryTargetOwner: targetOwner2}) + alID2, err1 := tdb.CreateAuditLog(tx, &pacta.AuditLog{ActorType: actorType2, ActorID: actorID2, ActorOwner: actorOwner2, Action: action2, PrimaryTargetType: targetType2, PrimaryTargetID: targetID1, PrimaryTargetOwner: targetOwner1}) + alID3, err2 := tdb.CreateAuditLog(tx, &pacta.AuditLog{ActorType: actorType2, ActorID: actorID2, ActorOwner: actorOwner2, Action: action2, PrimaryTargetType: targetType2, PrimaryTargetID: "something", PrimaryTargetOwner: targetOwner1, SecondaryTargetType: targetType2, SecondaryTargetID: targetID2, SecondaryTargetOwner: targetOwner2}) + noErrDuringSetup(t, err0, err1, err2) + + t.Run("Pre-Merge Tests", func(t *testing.T) { + cases := []struct { + name string + where *db.AuditLogQueryWhere + expected []pacta.AuditLogID + }{{ + name: "By ActorID 1", + where: &db.AuditLogQueryWhere{InActorID: []string{actorID1}}, + expected: []pacta.AuditLogID{alID1}, + }, { + name: "By ActorID 2", + where: &db.AuditLogQueryWhere{InActorID: []string{actorID2}}, + expected: []pacta.AuditLogID{alID2, alID3}, + }, { + name: "By ActorOwnerID 1", + where: &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, + expected: []pacta.AuditLogID{alID1}, + }, { + name: "By ActorOwnerID 2", + where: &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner2.ID}}, + expected: []pacta.AuditLogID{alID2, alID3}, + }, { + name: "By TargetID = ActorID 1", + where: &db.AuditLogQueryWhere{InTargetID: []string{actorID1}}, + expected: []pacta.AuditLogID{alID2}, + }, { + name: "By TargetID = ActorID 2", + where: &db.AuditLogQueryWhere{InTargetID: []string{actorID2}}, + expected: []pacta.AuditLogID{}, + }, { + name: "By TargetID = Something Else", + where: &db.AuditLogQueryWhere{InTargetID: []string{targetID2}}, + expected: []pacta.AuditLogID{alID1, alID3}, + }, { + name: "By TargetOwnerID 1", + where: &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, + expected: []pacta.AuditLogID{alID2, alID3}, + }, { + name: "By TargetOwnerID 2", + where: &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner2.ID}}, + expected: []pacta.AuditLogID{alID1, alID3}, + }} + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d: %q", i, c.name), func(t *testing.T) { + auditLogs, _, err := tdb.AuditLogs(tx, &db.AuditLogQuery{ + Limit: 10, + Wheres: []*db.AuditLogQueryWhere{c.where}, + }) + if err != nil { + t.Fatalf("getting audit logs: %v", err) + } + actual := make([]pacta.AuditLogID, len(auditLogs)) + for i, a := range auditLogs { + actual[i] = a.ID + } + if diff := cmp.Diff(c.expected, actual, auditLogIDCmpOpts()); diff != "" { + t.Errorf("unexpected diff:\n%s", diff) + } + }) + } + }) + + t.Run("Executing Merges", func(t *testing.T) { + if err := tdb.RecordUserMerge(tx, pacta.UserID(actorID1), pacta.UserID(actorID2), "some-admin-owner"); err != nil { + t.Fatalf("merging users: %v", err) + } + if err := tdb.RecordUserMerge(tx, pacta.UserID(actorID1), pacta.UserID(actorID2), "some-admin-owner"); err != nil { + t.Fatalf("merging users duplicatively should be fine: %v", err) + } + if err := tdb.RecordOwnerMerge(tx, actorOwner1.ID, actorOwner2.ID, "some-admin-owner"); err != nil { + t.Fatalf("merging owners: %v", err) + } + }) + + t.Run("Post-Merge Tests", func(t *testing.T) { + cases := []struct { + name string + where *db.AuditLogQueryWhere + expected []pacta.AuditLogID + }{{ + name: "By ActorID 1", + where: &db.AuditLogQueryWhere{InActorID: []string{actorID1}}, + expected: []pacta.AuditLogID{alID1, alID2, alID3}, + }, { + name: "By ActorID 2", + where: &db.AuditLogQueryWhere{InActorID: []string{actorID2}}, + expected: []pacta.AuditLogID{alID1, alID2, alID3}, + }, { + name: "By ActorOwnerID 1", + where: &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner1.ID}}, + expected: []pacta.AuditLogID{alID1, alID2, alID3}, + }, { + name: "By ActorOwnerID 2", + where: &db.AuditLogQueryWhere{InActorOwnerID: []pacta.OwnerID{actorOwner2.ID}}, + expected: []pacta.AuditLogID{alID1, alID2, alID3}, + }, { + name: "By TargetID = ActorID 1", + where: &db.AuditLogQueryWhere{InTargetID: []string{actorID1}}, + expected: []pacta.AuditLogID{alID2}, + }, { + name: "By TargetID = ActorID 2", + where: &db.AuditLogQueryWhere{InTargetID: []string{actorID2}}, + expected: []pacta.AuditLogID{alID2}, + }, { + name: "By TargetID = Something Else", + where: &db.AuditLogQueryWhere{InTargetID: []string{targetID2}}, + expected: []pacta.AuditLogID{alID1, alID3}, + }, { + name: "By TargetOwnerID 1", + where: &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner1.ID}}, + expected: []pacta.AuditLogID{alID2, alID3}, + }, { + name: "By TargetOwnerID 2", + where: &db.AuditLogQueryWhere{InTargetOwnerID: []pacta.OwnerID{targetOwner2.ID}}, + expected: []pacta.AuditLogID{alID1, alID3}, + }} + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d: %q", i, c.name), func(t *testing.T) { + auditLogs, _, err := tdb.AuditLogs(tx, &db.AuditLogQuery{ + Limit: 10, + Wheres: []*db.AuditLogQueryWhere{c.where}, + }) + if err != nil { + t.Fatalf("getting audit logs: %v", err) + } + actual := make([]pacta.AuditLogID, len(auditLogs)) + for i, a := range auditLogs { + actual[i] = a.ID + } + if diff := cmp.Diff(c.expected, actual, auditLogIDCmpOpts()); diff != "" { t.Errorf("unexpected diff:\n%s", diff) } }) @@ -294,8 +458,11 @@ func auditLogCmpOpts() cmp.Option { } } -func sortAuditLogIDs() cmp.Option { - return cmpopts.SortSlices(func(a, b pacta.AuditLogID) bool { - return string(a) < string(b) - }) +func auditLogIDCmpOpts() cmp.Option { + return cmp.Options{ + cmpopts.SortSlices(func(a, b pacta.AuditLogID) bool { + return string(a) < string(b) + }), + cmpopts.EquateEmpty(), + } } diff --git a/db/sqldb/golden/human_readable_schema.sql b/db/sqldb/golden/human_readable_schema.sql index 7c53299..8ab3e8c 100644 --- a/db/sqldb/golden/human_readable_schema.sql +++ b/db/sqldb/golden/human_readable_schema.sql @@ -18,7 +18,8 @@ CREATE TYPE audit_log_action AS ENUM ( 'DISABLE_ADMIN_DEBUG', 'DOWNLOAD', 'ENABLE_SHARING', - 'DISABLE_SHARING'); + 'DISABLE_SHARING', + 'TRANSFER_OWNERSHIP'); CREATE TYPE audit_log_actor_type AS ENUM ( 'USER', 'ADMIN', @@ -175,6 +176,13 @@ ALTER TABLE ONLY owner ADD CONSTRAINT owner_initiative_id_fkey FOREIGN KEY (init ALTER TABLE ONLY owner ADD CONSTRAINT owner_user_id_fkey FOREIGN KEY (user_id) REFERENCES pacta_user(id) ON DELETE RESTRICT; +CREATE TABLE owner_merges ( + actor_user_id text NOT NULL, + from_owner_id text NOT NULL, + merged_at timestamp with time zone DEFAULT now() NOT NULL, + to_owner_id text NOT NULL); + + CREATE TABLE pacta_user ( admin boolean NOT NULL, authn_id text NOT NULL, @@ -269,4 +277,11 @@ CREATE TABLE schema_migrations_history ( id integer NOT NULL, version bigint NOT NULL); ALTER SEQUENCE schema_migrations_history_id_seq OWNED BY schema_migrations_history.id; + + +CREATE TABLE user_merges ( + actor_user_id text NOT NULL, + from_user_id text NOT NULL, + merged_at timestamp with time zone DEFAULT now() NOT NULL, + to_user_id text NOT NULL); ALTER TABLE ONLY schema_migrations ADD CONSTRAINT schema_migrations_pkey PRIMARY KEY (version); \ No newline at end of file diff --git a/db/sqldb/golden/schema_dump.sql b/db/sqldb/golden/schema_dump.sql index 80a8354..077da0a 100644 --- a/db/sqldb/golden/schema_dump.sql +++ b/db/sqldb/golden/schema_dump.sql @@ -42,7 +42,8 @@ CREATE TYPE public.audit_log_action AS ENUM ( 'DISABLE_ADMIN_DEBUG', 'DOWNLOAD', 'ENABLE_SHARING', - 'DISABLE_SHARING' + 'DISABLE_SHARING', + 'TRANSFER_OWNERSHIP' ); @@ -317,6 +318,20 @@ CREATE TABLE public.owner ( ALTER TABLE public.owner OWNER TO postgres; +-- +-- Name: owner_merges; Type: TABLE; Schema: public; Owner: postgres +-- + +CREATE TABLE public.owner_merges ( + from_owner_id text NOT NULL, + to_owner_id text NOT NULL, + actor_user_id text NOT NULL, + merged_at timestamp with time zone DEFAULT now() NOT NULL +); + + +ALTER TABLE public.owner_merges OWNER TO postgres; + -- -- Name: pacta_user; Type: TABLE; Schema: public; Owner: postgres -- @@ -478,6 +493,20 @@ ALTER TABLE public.schema_migrations_history_id_seq OWNER TO postgres; ALTER SEQUENCE public.schema_migrations_history_id_seq OWNED BY public.schema_migrations_history.id; +-- +-- Name: user_merges; Type: TABLE; Schema: public; Owner: postgres +-- + +CREATE TABLE public.user_merges ( + from_user_id text NOT NULL, + to_user_id text NOT NULL, + actor_user_id text NOT NULL, + merged_at timestamp with time zone DEFAULT now() NOT NULL +); + + +ALTER TABLE public.user_merges OWNER TO postgres; + -- -- Name: schema_migrations_history id; Type: DEFAULT; Schema: public; Owner: postgres -- diff --git a/db/sqldb/merge.go b/db/sqldb/merge.go new file mode 100644 index 0000000..453e3ba --- /dev/null +++ b/db/sqldb/merge.go @@ -0,0 +1,128 @@ +package sqldb + +import ( + "fmt" + + "github.com/RMI/pacta/db" + "github.com/RMI/pacta/pacta" +) + +func (d *DB) RecordUserMerge(tx db.Tx, fromUserID, toUserID, actorUserID pacta.UserID) error { + err := d.exec(tx, ` + INSERT INTO user_merges + (from_user_id, to_user_id, actor_user_id) + VALUES ($1, $2, $3);`, fromUserID, toUserID, actorUserID) + if err != nil { + return fmt.Errorf("inserting user merge: %w", err) + } + return nil +} + +func (d *DB) RecordOwnerMerge(tx db.Tx, fromOwnerID, toOwnerID pacta.OwnerID, actorUserID pacta.UserID) error { + err := d.exec(tx, ` + INSERT INTO owner_merges + (from_owner_id, to_owner_id, actor_user_id) + VALUES ($1, $2, $3);`, fromOwnerID, toOwnerID, actorUserID) + if err != nil { + return fmt.Errorf("inserting owner merge: %w", err) + } + return nil +} + +func (d *DB) expandAuditLogQueryToAccountForMerges(tx db.Tx, q *db.AuditLogQuery) (*db.AuditLogQuery, error) { + var err error + for _, w := range q.Wheres { + w.InActorID, err = d.findAllMergedUsers(tx, w.InActorID) + if err != nil { + return nil, fmt.Errorf("finding merged users for actor_id: %w", err) + } + + w.InTargetID, err = d.findAllMergedUsers(tx, w.InTargetID) + if err != nil { + return nil, fmt.Errorf("finding merged users for target_id: %w", err) + } + + w.InActorOwnerID, err = d.findAllMergedOwners(tx, w.InActorOwnerID) + if err != nil { + return nil, fmt.Errorf("finding merged owners for actor_owner_id: %w", err) + } + + w.InTargetOwnerID, err = d.findAllMergedOwners(tx, w.InTargetOwnerID) + if err != nil { + return nil, fmt.Errorf("finding merged owners for actor_owner_id: %w", err) + } + } + return q, nil +} + +func (d *DB) findAllMergedOwners(tx db.Tx, in []pacta.OwnerID) ([]pacta.OwnerID, error) { + relationshipFn := func(id pacta.OwnerID) ([]pacta.OwnerID, error) { + return d.findMergedOwners(tx, pacta.OwnerID(id)) + } + return recursivelyExpandRelationships(in, relationshipFn) +} + +func (d *DB) findMergedOwners(tx db.Tx, id pacta.OwnerID) ([]pacta.OwnerID, error) { + rows, err := d.query(tx, ` + (SELECT from_owner_id FROM owner_merges WHERE to_owner_id = $1) + UNION + (SELECT to_owner_id FROM owner_merges WHERE from_owner_id = $1);`, id) + if err != nil { + return nil, fmt.Errorf("querying owner_merges: %w", err) + } + ownerIDs, err := mapRowsToIDs[pacta.OwnerID]("merged_owners", rows) + if err != nil { + return nil, fmt.Errorf("mapping rows to owner ids: %w", err) + } + return ownerIDs, nil +} + +func (d *DB) findAllMergedUsers(tx db.Tx, in []string) ([]string, error) { + relationshipFn := func(id string) ([]string, error) { + others, err := d.findMergedUsers(tx, pacta.UserID(id)) + if err != nil { + return nil, fmt.Errorf("finding merged users for %q: %w", id, err) + } + return asStrs(others), nil + } + return recursivelyExpandRelationships(in, relationshipFn) +} + +func (d *DB) findMergedUsers(tx db.Tx, id pacta.UserID) ([]pacta.UserID, error) { + rows, err := d.query(tx, ` + (SELECT from_user_id FROM user_merges WHERE to_user_id = $1) + UNION + (SELECT to_user_id FROM user_merges WHERE from_user_id = $1);`, id) + if err != nil { + return nil, fmt.Errorf("querying user_merges: %w", err) + } + userIDs, err := mapRowsToIDs[pacta.UserID]("merged_users", rows) + if err != nil { + return nil, fmt.Errorf("mapping rows to user ids: %w", err) + } + return userIDs, nil +} + +func recursivelyExpandRelationships[S ~string](in []S, relatedFn func(S) ([]S, error)) ([]S, error) { + if len(in) == 0 { + return in, nil + } + all := asSet(in) + lookedUp := map[S]bool{} + for len(lookedUp) < len(all) { + for s := range all { + if lookedUp[s] { + continue + } + related, err := relatedFn(s) + if err != nil { + return nil, fmt.Errorf("finding relationships for %q: %w", s, err) + } + for _, r := range related { + all[r] = true + } + lookedUp[s] = true + } + } + return keys(all), nil +} diff --git a/db/sqldb/merge_test.go b/db/sqldb/merge_test.go new file mode 100644 index 0000000..1e5ee12 --- /dev/null +++ b/db/sqldb/merge_test.go @@ -0,0 +1,335 @@ +package sqldb + +import ( + "context" + "fmt" + "testing" + + "github.com/RMI/pacta/pacta" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestMergeUsers(t *testing.T) { + ctx := context.Background() + tdb := createDBForTesting(t) + tx := tdb.NoTxn(ctx) + a := userForTestingWithKey(t, tdb, "A") + b := userForTestingWithKey(t, tdb, "B") + c := userForTestingWithKey(t, tdb, "C") + d := userForTestingWithKey(t, tdb, "D") + e := userForTestingWithKey(t, tdb, "E") + f := userForTestingWithKey(t, tdb, "F") + + usersToIDs := func(users []*pacta.User) []string { + ids := []string{} + for _, u := range users { + ids = append(ids, string(u.ID)) + } + return ids + } + cmpOpts := cmpopts.SortSlices(func(a, b string) bool { return a < b }) + runTests := func(name string, tests []struct { + in []*pacta.User + expected []*pacta.User + }) { + t.Helper() + for _, test := range tests { + t.Run(fmt.Sprintf("%s_case_%v", name, test.in), func(t *testing.T) { + got, err := tdb.findAllMergedUsers(tx, usersToIDs(test.in)) + if err != nil { + t.Fatalf("get users: %v", err) + } + if diff := cmp.Diff(usersToIDs(test.expected), got, cmpOpts); diff != "" { + t.Fatalf("users mismatch (-want +got):\n%s", diff) + } + }) + + } + } + + runTests("Pre-Merge", []struct { + in []*pacta.User + expected []*pacta.User + }{ + { + in: []*pacta.User{}, + expected: []*pacta.User{}, + }, { + in: []*pacta.User{a}, + expected: []*pacta.User{a}, + }, { + in: []*pacta.User{c, c, c}, + expected: []*pacta.User{c}, + }, { + in: []*pacta.User{e}, + expected: []*pacta.User{e}, + }, { + in: []*pacta.User{a, b}, + expected: []*pacta.User{a, b}, + }, { + in: []*pacta.User{a, b, c, d, e, f}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, + }) + + // Merge A and B + if err := tdb.RecordUserMerge(tx, a.ID, b.ID, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + // Merge C and D + if err := tdb.RecordUserMerge(tx, c.ID, d.ID, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-First-Merge", []struct { + in []*pacta.User + expected []*pacta.User + }{ + { + in: []*pacta.User{}, + expected: []*pacta.User{}, + }, { + in: []*pacta.User{a}, + expected: []*pacta.User{a, b}, + }, { + in: []*pacta.User{e}, + expected: []*pacta.User{e}, + }, { + in: []*pacta.User{c, c, c}, + expected: []*pacta.User{c, d}, + }, { + in: []*pacta.User{a, b}, + expected: []*pacta.User{a, b}, + }, { + in: []*pacta.User{a, b, c, d, e, f}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, + }) + + // Merge B and E + if err := tdb.RecordUserMerge(tx, b.ID, e.ID, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + // Merge B and F + if err := tdb.RecordUserMerge(tx, f.ID, b.ID, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-Second-Merge", []struct { + in []*pacta.User + expected []*pacta.User + }{ + { + in: []*pacta.User{}, + expected: []*pacta.User{}, + }, { + in: []*pacta.User{a}, + expected: []*pacta.User{a, b, e, f}, + }, { + in: []*pacta.User{e}, + expected: []*pacta.User{a, b, e, f}, + }, { + in: []*pacta.User{c, c, c}, + expected: []*pacta.User{c, d}, + }, { + in: []*pacta.User{a, b}, + expected: []*pacta.User{a, b, e, f}, + }, { + in: []*pacta.User{a, b, c, d, e, f}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, + }) + + // Merge D and E + if err := tdb.RecordUserMerge(tx, d.ID, e.ID, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-Third-Merge", []struct { + in []*pacta.User + expected []*pacta.User + }{ + { + in: []*pacta.User{}, + expected: []*pacta.User{}, + }, { + in: []*pacta.User{a}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, { + in: []*pacta.User{e}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, { + in: []*pacta.User{c, c, c}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, { + in: []*pacta.User{a, b}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, { + in: []*pacta.User{a, b, c, d, e, f}, + expected: []*pacta.User{a, b, c, d, e, f}, + }, + }) +} + +func TestMergeOwners(t *testing.T) { + ctx := context.Background() + tdb := createDBForTesting(t) + tx := tdb.NoTxn(ctx) + uA := userForTestingWithKey(t, tdb, "A") + uB := userForTestingWithKey(t, tdb, "B") + uC := userForTestingWithKey(t, tdb, "C") + uD := userForTestingWithKey(t, tdb, "D") + uE := userForTestingWithKey(t, tdb, "E") + uF := userForTestingWithKey(t, tdb, "F") + a, err0 := tdb.GetOwnerForUser(tx, uA.ID) + b, err1 := tdb.GetOwnerForUser(tx, uB.ID) + c, err2 := tdb.GetOwnerForUser(tx, uC.ID) + d, err3 := tdb.GetOwnerForUser(tx, uD.ID) + e, err4 := tdb.GetOwnerForUser(tx, uE.ID) + f, err5 := tdb.GetOwnerForUser(tx, uF.ID) + noErrDuringSetup(t, err0, err1, err2, err3, err4, err5) + + cmpOpts := cmpopts.SortSlices(func(a, b pacta.OwnerID) bool { return a < b }) + runTests := func(name string, tests []struct { + in []pacta.OwnerID + expected []pacta.OwnerID + }) { + t.Helper() + for _, test := range tests { + t.Run(fmt.Sprintf("%s_case_%v", name, test.in), func(t *testing.T) { + got, err := tdb.findAllMergedOwners(tx, test.in) + if err != nil { + t.Fatalf("get owners: %v", err) + } + if diff := cmp.Diff(test.expected, got, cmpOpts); diff != "" { + t.Fatalf("owners mismatch (-want +got):\n%s", diff) + } + }) + + } + } + + runTests("Pre-Merge", []struct { + in []pacta.OwnerID + expected []pacta.OwnerID + }{ + { + in: []pacta.OwnerID{}, + expected: []pacta.OwnerID{}, + }, { + in: []pacta.OwnerID{a}, + expected: []pacta.OwnerID{a}, + }, { + in: []pacta.OwnerID{c, c, c}, + expected: []pacta.OwnerID{c}, + }, { + in: []pacta.OwnerID{e}, + expected: []pacta.OwnerID{e}, + }, { + in: []pacta.OwnerID{a, b}, + expected: []pacta.OwnerID{a, b}, + }, { + in: []pacta.OwnerID{a, b, c, d, e, f}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, + }) + + // Merge A and B + if err := tdb.RecordOwnerMerge(tx, a, b, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + // Merge C and D + if err := tdb.RecordOwnerMerge(tx, c, d, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-First-Merge", []struct { + in []pacta.OwnerID + expected []pacta.OwnerID + }{ + { + in: []pacta.OwnerID{}, + expected: []pacta.OwnerID{}, + }, { + in: []pacta.OwnerID{a}, + expected: []pacta.OwnerID{a, b}, + }, { + in: []pacta.OwnerID{e}, + expected: []pacta.OwnerID{e}, + }, { + in: []pacta.OwnerID{c, c, c}, + expected: []pacta.OwnerID{c, d}, + }, { + in: []pacta.OwnerID{a, b}, + expected: []pacta.OwnerID{a, b}, + }, { + in: []pacta.OwnerID{a, b, c, d, e, f}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, + }) + + // Merge B and E + if err := tdb.RecordOwnerMerge(tx, b, e, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + // Merge B and F + if err := tdb.RecordOwnerMerge(tx, f, b, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-Second-Merge", []struct { + in []pacta.OwnerID + expected []pacta.OwnerID + }{ + { + in: []pacta.OwnerID{}, + expected: []pacta.OwnerID{}, + }, { + in: []pacta.OwnerID{a}, + expected: []pacta.OwnerID{a, b, e, f}, + }, { + in: []pacta.OwnerID{e}, + expected: []pacta.OwnerID{a, b, e, f}, + }, { + in: []pacta.OwnerID{c, c, c}, + expected: []pacta.OwnerID{c, d}, + }, { + in: []pacta.OwnerID{a, b}, + expected: []pacta.OwnerID{a, b, e, f}, + }, { + in: []pacta.OwnerID{a, b, c, d, e, f}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, + }) + + // Merge D and E + if err := tdb.RecordOwnerMerge(tx, d, e, "AdminID"); err != nil { + t.Fatalf("record merge: %v", err) + } + + runTests("Post-Third-Merge", []struct { + in []pacta.OwnerID + expected []pacta.OwnerID + }{ + { + in: []pacta.OwnerID{}, + expected: []pacta.OwnerID{}, + }, { + in: []pacta.OwnerID{a}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, { + in: []pacta.OwnerID{e}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, { + in: []pacta.OwnerID{c, c, c}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, { + in: []pacta.OwnerID{a, b}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, { + in: []pacta.OwnerID{a, b, c, d, e, f}, + expected: []pacta.OwnerID{a, b, c, d, e, f}, + }, + }) +} diff --git a/db/sqldb/migrations/0009_support_user_merge.down.sql b/db/sqldb/migrations/0009_support_user_merge.down.sql new file mode 100644 index 0000000..eddfb5a --- /dev/null +++ b/db/sqldb/migrations/0009_support_user_merge.down.sql @@ -0,0 +1,28 @@ +BEGIN; + +DROP TABLE owner_merges; +DROP TABLE user_merges; + +-- There isn't a way to delete a value from an enum, so this is the workaround +-- https://stackoverflow.com/a/56777227/17909149 + +ALTER TABLE audit_log ALTER action TYPE TEXT; + +DROP TYPE audit_log_action; +CREATE TYPE audit_log_action AS ENUM ( + 'CREATE', + 'UPDATE', + 'DELETE', + 'ADD_TO', + 'REMOVE_FROM', + 'ENABLE_ADMIN_DEBUG', + 'DISABLE_ADMIN_DEBUG', + 'DOWNLOAD', + 'ENABLE_SHARING', + 'DISABLE_SHARING'); + +ALTER TABLE audit_log + ALTER action TYPE audit_log_action + USING audit_log_action::audit_log_action; + +COMMIT; diff --git a/db/sqldb/migrations/0009_support_user_merge.up.sql b/db/sqldb/migrations/0009_support_user_merge.up.sql new file mode 100644 index 0000000..95c8acb --- /dev/null +++ b/db/sqldb/migrations/0009_support_user_merge.up.sql @@ -0,0 +1,19 @@ +BEGIN; + +ALTER TYPE audit_log_action ADD VALUE 'TRANSFER_OWNERSHIP'; + +CREATE TABLE user_merges ( + from_user_id TEXT NOT NULL, + to_user_id TEXT NOT NULL, + actor_user_id TEXT NOT NULL, + merged_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE TABLE owner_merges ( + from_owner_id TEXT NOT NULL, + to_owner_id TEXT NOT NULL, + actor_user_id TEXT NOT NULL, + merged_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +COMMIT; \ No newline at end of file diff --git a/db/sqldb/owner.go b/db/sqldb/owner.go index a781c57..4b044d9 100644 --- a/db/sqldb/owner.go +++ b/db/sqldb/owner.go @@ -112,6 +112,64 @@ func (d *DB) GetOrCreateOwnerForInitiative(tx db.Tx, iID pacta.InitiativeID) (pa return ownerID, nil } +func (d *DB) DeleteOwner(tx db.Tx, oID pacta.OwnerID) ([]pacta.BlobURI, error) { + var buris []pacta.BlobURI + err := d.RunOrContinueTransaction(tx, func(tx db.Tx) error { + portfolios, err := d.PortfoliosByOwner(tx, oID) + if err != nil { + return fmt.Errorf("getting portfolios for owner: %w", err) + } + for _, portfolio := range portfolios { + newBuris, err := d.DeletePortfolio(tx, portfolio.ID) + if err != nil { + return fmt.Errorf("deleting portfolio: %w", err) + } + buris = append(buris, newBuris...) + } + analyses, err := d.AnalysesByOwner(tx, oID) + if err != nil { + return fmt.Errorf("getting analyses for owner: %w", err) + } + for _, analysis := range analyses { + newBuris, err := d.DeleteAnalysis(tx, analysis.ID) + if err != nil { + return fmt.Errorf("deleting analysis: %w", err) + } + buris = append(buris, newBuris...) + } + pgroups, err := d.PortfolioGroupsByOwner(tx, oID) + if err != nil { + return fmt.Errorf("getting portfolio groups for owner: %w", err) + } + for _, pgroup := range pgroups { + err := d.DeletePortfolioGroup(tx, pgroup.ID) + if err != nil { + return fmt.Errorf("deleting portfolio group: %w", err) + } + } + incompleteUploads, err := d.IncompleteUploadsByOwner(tx, oID) + if err != nil { + return fmt.Errorf("getting incomplete uploads for owner: %w", err) + } + for _, iu := range incompleteUploads { + newBuri, err := d.DeleteIncompleteUpload(tx, iu.ID) + if err != nil { + return fmt.Errorf("deleting incomplete upload: %w", err) + } + buris = append(buris, newBuri) + } + err = d.exec(tx, `DELETE FROM owner WHERE id = $1;`, oID) + if err != nil { + return fmt.Errorf("deleting actual owner: %w", err) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("deleting owner: %w", err) + } + return buris, nil +} + func (d *DB) createOwner(tx db.Tx, o *pacta.Owner) (pacta.OwnerID, error) { if err := validateOwnerForCreation(o); err != nil { return "", fmt.Errorf("validating owner for creation: %w", err) @@ -183,5 +241,3 @@ func validateOwnerForCreation(o *pacta.Owner) error { } return nil } - -// TODO(grady) take on owner deletion diff --git a/db/sqldb/sqldb.go b/db/sqldb/sqldb.go index adaab07..d4436a8 100644 --- a/db/sqldb/sqldb.go +++ b/db/sqldb/sqldb.go @@ -90,6 +90,7 @@ type DBConn interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) + SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults } func (d *DB) withConn(tx db.Tx, fn func(*ctxtx, DBConn) error) error { @@ -141,6 +142,24 @@ func (d *DB) exec(tx db.Tx, sql string, args ...interface{}) error { return err } +func (d *DB) ExecBatch(tx db.Tx, batch *pgx.Batch) error { + return d.withConn(tx, func(c *ctxtx, conn DBConn) error { + batchResults := conn.SendBatch(c.ctx, batch) + defer batchResults.Close() + n := batch.Len() + for i := 0; i < n; i++ { + _, err := batchResults.Exec() + if err != nil { + return fmt.Errorf("batch op %d/%d failed: %w", i+1, n, err) + } + } + if err := batchResults.Close(); err != nil { + return fmt.Errorf("closing batch results: %w", err) + } + return nil + }) +} + type rowScanner interface { Scan(...interface{}) error } @@ -390,6 +409,30 @@ func mapRowsToIDs[T ~string](name string, rows pgx.Rows) ([]T, error) { return mapRows(name, rows, fn) } +func asSet[T comparable](ts []T) map[T]bool { + result := map[T]bool{} + for _, t := range ts { + result[t] = true + } + return result +} + +func keys[K comparable, V any](m map[K]V) []K { + result := []K{} + for k := range m { + result = append(result, k) + } + return result +} + +func asStrs[S ~string](ss []S) []string { + result := make([]string, len(ss)) + for i, s := range ss { + result[i] = string(s) + } + return result +} + func checkSizesEquivalent(name string, sizes ...int) error { for i := 1; i < len(sizes); i++ { if sizes[i] != sizes[0] { diff --git a/db/sqldb/sqldb_test.go b/db/sqldb/sqldb_test.go index a62f975..68b118b 100644 --- a/db/sqldb/sqldb_test.go +++ b/db/sqldb/sqldb_test.go @@ -90,6 +90,7 @@ func TestSchemaHistory(t *testing.T) { {ID: 6, Version: 6}, // 0006_initiative_primary_key {ID: 7, Version: 7}, // 0007_audit_log_actor_type {ID: 8, Version: 8}, // 0008_indexes_on_blob_ids + {ID: 9, Version: 9}, // 0009_support_user_merge } if diff := cmp.Diff(want, got); diff != "" { diff --git a/db/sqldb/user.go b/db/sqldb/user.go index 91948a2..670bc05 100644 --- a/db/sqldb/user.go +++ b/db/sqldb/user.go @@ -72,10 +72,6 @@ func (d *DB) GetOrCreateUserByAuthn(tx db.Tx, authnMechanism pacta.AuthnMechanis if err != nil { return fmt.Errorf("creating user: %w", err) } - _, err = d.createOwner(tx, &pacta.Owner{User: &pacta.User{ID: uID}}) - if err != nil { - return fmt.Errorf("creating owner: %w", err) - } u, err = d.User(tx, uID) if err != nil { return fmt.Errorf("reading back created user: %w", err) @@ -119,14 +115,24 @@ func (d *DB) createUser(tx db.Tx, u *pacta.User) (pacta.UserID, error) { pl.String = string(u.PreferredLanguage) } id := pacta.UserID(d.randomID(userIDNamespace)) - err := d.exec(tx, ` - INSERT INTO pacta_user - (id, authn_mechanism, authn_id, entered_email, canonical_email, admin, super_admin, name, preferred_language) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9); - `, id, u.AuthnMechanism, u.AuthnID, u.EnteredEmail, u.CanonicalEmail, false, false, u.Name, pl) + err := d.RunOrContinueTransaction(tx, func(db.Tx) error { + err := d.exec(tx, ` + INSERT INTO pacta_user + (id, authn_mechanism, authn_id, entered_email, canonical_email, admin, super_admin, name, preferred_language) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9); + `, id, u.AuthnMechanism, u.AuthnID, u.EnteredEmail, u.CanonicalEmail, false, false, u.Name, pl) + if err != nil { + return fmt.Errorf("creating pacta_user row for %q: %w", id, err) + } + _, err = d.createOwner(tx, &pacta.Owner{User: &pacta.User{ID: id}}) + if err != nil { + return fmt.Errorf("creating owner: %w", err) + } + return nil + }) if err != nil { - return "", fmt.Errorf("creating pacta_user row for %q: %w", id, err) + return "", fmt.Errorf("creating user: %w", err) } return id, nil } @@ -155,19 +161,39 @@ func (d *DB) UpdateUser(tx db.Tx, id pacta.UserID, mutations ...db.UpdateUserFn) return nil } -func (d *DB) DeleteUser(tx db.Tx, id pacta.UserID) error { +func (d *DB) DeleteUser(tx db.Tx, id pacta.UserID) ([]pacta.BlobURI, error) { + buris := []pacta.BlobURI{} err := d.RunOrContinueTransaction(tx, func(db.Tx) error { - // TODO(grady) add entity deletions here - err := d.exec(tx, `DELETE FROM pacta_user WHERE id = $1;`, id) + userOwnerID, err := d.GetOwnerForUser(tx, id) + if err != nil { + if !db.IsNotFound(err) { + return fmt.Errorf("getting owner for user: %w", err) + } + } else { + newBuris, err := d.DeleteOwner(tx, userOwnerID) + if err != nil { + return fmt.Errorf("deleting owner: %w", err) + } + buris = append(buris, newBuris...) + } + err = d.exec(tx, `DELETE FROM initiative_invitation WHERE used_by_user_id = $1;`, id) if err != nil { - return fmt.Errorf("deleting user: %w", err) + return fmt.Errorf("deleting initiative_invitation rows: %w", err) + } + err = d.exec(tx, `UPDATE portfolio_initiative_membership SET added_by_user_id = NULL WHERE added_by_user_id = $1;`, id) + if err != nil { + return fmt.Errorf("clearing portfolio_initiative_membership.added_by_user_id: %w", err) + } + err = d.exec(tx, `DELETE FROM pacta_user WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("deleting actual user: %w", err) } return nil }) if err != nil { - return fmt.Errorf("performing initiative deletion: %w", err) + return nil, fmt.Errorf("performing user deletion: %w", err) } - return nil + return buris, nil } func (d *DB) putUser(tx db.Tx, u *pacta.User) error { diff --git a/db/sqldb/user_test.go b/db/sqldb/user_test.go index 3d7d557..abb18dd 100644 --- a/db/sqldb/user_test.go +++ b/db/sqldb/user_test.go @@ -218,7 +218,7 @@ func TestDeleteUser(t *testing.T) { userID, err0 := tdb.createUser(tx, u) noErrDuringSetup(t, err0) - err := tdb.DeleteUser(tx, userID) + _, err := tdb.DeleteUser(tx, userID) if err != nil { t.Fatalf("deleting user: %v", err) } diff --git a/openapi/pacta.yaml b/openapi/pacta.yaml index e278719..9ff9b66 100644 --- a/openapi/pacta.yaml +++ b/openapi/pacta.yaml @@ -136,6 +136,25 @@ paths: responses: '204': description: pacta version created successfully + /admin/merge-users: + post: + summary: Merges two users together + description: Merges two users together + operationId: mergeUsers + requestBody: + description: a request describing the two users to merge + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/MergeUsersReq' + responses: + '200': + description: the users were merged successfully + content: + application/json: + schema: + $ref: '#/components/schemas/MergeUsersResp' /initiative/{id}: get: summary: Returns an initiative by ID @@ -1879,6 +1898,7 @@ components: - AuditLogActionDownload - AuditLogActionEnableSharing - AuditLogActionDisableSharing + - AuditLogActionTransferOwnership AuditLogActorType: type: string enum: @@ -2058,6 +2078,42 @@ components: secondaryTargetOwner: type: string description: the id of the owner of the secondary object this action was performed on + MergeUsersReq: + type: object + required: + - fromUserId + - toUserId + properties: + fromUserId: + type: string + description: the user id of the user to merge records from, and to be deleted after the merge + toUserId: + type: string + description: the user id of the user to recieve merged records and to exist after the merge + MergeUsersResp: + type: object + required: + - incompleteUploadCount + - analysisCount + - portfolioCount + - portfolioGroupCount + - auditLogsCreated + properties: + incompleteUploadCount: + type: integer + description: the number of incomplete uploads that were transferred to the new user + analysisCount: + type: integer + description: the number of analyses that were transferred to the new user + portfolioCount: + type: integer + description: the number of portfolios that were transferred to the new user + portfolioGroupCount: + type: integer + description: the number of portfolio groups that were transferred to the new user + auditLogsCreated: + type: integer + description: the number of audit logs that were created to record the merge Error: type: object required: diff --git a/pacta/pacta.go b/pacta/pacta.go index 3d2836a..71fa86d 100644 --- a/pacta/pacta.go +++ b/pacta/pacta.go @@ -579,6 +579,7 @@ const ( AuditLogAction_Download AuditLogAction = "DOWNLOAD" AuditLogAction_EnableSharing AuditLogAction = "ENABLE_SHARING" AuditLogAction_DisableSharing AuditLogAction = "DISABLE_SHARING" + AuditLogAction_TransferOwnership AuditLogAction = "TRANSFER_OWNERSHIP" ) var AuditLogActionValues = []AuditLogAction{ @@ -592,6 +593,7 @@ var AuditLogActionValues = []AuditLogAction{ AuditLogAction_Download, AuditLogAction_EnableSharing, AuditLogAction_DisableSharing, + AuditLogAction_TransferOwnership, } func ParseAuditLogAction(s string) (AuditLogAction, error) { @@ -616,6 +618,8 @@ func ParseAuditLogAction(s string) (AuditLogAction, error) { return AuditLogAction_EnableSharing, nil case "DISABLE_SHARING": return AuditLogAction_DisableSharing, nil + case "TRANSFER_OWNERSHIP": + return AuditLogAction_TransferOwnership, nil } return "", fmt.Errorf("unknown AuditLogAction: %q", s) }