Skip to content

Commit

Permalink
Creates Backend for Merging User Accounts (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
gbdubs authored Jan 9, 2024
1 parent e4dbd9f commit 1a5a0cd
Show file tree
Hide file tree
Showing 23 changed files with 1,190 additions and 71 deletions.
1 change: 1 addition & 0 deletions cmd/server/pactasrv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "pactasrv",
srcs = [
"admin.go",
"analysis.go",
"audit_logs.go",
"blobs.go",
Expand Down
170 changes: 170 additions & 0 deletions cmd/server/pactasrv/admin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package pactasrv

import (
"context"
"fmt"

"github.com/RMI/pacta/db"
"github.com/RMI/pacta/oapierr"
api "github.com/RMI/pacta/openapi/pacta"
"github.com/RMI/pacta/pacta"
"go.uber.org/zap"
)

// Merges two users together
// (POST /admin/merge-users)
func (s *Server) MergeUsers(ctx context.Context, request api.MergeUsersRequestObject) (api.MergeUsersResponseObject, error) {
req := request.Body
actorUserInfo, err := s.getactorInfoOrErrIfAnon(ctx)
if err != nil {
return nil, err
}
fieldsIfErr := []zap.Field{
zap.String("actor_user_id", string(actorUserInfo.UserID)),
zap.String("from_user_id", req.FromUserId),
zap.String("to_user_id", req.ToUserId),
}
if !actorUserInfo.IsAdmin && !actorUserInfo.IsSuperAdmin {
return nil, oapierr.Forbidden("only admins can merge users", fieldsIfErr...)
}

sourceUID := pacta.UserID(req.FromUserId)
destUID := pacta.UserID(req.ToUserId)
if sourceUID == destUID {
return nil, oapierr.BadRequest("cannot merge user into themselves", fieldsIfErr...)
}

var (
numIncompleteUploads, numAnalyses, numPortfolios, numPortfolioGroups, numAuditLogsCreated int
buris []pacta.BlobURI
)

err = s.DB.Transactional(ctx, func(tx db.Tx) error {
sourceOwner, err := s.DB.GetOwnerForUser(tx, sourceUID)
if err != nil {
return fmt.Errorf("failed to get owner for source user: %w", err)
}
destOwner, err := s.DB.GetOwnerForUser(tx, destUID)
if err != nil {
return fmt.Errorf("failed to get owner for destination user: %w", err)
}

if err = s.DB.RecordUserMerge(tx, sourceUID, destUID, actorUserInfo.UserID); err != nil {
return fmt.Errorf("failed to record user merge: %w", err)
}
if err = s.DB.RecordOwnerMerge(tx, sourceOwner, destOwner, actorUserInfo.UserID); err != nil {
return fmt.Errorf("failed to record owner merge: %w", err)
}

auditLogsToCreate := []*pacta.AuditLog{}
addAuditLog := func(t pacta.AuditLogTargetType, id string) {
auditLogsToCreate = append(auditLogsToCreate, &pacta.AuditLog{
Action: pacta.AuditLogAction_TransferOwnership,
ActorType: pacta.AuditLogActorType_Admin,
ActorID: string(actorUserInfo.UserID),
ActorOwner: &pacta.Owner{ID: actorUserInfo.OwnerID},
PrimaryTargetType: t,
PrimaryTargetID: id,
PrimaryTargetOwner: &pacta.Owner{ID: destOwner},
SecondaryTargetType: pacta.AuditLogTargetType_User,
SecondaryTargetID: string(sourceUID),
SecondaryTargetOwner: &pacta.Owner{ID: sourceOwner},
})
}

incompleteUploads, err := s.DB.IncompleteUploadsByOwner(tx, sourceOwner)
if err != nil {
return fmt.Errorf("failed to get incomplete uploads for source owner: %w", err)
}
for i, upload := range incompleteUploads {
err := s.DB.UpdateIncompleteUpload(tx, upload.ID, db.SetIncompleteUploadOwner(destOwner))
if err != nil {
return fmt.Errorf("failed to update upload owner %d/%d: %w", i+1, len(incompleteUploads), err)
}
addAuditLog(pacta.AuditLogTargetType_IncompleteUpload, string(upload.ID))
}
numIncompleteUploads = len(incompleteUploads)

analyses, err := s.DB.AnalysesByOwner(tx, sourceOwner)
if err != nil {
return fmt.Errorf("failed to get analyses for source owner: %w", err)
}
for i, analysis := range analyses {
err := s.DB.UpdateAnalysis(tx, analysis.ID, db.SetAnalysisOwner(destOwner))
if err != nil {
return fmt.Errorf("failed to update analysis owner %d/%d: %w", i+1, len(analyses), err)
}
addAuditLog(pacta.AuditLogTargetType_Analysis, string(analysis.ID))
}
numAnalyses = len(analyses)

portfolios, err := s.DB.PortfoliosByOwner(tx, sourceOwner)
if err != nil {
return fmt.Errorf("failed to get portfolios for source owner: %w", err)
}
for i, portfolio := range portfolios {
err := s.DB.UpdatePortfolio(tx, portfolio.ID, db.SetPortfolioOwner(destOwner))
if err != nil {
return fmt.Errorf("failed to update portfolio owner %d/%d: %w", i+1, len(portfolios), err)
}
addAuditLog(pacta.AuditLogTargetType_Portfolio, string(portfolio.ID))
}
numPortfolios = len(portfolios)

portfolioGroups, err := s.DB.PortfolioGroupsByOwner(tx, sourceOwner)
if err != nil {
return fmt.Errorf("failed to get portfolio groups for source owner: %w", err)
}
for i, portfolioGroup := range portfolioGroups {
err := s.DB.UpdatePortfolioGroup(tx, portfolioGroup.ID, db.SetPortfolioGroupOwner(destOwner))
if err != nil {
return fmt.Errorf("failed to update portfolio group owner %d/%d: %w", i+1, len(portfolioGroups), err)
}
addAuditLog(pacta.AuditLogTargetType_PortfolioGroup, string(portfolioGroup.ID))
}
numPortfolioGroups = len(portfolioGroups)

if err := s.DB.CreateAuditLogs(tx, auditLogsToCreate); err != nil {
return fmt.Errorf("failed to create audit logs: %w", err)
}
numAuditLogsCreated = len(auditLogsToCreate)

// Now that we've transferred all the entities, we can delete the user.
deletedUserBuris, err := s.DB.DeleteUser(tx, sourceUID)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
if len(deletedUserBuris) > 0 {
// Note in this case we won't commit the transaction, so this data won't be orphaned.
return fmt.Errorf("failed to delete user: user still has blobs: %v", deletedUserBuris)
}

return nil
})
if err != nil {
fieldsIfErr := append(fieldsIfErr, zap.Error(err))
return nil, oapierr.Internal("failed to merge users", fieldsIfErr...)
}

if err := s.deleteBlobs(ctx, buris...); err != nil {
return nil, err
}

s.Logger.Info("user merge completed successfully",
zap.String("actor_user_id", string(actorUserInfo.UserID)),
zap.String("from_user_id", req.FromUserId),
zap.String("to_user_id", req.ToUserId),
zap.Int("num_incomplete_uploads", numIncompleteUploads),
zap.Int("num_analyses", numAnalyses),
zap.Int("num_portfolios", numPortfolios),
zap.Int("num_portfolio_groups", numPortfolioGroups),
zap.Int("num_audit_logs_created", numAuditLogsCreated),
)
return api.MergeUsers200JSONResponse{
AuditLogsCreated: numAuditLogsCreated,
IncompleteUploadCount: numIncompleteUploads,
PortfolioCount: numPortfolios,
PortfolioGroupCount: numPortfolioGroups,
AnalysisCount: numAnalyses,
}, nil
}
14 changes: 2 additions & 12 deletions cmd/server/pactasrv/blobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pactasrv

import (
"context"
"fmt"

"github.com/RMI/pacta/db"
"github.com/RMI/pacta/oapierr"
Expand All @@ -12,7 +11,7 @@ import (
)

func (s *Server) AccessBlobContent(ctx context.Context, request api.AccessBlobContentRequestObject) (api.AccessBlobContentResponseObject, error) {
actorInfo, err := s.getActorInfoOrFail(ctx)
actorInfo, err := s.getactorInfoOrErrIfAnon(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -69,16 +68,7 @@ func (s *Server) AccessBlobContent(ctx context.Context, request api.AccessBlobCo
return nil, oapierr.Internal("error getting blobs", zap.Error(err), zap.Strings("blob_ids", asStrs(blobIDs)))
}

err = s.DB.Transactional(ctx, func(tx db.Tx) error {
for i, al := range auditLogs {
_, err := s.DB.CreateAuditLog(tx, al)
if err != nil {
return fmt.Errorf("creating audit log %d/%d: %w", i+1, len(auditLogs), err)
}
}
return nil
})
if err != nil {
if err = s.DB.CreateAuditLogs(s.DB.NoTxn(ctx), auditLogs); err != nil {
return nil, oapierr.Internal("error creating audit logs - no download URLs generated", zap.Error(err), zap.Strings("blob_ids", asStrs(blobIDs)))
}

Expand Down
2 changes: 2 additions & 0 deletions cmd/server/pactasrv/conv/oapi_to_pacta.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ func auditLogActionFromOAPI(i api.AuditLogAction) (pacta.AuditLogAction, error)
return pacta.AuditLogAction_EnableSharing, nil
case api.AuditLogActionDisableSharing:
return pacta.AuditLogAction_DisableSharing, nil
case api.AuditLogActionTransferOwnership:
return pacta.AuditLogAction_TransferOwnership, nil
}
return "", oapierr.BadRequest("unknown audit log action", zap.String("audit_log_action", string(i)))
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/server/pactasrv/conv/pacta_to_oapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ func auditLogActionToOAPI(i pacta.AuditLogAction) (api.AuditLogAction, error) {
return api.AuditLogActionEnableSharing, nil
case pacta.AuditLogAction_DisableSharing:
return api.AuditLogActionDisableSharing, nil
case pacta.AuditLogAction_TransferOwnership:
return api.AuditLogActionTransferOwnership, nil
}
return "", oapierr.Internal(fmt.Sprintf("auditLogActionToOAPI: unknown action: %q", i))
}
Expand Down
8 changes: 6 additions & 2 deletions cmd/server/pactasrv/pactasrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,14 @@ type DB interface {
User(tx db.Tx, id pacta.UserID) (*pacta.User, error)
Users(tx db.Tx, ids []pacta.UserID) (map[pacta.UserID]*pacta.User, error)
UpdateUser(tx db.Tx, id pacta.UserID, mutations ...db.UpdateUserFn) error
DeleteUser(tx db.Tx, id pacta.UserID) error
DeleteUser(tx db.Tx, id pacta.UserID) ([]pacta.BlobURI, error)

CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, error)
CreateAuditLogs(tx db.Tx, as []*pacta.AuditLog) error
AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.PageInfo, error)

RecordUserMerge(tx db.Tx, fromUserID, toUserID, actorUserID pacta.UserID) error
RecordOwnerMerge(tx db.Tx, fromUserID, toUserID pacta.OwnerID, actorUserID pacta.UserID) error
}

type Blob interface {
Expand Down Expand Up @@ -211,7 +215,7 @@ type actorInfo struct {
IsSuperAdmin bool
}

func (s *Server) getActorInfoOrFail(ctx context.Context) (*actorInfo, error) {
func (s *Server) getactorInfoOrErrIfAnon(ctx context.Context) (*actorInfo, error) {
actorUserID, err := getUserID(ctx)
if err != nil {
return nil, err
Expand Down
5 changes: 4 additions & 1 deletion cmd/server/pactasrv/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ func (s *Server) UpdateUser(ctx context.Context, request api.UpdateUserRequestOb
// (DELETE /user/{id})
func (s *Server) DeleteUser(ctx context.Context, request api.DeleteUserRequestObject) (api.DeleteUserResponseObject, error) {
// TODO(#12) Implement Authorization
err := s.DB.DeleteUser(s.DB.NoTxn(ctx), pacta.UserID(request.Id))
blobURIs, err := s.DB.DeleteUser(s.DB.NoTxn(ctx), pacta.UserID(request.Id))
if err != nil {
return nil, oapierr.Internal("failed to delete user", zap.Error(err))
}
if err := s.deleteBlobs(ctx, blobURIs...); err != nil {
return nil, err
}
return api.DeleteUser204Response{}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions db/sqldb/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
"initiative.go",
"initiative_invitation.go",
"initiative_user.go",
"merge.go",
"owner.go",
"pacta_version.go",
"portfolio.go",
Expand Down Expand Up @@ -48,6 +49,7 @@ go_test(
"initiative_invitation_test.go",
"initiative_test.go",
"initiative_user_test.go",
"merge_test.go",
"owner_test.go",
"pacta_version_test.go",
"portfolio_group_test.go",
Expand Down
52 changes: 45 additions & 7 deletions db/sqldb/audit_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ func (d *DB) AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.Pa
if err != nil {
return nil, nil, fmt.Errorf("converting cursor to offset: %w", err)
}
q, err = d.expandAuditLogQueryToAccountForMerges(tx, q)
if err != nil {
return nil, nil, fmt.Errorf("expanding audit_log query to account for merges: %w", err)
}
sql, args, err := auditLogQuery(q)
if err != nil {
return nil, nil, fmt.Errorf("building audit_log query: %w", err)
Expand All @@ -55,8 +59,42 @@ func (d *DB) AuditLogs(tx db.Tx, q *db.AuditLogQuery) ([]*pacta.AuditLog, *db.Pa
}

func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, error) {
sql, args, id, err := d.buildCreateAuditLogQuery(tx, a)
if err != nil {
return "", err
}
err = d.exec(tx, sql, args...)
if err != nil {
return "", fmt.Errorf("creating audit_log row: %w", err)
}
return id, nil
}

func (d *DB) CreateAuditLogs(tx db.Tx, als []*pacta.AuditLog) error {
if len(als) == 0 {
return nil
}
if len(als) == 1 {
_, err := d.CreateAuditLog(tx, als[0])
return err
}
batch := &pgx.Batch{}
for _, al := range als {
sql, args, _, err := d.buildCreateAuditLogQuery(tx, al)
if err != nil {
return fmt.Errorf("building batch audit_log updates: %w", err)
}
batch.Queue(sql, args...)
}
if err := d.ExecBatch(tx, batch); err != nil {
return fmt.Errorf("batch creating audit_logs: %w", err)
}
return nil
}

func (d *DB) buildCreateAuditLogQuery(tx db.Tx, a *pacta.AuditLog) (string, []interface{}, pacta.AuditLogID, error) {
if err := validateAuditLogForCreation(a); err != nil {
return "", fmt.Errorf("validating audit_log for creation: %w", err)
return "", nil, "", fmt.Errorf("validating audit_log for creation: %w", err)
}
id := pacta.AuditLogID(d.randomID(auditLogIDNamespace))
ownerFn := func(o *pacta.Owner) pgtype.Text {
Expand All @@ -70,7 +108,7 @@ func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, erro
stt.Valid = true
stt.String = string(a.SecondaryTargetType)
}
err := d.exec(tx, `
sql := `
INSERT INTO audit_log
(
id, action, actor_type, actor_id, actor_owner_id,
Expand All @@ -79,13 +117,13 @@ func (d *DB) CreateAuditLog(tx db.Tx, a *pacta.AuditLog) (pacta.AuditLogID, erro
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
`, id, a.Action, a.ActorType, a.ActorID, ownerFn(a.ActorOwner),
`
args := []interface{}{
id, a.Action, a.ActorType, a.ActorID, ownerFn(a.ActorOwner),
a.PrimaryTargetType, a.PrimaryTargetID, ownerFn(a.PrimaryTargetOwner),
stt, a.SecondaryTargetID, ownerFn(a.SecondaryTargetOwner))
if err != nil {
return "", fmt.Errorf("creating audit_log row: %w", err)
stt, a.SecondaryTargetID, ownerFn(a.SecondaryTargetOwner),
}
return id, nil
return sql, args, id, nil
}

func rowsToAuditLogs(rows pgx.Rows) ([]*pacta.AuditLog, error) {
Expand Down
Loading

0 comments on commit 1a5a0cd

Please sign in to comment.