diff --git a/cmd/init/main.go b/cmd/init/main.go index ad6ede8..1327020 100644 --- a/cmd/init/main.go +++ b/cmd/init/main.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "log" "time" @@ -85,7 +86,7 @@ func createDefaultUser(ctx context.Context, db *gorm.DB) error { return nil } - if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound { + if !errors.Is(err, gorm.ErrRecordNotFound) { return status.Errorf(codes.Internal, "error %v", err) } diff --git a/go.mod b/go.mod index 3285dfa..1388ae9 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/instill-ai/protogen-go v0.3.3-alpha.0.20231129095217-f8d4e5951d35 github.com/instill-ai/usage-client v0.2.4-alpha.0.20231019203021-70410a0a8061 github.com/instill-ai/x v0.3.0-alpha - github.com/jackc/pgx/v5 v5.3.0 github.com/knadh/koanf v1.4.4 github.com/mennanov/fieldmask-utils v0.5.0 github.com/openfga/go-sdk v0.2.3 @@ -71,6 +70,7 @@ require ( github.com/influxdata/line-protocol/v2 v2.2.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect diff --git a/pkg/acl/acl.go b/pkg/acl/acl.go index f2f7126..6541068 100644 --- a/pkg/acl/acl.go +++ b/pkg/acl/acl.go @@ -105,7 +105,7 @@ func (c *ACLClient) GetOrganizationUserMembership(orgUID uuid.UUID, userUID uuid for _, tuple := range *data.Tuples { return *tuple.Key.Relation, nil } - return "", fmt.Errorf("no permission") + return "", ErrMembershipNotFound } func (c *ACLClient) GetOrganizationUsers(orgUID uuid.UUID) ([]*Relation, error) { diff --git a/pkg/acl/errors.go b/pkg/acl/errors.go new file mode 100644 index 0000000..6c945a3 --- /dev/null +++ b/pkg/acl/errors.go @@ -0,0 +1,5 @@ +package acl + +import "errors" + +var ErrMembershipNotFound = errors.New("membership not found") diff --git a/pkg/handler/errors.go b/pkg/handler/errors.go new file mode 100644 index 0000000..32f8ba8 --- /dev/null +++ b/pkg/handler/errors.go @@ -0,0 +1,10 @@ +package handler + +import "errors" + +var ErrCheckUpdateImmutableFields = errors.New("update immutable fields error") +var ErrCheckOutputOnlyFields = errors.New("can not contain output only fields") +var ErrCheckRequiredFields = errors.New("required fields missing") +var ErrFieldMask = errors.New("field mask error") +var ErrResourceID = errors.New("resource ID error") +var ErrUpdateMask = errors.New("update mask error") diff --git a/pkg/handler/privatehandler.go b/pkg/handler/privatehandler.go index 14a26d0..2651bb9 100644 --- a/pkg/handler/privatehandler.go +++ b/pkg/handler/privatehandler.go @@ -2,14 +2,11 @@ package handler import ( "context" - "fmt" "strings" "github.com/gofrs/uuid" "go.einride.tech/aip/filtering" "google.golang.org/genproto/googleapis/rpc/errdetails" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/instill-ai/mgmt-backend/pkg/logger" "github.com/instill-ai/mgmt-backend/pkg/service" @@ -35,7 +32,6 @@ func NewPrivateHandler(s service.Service) mgmtPB.MgmtPrivateServiceServer { // ListUsersAdmin lists all users func (h *PrivateHandler) ListUsersAdmin(ctx context.Context, req *mgmtPB.ListUsersAdminRequest) (*mgmtPB.ListUsersAdminResponse, error) { - logger, _ := logger.GetZapLogger(ctx) pageSize := req.GetPageSize() if pageSize == 0 { @@ -46,34 +42,7 @@ func (h *PrivateHandler) ListUsersAdmin(ctx context.Context, req *mgmtPB.ListUse pbUsers, totalSize, nextPageToken, err := h.Service.ListUsersAdmin(ctx, int(pageSize), req.GetPageToken(), filtering.Filter{}) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "list user error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "ListUsersAdminRequest.page_token", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ListUsersAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "list user error", - "user", - "", - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ListUsersAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.ListUsersAdminResponse{ @@ -86,40 +55,12 @@ func (h *PrivateHandler) ListUsersAdmin(ctx context.Context, req *mgmtPB.ListUse // GetUserAdmin gets a user func (h *PrivateHandler) GetUserAdmin(ctx context.Context, req *mgmtPB.GetUserAdminRequest) (*mgmtPB.GetUserAdminResponse, error) { - logger, _ := logger.GetZapLogger(ctx) id := strings.TrimPrefix(req.GetName(), "users/") pbUser, err := h.Service.GetUserAdmin(ctx, id) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "get user error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "GetUserAdminRequest.name", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.GetUserAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "get user error", - "user", - fmt.Sprintf("id %s", id), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.GetUserAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.GetUserAdminResponse{ @@ -152,34 +93,7 @@ func (h *PrivateHandler) LookUpUserAdmin(ctx context.Context, req *mgmtPB.LookUp pbUser, err := h.Service.GetUserByUIDAdmin(ctx, uid) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "look up user error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "LookUpUserAdminRequest.permalink", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.LookUpUserAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "look up user error", - "user", - fmt.Sprintf("uid %s", uid), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.LookUpUserAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.LookUpUserAdminResponse{ @@ -190,7 +104,6 @@ func (h *PrivateHandler) LookUpUserAdmin(ctx context.Context, req *mgmtPB.LookUp // ListOrganizationsAdmin lists all organizations func (h *PrivateHandler) ListOrganizationsAdmin(ctx context.Context, req *mgmtPB.ListOrganizationsAdminRequest) (*mgmtPB.ListOrganizationsAdminResponse, error) { - logger, _ := logger.GetZapLogger(ctx) pageSize := req.GetPageSize() if pageSize == 0 { @@ -201,34 +114,7 @@ func (h *PrivateHandler) ListOrganizationsAdmin(ctx context.Context, req *mgmtPB pbOrganizations, totalSize, nextPageToken, err := h.Service.ListOrganizationsAdmin(ctx, int(pageSize), req.GetPageToken(), filtering.Filter{}) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "list organization error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "ListOrganizationsAdminRequest.page_token", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ListOrganizationsAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "list organization error", - "organization", - "", - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ListOrganizationsAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.ListOrganizationsAdminResponse{ @@ -241,40 +127,12 @@ func (h *PrivateHandler) ListOrganizationsAdmin(ctx context.Context, req *mgmtPB // GetOrganizationAdmin gets a organization func (h *PrivateHandler) GetOrganizationAdmin(ctx context.Context, req *mgmtPB.GetOrganizationAdminRequest) (*mgmtPB.GetOrganizationAdminResponse, error) { - logger, _ := logger.GetZapLogger(ctx) id := strings.TrimPrefix(req.GetName(), "organizations/") pbOrganization, err := h.Service.GetOrganizationAdmin(ctx, id) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "get organization error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "GetOrganizationAdminRequest.name", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.GetOrganizationAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "get organization error", - "organization", - fmt.Sprintf("id %s", id), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.GetOrganizationAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.GetOrganizationAdminResponse{ @@ -285,56 +143,17 @@ func (h *PrivateHandler) GetOrganizationAdmin(ctx context.Context, req *mgmtPB.G // LookUpOrganizationAdmin gets a organization by permalink func (h *PrivateHandler) LookUpOrganizationAdmin(ctx context.Context, req *mgmtPB.LookUpOrganizationAdminRequest) (*mgmtPB.LookUpOrganizationAdminResponse, error) { - logger, _ := logger.GetZapLogger(ctx) uidStr := strings.TrimPrefix(req.GetPermalink(), "organizations/") // Validation: `uid` in request is valid uid, err := uuid.FromString(uidStr) if err != nil { - st, e := sterr.CreateErrorBadRequest( - "look up organization invalid uuid error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "LookUpOrganizationAdminRequest.permalink", - Description: err.Error(), - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.LookUpOrganizationAdminResponse{}, st.Err() + return nil, err } pbOrganization, err := h.Service.GetOrganizationByUIDAdmin(ctx, uid) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "look up organization error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "LookUpOrganizationAdminRequest.permalink", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.LookUpOrganizationAdminResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "look up organization error", - "organization", - fmt.Sprintf("uid %s", uid), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.LookUpOrganizationAdminResponse{}, st.Err() - } + return nil, err } resp := mgmtPB.LookUpOrganizationAdminResponse{ diff --git a/pkg/handler/publichandler.go b/pkg/handler/publichandler.go index 93eadb1..c9cf7fd 100644 --- a/pkg/handler/publichandler.go +++ b/pkg/handler/publichandler.go @@ -13,12 +13,8 @@ import ( "go.einride.tech/aip/filtering" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" - "golang.org/x/crypto/bcrypt" - "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" fieldmask_utils "github.com/mennanov/fieldmask-utils" @@ -27,7 +23,6 @@ import ( "github.com/instill-ai/mgmt-backend/pkg/logger" "github.com/instill-ai/mgmt-backend/pkg/service" "github.com/instill-ai/mgmt-backend/pkg/usage" - "github.com/instill-ai/x/sterr" custom_otel "github.com/instill-ai/mgmt-backend/pkg/logger/otel" healthcheckPB "github.com/instill-ai/protogen-go/common/healthcheck/v1alpha" @@ -93,16 +88,12 @@ func (h *PublicHandler) AuthTokenIssuer(ctx context.Context, in *mgmtPB.AuthToke user, err := h.Service.GetUserAdmin(ctx, in.Username) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + return nil, err } - passwordHash, _, err := h.Service.GetUserPasswordHash(ctx, uuid.FromStringOrNil(*user.Uid)) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(in.Password)) + err = h.Service.CheckUserPassword(ctx, uuid.FromStringOrNil(*user.Uid), in.Password) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + return nil, err } jti, _ := uuid.NewV4() @@ -120,32 +111,23 @@ func (h *PublicHandler) AuthTokenIssuer(ctx context.Context, in *mgmtPB.AuthToke func (h *PublicHandler) AuthChangePassword(ctx context.Context, in *mgmtPB.AuthChangePasswordRequest) (*mgmtPB.AuthChangePasswordResponse, error) { - userId, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - user, err := h.Service.GetUser(ctx, userUid, userId) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - - passwordHash, _, err := h.Service.GetUserPasswordHash(ctx, uuid.FromStringOrNil(*user.Uid)) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + return nil, err } - err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(in.OldPassword)) + user, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + return nil, err } - passwordBytes, err := bcrypt.GenerateFromPassword([]byte(in.NewPassword), 10) + err = h.Service.CheckUserPassword(ctx, uuid.FromStringOrNil(*user.Uid), in.OldPassword) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Update Password Failed") + return nil, err } - err = h.Service.UpdateUserPasswordHash(ctx, uuid.FromStringOrNil(*user.Uid), string(passwordBytes)) + err = h.Service.UpdateUserPassword(ctx, uuid.FromStringOrNil(*user.Uid), in.NewPassword) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Update Password Failed") + return nil, err } return &mgmtPB.AuthChangePasswordResponse{}, nil @@ -178,13 +160,13 @@ func (h *PublicHandler) ListUsers(ctx context.Context, req *mgmtPB.ListUsersRequ logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } - pbUsers, totalSize, nextPageToken, err := h.Service.ListUsers(ctx, userUid, int(req.GetPageSize()), req.GetPageToken(), filtering.Filter{}) + pbUsers, totalSize, nextPageToken, err := h.Service.ListUsers(ctx, ctxUserUID, int(req.GetPageSize()), req.GetPageToken(), filtering.Filter{}) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -193,7 +175,7 @@ func (h *PublicHandler) ListUsers(ctx context.Context, req *mgmtPB.ListUsersRequ logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -219,31 +201,21 @@ func (h *PublicHandler) GetUser(ctx context.Context, req *mgmtPB.GetUserRequest) logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - userId := strings.Split(req.Name, "/")[1] - if userId == "me" { - userId = ctxUserId + userID := strings.Split(req.Name, "/")[1] + if userID == "me" { + userID = ctxUserID } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, userId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, userID) if err != nil { logger.Error(err.Error()) - st, e := sterr.CreateErrorResourceInfo( - codes.NotFound, - "get user error", - "user", - fmt.Sprintf("id %s", userId), - "", - err.Error(), - ) - if e != nil { - logger.Error(e.Error()) - } - return nil, st.Err() + return nil, err } logger.Info(string(custom_otel.NewLogMessage( @@ -277,62 +249,30 @@ func (h *PublicHandler) PatchAuthenticatedUser(ctx context.Context, req *mgmtPB. // Validate the field mask if !req.GetUpdateMask().IsValid(reqUser) { - st, e := sterr.CreateErrorBadRequest( - "update user invalid fieldmask error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "PatchAuthenticatedUserRequest.update_mask", - Description: "invalid", - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() + return nil, fmt.Errorf("err") } reqFieldMask, err := checkfield.CheckUpdateOutputOnlyFields(req.GetUpdateMask(), outputOnlyFields) if err != nil { - st, e := sterr.CreateErrorBadRequest( - "update user update OUTPUT_ONLY fields error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "PatchAuthenticatedUserRequest OUTPUT_ONLY fields", - Description: err.Error(), - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() + logger.Error(err.Error()) + return nil, err } mask, err := fieldmask_utils.MaskFromProtoFieldMask(reqFieldMask, strcase.ToCamel) if err != nil { logger.Error(err.Error()) - st, e := sterr.CreateErrorBadRequest( - "update user update mask error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "PatchAuthenticatedUserRequest.update_mask", - Description: err.Error(), - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() + return nil, err } - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUserToUpdate, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUserToUpdate, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + return nil, err } if mask.IsEmpty() { @@ -346,66 +286,18 @@ func (h *PublicHandler) PatchAuthenticatedUser(ctx context.Context, req *mgmtPB. // Handle immutable fields from the update mask err = checkfield.CheckUpdateImmutableFields(reqUser, pbUserToUpdate, immutableFields) if err != nil { - st, e := sterr.CreateErrorBadRequest( - "update authenticated user update IMMUTABLE fields error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "PatchAuthenticatedUserRequest IMMUTABLE fields", - Description: err.Error(), - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() + return nil, ErrCheckUpdateImmutableFields } // Only the fields mentioned in the field mask will be copied to `pbUserToUpdate`, other fields are left intact err = fieldmask_utils.StructToStruct(mask, reqUser, pbUserToUpdate) if err != nil { - logger.Error(err.Error()) - st, e := sterr.CreateErrorResourceInfo( - codes.Internal, - "update authenticated user error", "user", fmt.Sprintf("uid %s", *reqUser.Uid), - "", - err.Error(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() - } - - pbUserUpdated, err := h.Service.UpdateUser(ctx, ctxUserUID, ctxUserId, pbUserToUpdate) - if err != nil { - sta := status.Convert(err) - switch sta.Code() { - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "update authenticated user error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "PatchAuthenticatedUserRequest", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "update authenticated user error", - "user", - fmt.Sprintf("uid %s", ctxUserUID.String()), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.PatchAuthenticatedUserResponse{}, st.Err() - } + return nil, ErrFieldMask + } + + pbUserUpdated, err := h.Service.UpdateUser(ctx, ctxUserUID, ctxUserID, pbUserToUpdate) + if err != nil { + return nil, err } resp := mgmtPB.PatchAuthenticatedUserResponse{ @@ -447,57 +339,12 @@ func (h *PublicHandler) ExistUsername(ctx context.Context, req *mgmtPB.ExistUser // number, and a 63 character maximum. err := checkfield.CheckResourceID(id) if err != nil { - st, e := sterr.CreateErrorBadRequest( - "verify whether username is occupied bad request error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "id", - Description: err.Error(), - }, - }, - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ExistUsernameResponse{}, st.Err() + return nil, ErrResourceID } pbUser, err := h.Service.GetUserAdmin(ctx, id) if err != nil { - sta := status.Convert(err) - switch sta.Code() { - // user not exist - username not occupied - case codes.NotFound: - resp := mgmtPB.ExistUsernameResponse{ - Exists: false, - } - return &resp, nil - // invalid username - case codes.InvalidArgument: - st, e := sterr.CreateErrorBadRequest( - "verify whether username is occupied error", []*errdetails.BadRequest_FieldViolation{ - { - Field: "ExistUsernameRequest.name", - Description: sta.Message(), - }, - }) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ExistUsernameResponse{}, st.Err() - default: - st, e := sterr.CreateErrorResourceInfo( - sta.Code(), - "verify whether username is occupied error", - "user", - fmt.Sprintf("id %s", id), - "", - sta.Message(), - ) - if e != nil { - logger.Error(e.Error()) - } - return &mgmtPB.ExistUsernameResponse{}, st.Err() - } + return nil, err } logger.Info(string(custom_otel.NewLogMessage( @@ -527,37 +374,28 @@ func (h *PublicHandler) CreateOrganization(ctx context.Context, req *mgmtPB.Crea // Set all OUTPUT_ONLY fields to zero value on the requested payload organization resource if err := checkfield.CheckCreateOutputOnlyFields(req.Organization, outputOnlyFieldsForOrganization); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckOutputOnlyFields } // Return error if REQUIRED fields are not provided in the requested payload organization resource if err := checkfield.CheckRequiredFields(req.Organization, createRequiredFieldsForOrganization); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckRequiredFields } // Return error if resource ID does not follow RFC-1034 if err := checkfield.CheckResourceID(req.Organization.GetId()); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrResourceID } - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - - _, err = h.Service.GetOrganization(ctx, ctxUserUID, req.Organization.Id) - if err == nil { - return nil, status.Errorf(codes.AlreadyExists, "Organization ID already existed") - } - - _, createErr := h.Service.CreateOrganization(ctx, ctxUserUID, req.Organization) - if createErr != nil { - return nil, status.Errorf(codes.AlreadyExists, createErr.Error()) + span.SetStatus(1, err.Error()) + return nil, err } - pbCreatedOrg, err := h.Service.GetOrganization(ctx, ctxUserUID, req.Organization.Id) + pbCreatedOrg, createErr := h.Service.CreateOrganization(ctx, ctxUserUID, req.Organization) if createErr != nil { - return nil, status.Errorf(codes.AlreadyExists, err.Error()) + return nil, createErr } resp := &mgmtPB.CreateOrganizationResponse{ @@ -591,9 +429,10 @@ func (h *PublicHandler) ListOrganizations(ctx context.Context, req *mgmtPB.ListO logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } pbOrgs, totalSize, nextPageToken, err := h.Service.ListOrganizations(ctx, ctxUserUID, int(req.GetPageSize()), req.GetPageToken(), filtering.Filter{}) @@ -627,14 +466,15 @@ func (h *PublicHandler) GetOrganization(ctx context.Context, req *mgmtPB.GetOrga logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } id, err := resource.GetRscNameID(req.GetName()) if err != nil { - return nil, err + return nil, ErrResourceID } pbOrg, err := h.Service.GetOrganization(ctx, ctxUserUID, id) @@ -669,9 +509,10 @@ func (h *PublicHandler) UpdateOrganization(ctx context.Context, req *mgmtPB.Upda logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } id, err := resource.GetRscNameID(req.GetOrganization().Name) @@ -684,7 +525,7 @@ func (h *PublicHandler) UpdateOrganization(ctx context.Context, req *mgmtPB.Upda // Validate the field mask if !pbUpdateMask.IsValid(pbOrgReq) { - return nil, status.Error(codes.InvalidArgument, "The update_mask is invalid") + return nil, ErrUpdateMask } getResp, err := h.GetOrganization(ctx, &mgmtPB.GetOrganizationRequest{Name: pbOrgReq.GetName()}) @@ -696,7 +537,7 @@ func (h *PublicHandler) UpdateOrganization(ctx context.Context, req *mgmtPB.Upda mask, err := fieldmask_utils.MaskFromProtoFieldMask(pbUpdateMask, strcase.ToCamel) if err != nil { span.SetStatus(1, err.Error()) - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, ErrFieldMask } if mask.IsEmpty() { @@ -710,18 +551,18 @@ func (h *PublicHandler) UpdateOrganization(ctx context.Context, req *mgmtPB.Upda // Return error if IMMUTABLE fields are intentionally changed if err := checkfield.CheckUpdateImmutableFields(pbOrgReq, pbOrgToUpdate, immutableFields); err != nil { span.SetStatus(1, err.Error()) - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, ErrCheckUpdateImmutableFields } // Only the fields mentioned in the field mask will be copied to `pbPipelineToUpdate`, other fields are left intact err = fieldmask_utils.StructToStruct(mask, pbOrgReq, pbOrgToUpdate) if err != nil { span.SetStatus(1, err.Error()) - return nil, err + return nil, ErrFieldMask } pbOrg, err := h.Service.UpdateOrganization(ctx, ctxUserUID, id, pbOrgToUpdate) - fmt.Println(pbOrg, err) + if err != nil { return nil, err } @@ -755,9 +596,9 @@ func (h *PublicHandler) DeleteOrganization(ctx context.Context, req *mgmtPB.Dele id, err := resource.GetRscNameID(req.Name) if err != nil { span.SetStatus(1, err.Error()) - return nil, err + return nil, ErrResourceID } - _, userUid, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -768,7 +609,7 @@ func (h *PublicHandler) DeleteOrganization(ctx context.Context, req *mgmtPB.Dele return nil, err } - if err := h.Service.DeleteOrganization(ctx, userUid, id); err != nil { + if err := h.Service.DeleteOrganization(ctx, ctxUserUID, id); err != nil { span.SetStatus(1, err.Error()) return nil, err } @@ -782,7 +623,7 @@ func (h *PublicHandler) DeleteOrganization(ctx context.Context, req *mgmtPB.Dele logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, custom_otel.SetEventResource(existOrg.GetOrganization()), ))) @@ -804,42 +645,38 @@ func (h *PublicHandler) CreateToken(ctx context.Context, req *mgmtPB.CreateToken // Set all OUTPUT_ONLY fields to zero value on the requested payload token resource if err := checkfield.CheckCreateOutputOnlyFields(req.Token, outputOnlyFieldsForToken); err != nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) + return &mgmtPB.CreateTokenResponse{}, ErrCheckOutputOnlyFields } // Return error if REQUIRED fields are not provided in the requested payload token resource if err := checkfield.CheckRequiredFields(req.Token, createRequiredFieldsForToken); err != nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) + return &mgmtPB.CreateTokenResponse{}, ErrCheckRequiredFields } // Return error if resource ID does not follow RFC-1034 if err := checkfield.CheckResourceID(req.Token.GetId()); err != nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) + return &mgmtPB.CreateTokenResponse{}, ErrResourceID } // Return error if expiration is not provided if req.Token.GetExpiration() == nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, "no expiration info") + return &mgmtPB.CreateTokenResponse{}, ErrCheckRequiredFields } - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - - _, err = h.Service.GetToken(ctx, ctxUserUID, req.Token.Id) - if err == nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.AlreadyExists, "Token ID already existed") + span.SetStatus(1, err.Error()) + return nil, err } - createErr := h.Service.CreateToken(ctx, ctxUserUID, req.Token) - if createErr != nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, createErr.Error()) + err = h.Service.CreateToken(ctx, ctxUserUID, req.Token) + if err != nil { + return nil, err } pbCreatedToken, err := h.Service.GetToken(ctx, ctxUserUID, req.Token.Id) - if createErr != nil { - return &mgmtPB.CreateTokenResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) + if err != nil { + return nil, err } resp := &mgmtPB.CreateTokenResponse{ @@ -874,9 +711,10 @@ func (h *PublicHandler) ListTokens(ctx context.Context, req *mgmtPB.ListTokensRe logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } pbTokens, totalSize, nextPageToken, err := h.Service.ListTokens(ctx, ctxUserUID, int64(req.GetPageSize()), req.GetPageToken()) @@ -911,19 +749,20 @@ func (h *PublicHandler) GetToken(ctx context.Context, req *mgmtPB.GetTokenReques logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } id, err := resource.GetRscNameID(req.GetName()) if err != nil { - return &mgmtPB.GetTokenResponse{}, err + return nil, ErrResourceID } pbToken, err := h.Service.GetToken(ctx, ctxUserUID, id) if err != nil { - return &mgmtPB.GetTokenResponse{}, err + return nil, err } resp := &mgmtPB.GetTokenResponse{ @@ -953,18 +792,19 @@ func (h *PublicHandler) DeleteToken(ctx context.Context, req *mgmtPB.DeleteToken logger, _ := logger.GetZapLogger(ctx) - _, ctxUserUID, err := h.Service.GetCtxUser(ctx) + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } existToken, err := h.GetToken(ctx, &mgmtPB.GetTokenRequest{Name: req.GetName()}) if err != nil { - return &mgmtPB.DeleteTokenResponse{}, err + return nil, err } if err := h.Service.DeleteToken(ctx, ctxUserUID, existToken.Token.GetId()); err != nil { - return &mgmtPB.DeleteTokenResponse{}, err + return nil, err } // We need to manually set the custom header to have a StatusCreated http response for REST endpoint @@ -994,14 +834,14 @@ func (h *PublicHandler) ValidateToken(ctx context.Context, req *mgmtPB.ValidateT authorization := resource.GetRequestSingleHeader(ctx, constant.HeaderAuthorization) apiToken := strings.Replace(authorization, "Bearer ", "", 1) - userUid, err := h.Service.ValidateToken(apiToken) + userUID, err := h.Service.ValidateToken(apiToken) if err != nil { span.SetStatus(1, err.Error()) return nil, err } - return &mgmtPB.ValidateTokenResponse{UserUid: userUid}, nil + return &mgmtPB.ValidateTokenResponse{UserUid: userUID}, nil } func (h *PublicHandler) ListPipelineTriggerRecords(ctx context.Context, req *mgmtPB.ListPipelineTriggerRecordsRequest) (*mgmtPB.ListPipelineTriggerRecordsResponse, error) { @@ -1015,14 +855,15 @@ func (h *PublicHandler) ListPipelineTriggerRecords(ctx context.Context, req *mgm logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerRecordsResponse{}, err + return nil, err } var mode mgmtPB.Mode @@ -1041,19 +882,19 @@ func (h *PublicHandler) ListPipelineTriggerRecords(ctx context.Context, req *mgm }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerRecordsResponse{}, err + return nil, err } pipelineTriggerRecords, totalSize, nextPageToken, err := h.Service.ListPipelineTriggerRecords(ctx, pbUser, int64(req.GetPageSize()), req.GetPageToken(), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListPipelineTriggerRecordsResponse{ @@ -1084,14 +925,15 @@ func (h *PublicHandler) ListPipelineTriggerTableRecords(ctx context.Context, req logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerTableRecordsResponse{}, err + return nil, err } declarations, err := filtering.NewDeclarations([]filtering.DeclarationOption{ @@ -1105,19 +947,19 @@ func (h *PublicHandler) ListPipelineTriggerTableRecords(ctx context.Context, req }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerTableRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerTableRecordsResponse{}, err + return nil, err } pipelineTriggerTableRecords, totalSize, nextPageToken, err := h.Service.ListPipelineTriggerTableRecords(ctx, pbUser, int64(req.GetPageSize()), req.GetPageToken(), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerTableRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListPipelineTriggerTableRecordsResponse{ @@ -1148,14 +990,15 @@ func (h *PublicHandler) ListPipelineTriggerChartRecords(ctx context.Context, req logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerChartRecordsResponse{}, err + return nil, err } var mode mgmtPB.Mode @@ -1174,19 +1017,19 @@ func (h *PublicHandler) ListPipelineTriggerChartRecords(ctx context.Context, req }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerChartRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerChartRecordsResponse{}, err + return nil, err } pipelineTriggerChartRecords, err := h.Service.ListPipelineTriggerChartRecords(ctx, pbUser, int64(req.GetAggregationWindow()), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListPipelineTriggerChartRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListPipelineTriggerChartRecordsResponse{ @@ -1214,14 +1057,15 @@ func (h *PublicHandler) ListConnectorExecuteRecords(ctx context.Context, req *mg logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteRecordsResponse{}, err + return nil, err } var status mgmtPB.Status @@ -1240,19 +1084,19 @@ func (h *PublicHandler) ListConnectorExecuteRecords(ctx context.Context, req *mg }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteRecordsResponse{}, err + return nil, err } connectorExecuteRecords, totalSize, nextPageToken, err := h.Service.ListConnectorExecuteRecords(ctx, pbUser, int64(req.GetPageSize()), req.GetPageToken(), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListConnectorExecuteRecordsResponse{ @@ -1283,14 +1127,15 @@ func (h *PublicHandler) ListConnectorExecuteTableRecords(ctx context.Context, re logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteTableRecordsResponse{}, err + return nil, err } declarations, err := filtering.NewDeclarations([]filtering.DeclarationOption{ @@ -1302,19 +1147,19 @@ func (h *PublicHandler) ListConnectorExecuteTableRecords(ctx context.Context, re }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteTableRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteTableRecordsResponse{}, err + return nil, err } connectorExecuteTableRecords, totalSize, nextPageToken, err := h.Service.ListConnectorExecuteTableRecords(ctx, pbUser, int64(req.GetPageSize()), req.GetPageToken(), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteTableRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListConnectorExecuteTableRecordsResponse{ @@ -1345,14 +1190,15 @@ func (h *PublicHandler) ListConnectorExecuteChartRecords(ctx context.Context, re logger, _ := logger.GetZapLogger(ctx) - ctxUserId, ctxUserUID, err := h.Service.GetCtxUser(ctx) + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") + span.SetStatus(1, err.Error()) + return nil, err } - pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserId) + pbUser, err := h.Service.GetUser(ctx, ctxUserUID, ctxUserID) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteChartRecordsResponse{}, err + return nil, err } var status mgmtPB.Status @@ -1371,19 +1217,19 @@ func (h *PublicHandler) ListConnectorExecuteChartRecords(ctx context.Context, re }...) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteChartRecordsResponse{}, err + return nil, err } filter, err := filtering.ParseFilter(req, declarations) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteChartRecordsResponse{}, err + return nil, err } connectorExecuteChartRecords, err := h.Service.ListConnectorExecuteChartRecords(ctx, pbUser, int64(req.GetAggregationWindow()), filter) if err != nil { span.SetStatus(1, err.Error()) - return &mgmtPB.ListConnectorExecuteChartRecordsResponse{}, err + return nil, err } resp := mgmtPB.ListConnectorExecuteChartRecordsResponse{ @@ -1412,17 +1258,16 @@ func (h *PublicHandler) ListUserMemberships(ctx context.Context, req *mgmtPB.Lis logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } - - pbMemberships, err := h.Service.ListUserMemberships(ctx, userUid, strings.Split(req.Parent, "/")[1]) + userID := strings.Split(req.Parent, "/")[1] + if userID == "me" { + userID = ctxUserID + } + pbMemberships, err := h.Service.ListUserMemberships(ctx, ctxUserUID, userID) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1431,7 +1276,7 @@ func (h *PublicHandler) ListUserMemberships(ctx context.Context, req *mgmtPB.Lis logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1454,20 +1299,18 @@ func (h *PublicHandler) GetUserMembership(ctx context.Context, req *mgmtPB.GetUs logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + userID := strings.Split(req.Name, "/")[1] + orgID := strings.Split(req.Name, "/")[3] + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } + if userID == "me" { + userID = ctxUserID + } - userID := strings.Split(req.Name, "/")[1] - orgID := strings.Split(req.Name, "/")[3] - - pbMembership, err := h.Service.GetUserMembership(ctx, userUid, userID, orgID) + pbMembership, err := h.Service.GetUserMembership(ctx, ctxUserUID, userID, orgID) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1476,7 +1319,7 @@ func (h *PublicHandler) GetUserMembership(ctx context.Context, req *mgmtPB.GetUs logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1499,28 +1342,26 @@ func (h *PublicHandler) UpdateUserMembership(ctx context.Context, req *mgmtPB.Up logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + userID := strings.Split(req.Membership.Name, "/")[1] + orgID := strings.Split(req.Membership.Name, "/")[3] + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } - - userID := strings.Split(req.Membership.Name, "/")[1] - orgID := strings.Split(req.Membership.Name, "/")[3] + if userID == "me" { + userID = ctxUserID + } if err := checkfield.CheckRequiredFields(req.Membership, requiredFieldsForUserMembership); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckRequiredFields } if err := checkfield.CheckCreateOutputOnlyFields(req.Membership, outputOnlyFieldsForUserMembership); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckOutputOnlyFields } - pbMembership, err := h.Service.UpdateUserMembership(ctx, userUid, userID, orgID, req.Membership) + pbMembership, err := h.Service.UpdateUserMembership(ctx, ctxUserUID, userID, orgID, req.Membership) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1529,7 +1370,7 @@ func (h *PublicHandler) UpdateUserMembership(ctx context.Context, req *mgmtPB.Up logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1552,20 +1393,18 @@ func (h *PublicHandler) DeleteUserMembership(ctx context.Context, req *mgmtPB.De logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + userID := strings.Split(req.Name, "/")[1] + orgID := strings.Split(req.Name, "/")[3] + ctxUserID, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } + if userID == "me" { + userID = ctxUserID + } - userID := strings.Split(req.Name, "/")[1] - orgID := strings.Split(req.Name, "/")[3] - - err = h.Service.DeleteUserMembership(ctx, userUid, userID, orgID) + err = h.Service.DeleteUserMembership(ctx, ctxUserUID, userID, orgID) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1574,7 +1413,7 @@ func (h *PublicHandler) DeleteUserMembership(ctx context.Context, req *mgmtPB.De logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1595,17 +1434,13 @@ func (h *PublicHandler) ListOrganizationMemberships(ctx context.Context, req *mg logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err } - pbMemberships, err := h.Service.ListOrganizationMemberships(ctx, userUid, strings.Split(req.Parent, "/")[1]) + pbMemberships, err := h.Service.ListOrganizationMemberships(ctx, ctxUserUID, strings.Split(req.Parent, "/")[1]) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1614,7 +1449,7 @@ func (h *PublicHandler) ListOrganizationMemberships(ctx context.Context, req *mg logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1637,11 +1472,7 @@ func (h *PublicHandler) GetOrganizationMembership(ctx context.Context, req *mgmt logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1650,7 +1481,7 @@ func (h *PublicHandler) GetOrganizationMembership(ctx context.Context, req *mgmt userID := strings.Split(req.Name, "/")[1] orgID := strings.Split(req.Name, "/")[3] - pbMembership, err := h.Service.GetOrganizationMembership(ctx, userUid, userID, orgID) + pbMembership, err := h.Service.GetOrganizationMembership(ctx, ctxUserUID, userID, orgID) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1659,7 +1490,7 @@ func (h *PublicHandler) GetOrganizationMembership(ctx context.Context, req *mgmt logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1682,11 +1513,7 @@ func (h *PublicHandler) UpdateOrganizationMembership(ctx context.Context, req *m logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1696,14 +1523,14 @@ func (h *PublicHandler) UpdateOrganizationMembership(ctx context.Context, req *m orgID := strings.Split(req.Membership.Name, "/")[3] if err := checkfield.CheckRequiredFields(req.Membership, requiredFieldsForOrganizationMembership); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckRequiredFields } if err := checkfield.CheckCreateOutputOnlyFields(req.Membership, outputOnlyFieldsForOrganizationMembership); err != nil { - return nil, status.Errorf(codes.InvalidArgument, err.Error()) + return nil, ErrCheckOutputOnlyFields } - pbMembership, err := h.Service.UpdateOrganizationMembership(ctx, userUid, userID, orgID, req.Membership) + pbMembership, err := h.Service.UpdateOrganizationMembership(ctx, ctxUserUID, userID, orgID, req.Membership) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1712,7 +1539,7 @@ func (h *PublicHandler) UpdateOrganizationMembership(ctx context.Context, req *m logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) @@ -1735,11 +1562,7 @@ func (h *PublicHandler) DeleteOrganizationMembership(ctx context.Context, req *m logger, _ := logger.GetZapLogger(ctx) - _, userUid, err := h.Service.GetCtxUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated request") - } - + _, ctxUserUID, err := h.Service.AuthenticateUser(ctx) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1748,7 +1571,7 @@ func (h *PublicHandler) DeleteOrganizationMembership(ctx context.Context, req *m orgID := strings.Split(req.Name, "/")[1] userID := strings.Split(req.Name, "/")[3] - err = h.Service.DeleteUserMembership(ctx, userUid, userID, orgID) + err = h.Service.DeleteUserMembership(ctx, ctxUserUID, userID, orgID) if err != nil { span.SetStatus(1, err.Error()) return nil, err @@ -1757,11 +1580,9 @@ func (h *PublicHandler) DeleteOrganizationMembership(ctx context.Context, req *m logger.Info(string(custom_otel.NewLogMessage( span, logUUID.String(), - userUid, + ctxUserUID, eventName, ))) - resp := mgmtPB.DeleteOrganizationMembershipResponse{} - - return &resp, nil + return &mgmtPB.DeleteOrganizationMembershipResponse{}, nil } diff --git a/pkg/middleware/interceptor.go b/pkg/middleware/interceptor.go index d71f588..4e99832 100644 --- a/pkg/middleware/interceptor.go +++ b/pkg/middleware/interceptor.go @@ -2,14 +2,21 @@ package middleware import ( "context" + "errors" + "golang.org/x/crypto/bcrypt" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "gorm.io/gorm" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + "github.com/instill-ai/mgmt-backend/pkg/acl" + "github.com/instill-ai/mgmt-backend/pkg/handler" + "github.com/instill-ai/mgmt-backend/pkg/repository" + "github.com/instill-ai/mgmt-backend/pkg/service" ) // RecoveryInterceptor - panic handler @@ -29,7 +36,7 @@ func UnaryAppendMetadataInterceptor(ctx context.Context, req interface{}, info * newCtx := metadata.NewIncomingContext(ctx, md) h, err := handler(newCtx, req) - return h, err + return h, InjectErrCode(err) } // CustomInterceptor - append metadatas for stream @@ -47,3 +54,65 @@ func StreamAppendMetadataInterceptor(srv interface{}, stream grpc.ServerStream, return err } + +func InjectErrCode(err error) error { + if err == nil { + return nil + } + + switch { + + case + errors.Is(err, gorm.ErrDuplicatedKey): + return status.Error(codes.AlreadyExists, err.Error()) + case + errors.Is(err, gorm.ErrRecordNotFound): + return status.Error(codes.NotFound, err.Error()) + + case + errors.Is(err, repository.ErrNoDataDeleted): + return status.Error(codes.NotFound, err.Error()) + + case + errors.Is(err, repository.ErrOwnerTypeNotMatch), + errors.Is(err, repository.ErrPageTokenDecode): + return status.Error(codes.InvalidArgument, err.Error()) + + case + errors.Is(err, service.ErrCanNotRemoveOwnerFromOrganization), + errors.Is(err, service.ErrCanNotSetAnotherOwner), + errors.Is(err, service.ErrInvalidRole), + errors.Is(err, service.ErrInvalidTokenTTL), + errors.Is(err, service.ErrStateCanOnlyBeActive), + errors.Is(err, service.ErrPasswordNotMatch): + return status.Error(codes.InvalidArgument, err.Error()) + + case + errors.Is(err, service.ErrNoPermission): + return status.Error(codes.PermissionDenied, err.Error()) + + case + errors.Is(err, service.ErrUnauthenticated): + return status.Error(codes.Unauthenticated, err.Error()) + + case + errors.Is(err, acl.ErrMembershipNotFound): + return status.Error(codes.NotFound, err.Error()) + + case + errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): + return status.Error(codes.InvalidArgument, err.Error()) + + case + errors.Is(err, handler.ErrCheckUpdateImmutableFields), + errors.Is(err, handler.ErrCheckOutputOnlyFields), + errors.Is(err, handler.ErrCheckRequiredFields), + errors.Is(err, handler.ErrFieldMask), + errors.Is(err, handler.ErrResourceID), + errors.Is(err, handler.ErrUpdateMask): + return status.Error(codes.InvalidArgument, err.Error()) + + default: + return status.Error(codes.Internal, err.Error()) + } +} diff --git a/pkg/middleware/misc.go b/pkg/middleware/misc.go index 26eff61..e91ef9f 100644 --- a/pkg/middleware/misc.go +++ b/pkg/middleware/misc.go @@ -18,8 +18,6 @@ import ( "github.com/instill-ai/mgmt-backend/pkg/constant" "github.com/instill-ai/mgmt-backend/pkg/logger" - - mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1alpha" ) // GetRequestSingleHeader get a request header, the header has to be single-value HTTP header @@ -160,8 +158,3 @@ func ErrorHandler(ctx context.Context, mux *runtime.ServeMux, marshaler runtime. } } } - -func InjectOwnerToContext(ctx context.Context, owner *mgmtPB.User) context.Context { - ctx = metadata.AppendToOutgoingContext(ctx, "Jwt-Sub", owner.GetUid()) - return ctx -} diff --git a/pkg/repository/errors.go b/pkg/repository/errors.go new file mode 100644 index 0000000..145ac58 --- /dev/null +++ b/pkg/repository/errors.go @@ -0,0 +1,7 @@ +package repository + +import "errors" + +var ErrPageTokenDecode = errors.New("page token decode error") +var ErrOwnerTypeNotMatch = errors.New("owner type not match") +var ErrNoDataDeleted = errors.New("No data deleted") diff --git a/pkg/repository/repository.go b/pkg/repository/repository.go index fa261c8..c782edd 100644 --- a/pkg/repository/repository.go +++ b/pkg/repository/repository.go @@ -8,10 +8,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/jackc/pgx/v5/pgconn" "go.einride.tech/aip/filtering" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "gorm.io/gorm" "github.com/instill-ai/mgmt-backend/pkg/datamodel" @@ -111,7 +108,7 @@ func (r *repository) GetAllUsers(ctx context.Context) ([]*datamodel.Owner, error var users []*datamodel.Owner if result := r.db.Find(users).Where("owner_type = 'user'"); result.Error != nil { logger.Error(result.Error.Error()) - return users, status.Errorf(codes.Internal, "error %v", result.Error) + return nil, result.Error } return users, nil } @@ -122,7 +119,7 @@ func (r *repository) listOwners(ctx context.Context, ownerType string, pageSize totalSize := int64(0) if result := r.db.Model(&datamodel.Owner{}).Where("owner_type = ?", ownerType).Count(&totalSize); result.Error != nil { logger.Error(result.Error.Error()) - return nil, totalSize, "", status.Errorf(codes.Internal, "error %v", result.Error) + return nil, totalSize, "", result.Error } queryBuilder := r.db.Model(&datamodel.Owner{}).Order("create_time DESC, id DESC") @@ -133,9 +130,10 @@ func (r *repository) listOwners(ctx context.Context, ownerType string, pageSize } if pageToken != "" { + // TODO: check pageToken in handler createTime, uid, err := paginate.DecodeToken(pageToken) if err != nil { - return nil, totalSize, "", status.Errorf(codes.InvalidArgument, "Invalid page token: %s", err.Error()) + return nil, totalSize, "", ErrPageTokenDecode } queryBuilder = queryBuilder.Where("(create_time,uid) < (?::timestamp, ?)", createTime, uid) } @@ -146,14 +144,14 @@ func (r *repository) listOwners(ctx context.Context, ownerType string, pageSize rows, err := queryBuilder.Rows() if err != nil { logger.Error(err.Error()) - return nil, totalSize, "", status.Errorf(codes.Internal, "error %v", err.Error()) + return nil, totalSize, "", err } defer rows.Close() for rows.Next() { var item datamodel.Owner if err = r.db.ScanRows(rows, &item); err != nil { logger.Error(err.Error()) - return nil, totalSize, "", status.Errorf(codes.Internal, "error %v", err.Error()) + return nil, totalSize, "", err } createTime = item.CreateTime @@ -169,7 +167,7 @@ func (r *repository) listOwners(ctx context.Context, ownerType string, pageSize Where("owner_type = ?", ownerType). Order("create_time ASC, uid ASC"). Limit(1).Find(lastItem); result.Error != nil { - return nil, 0, "", status.Errorf(codes.Internal, result.Error.Error()) + return nil, 0, "", result.Error } if lastItem.UID.String() == lastUID.String() { nextPageToken = "" @@ -186,19 +184,14 @@ func (r *repository) listOwners(ctx context.Context, ownerType string, pageSize func (r *repository) createOwner(ctx context.Context, ownerType string, owner *datamodel.Owner) error { if ownerType != owner.OwnerType.String { - return fmt.Errorf("wrong ownerType") + return ErrOwnerTypeNotMatch } logger, _ := logger.GetZapLogger(ctx) if result := r.db.Model(&datamodel.Owner{}).Create(owner); result.Error != nil { - var pgErr *pgconn.PgError - if errors.As(result.Error, &pgErr) { - if pgErr.Code == "23505" { - return status.Errorf(codes.AlreadyExists, pgErr.Message) - } - } + fmt.Println("errors", errors.Is(result.Error, gorm.ErrDuplicatedKey)) logger.Error(result.Error.Error()) - return status.Errorf(codes.Internal, "error %v", result.Error) + return result.Error } return nil } @@ -206,7 +199,7 @@ func (r *repository) createOwner(ctx context.Context, ownerType string, owner *d func (r *repository) getOwner(ctx context.Context, ownerType string, id string) (*datamodel.Owner, error) { var owner datamodel.Owner if result := r.db.Model(&datamodel.Owner{}).Where("owner_type = ?", ownerType).Where("id = ?", id).First(&owner); result.Error != nil { - return nil, status.Error(codes.NotFound, "the owner is not found") + return nil, result.Error } return &owner, nil } @@ -214,7 +207,7 @@ func (r *repository) getOwner(ctx context.Context, ownerType string, id string) func (r *repository) getOwnerByUID(ctx context.Context, ownerType string, uid uuid.UUID) (*datamodel.Owner, error) { var owner datamodel.Owner if result := r.db.Model(&datamodel.Owner{}).Where("owner_type = ?", ownerType).Where("uid = ?", uid.String()).First(&owner); result.Error != nil { - return nil, status.Error(codes.NotFound, "the owner is not found") + return nil, result.Error } return &owner, nil } @@ -223,7 +216,7 @@ func (r *repository) updateOwner(ctx context.Context, ownerType string, id strin logger, _ := logger.GetZapLogger(ctx) if result := r.db.Select("*").Omit("UID").Omit("password_hash").Model(&datamodel.Owner{}).Where("owner_type = ?", ownerType).Where("id = ?", id).Updates(owner); result.Error != nil { logger.Error(result.Error.Error()) - return status.Errorf(codes.Internal, "error %v", result.Error) + return result.Error } return nil } @@ -237,11 +230,11 @@ func (r *repository) deleteOwner(ctx context.Context, ownerType string, id strin if result.Error != nil { logger.Error(result.Error.Error()) - return status.Errorf(codes.Internal, "error %v", result.Error) + return result.Error } if result.RowsAffected == 0 { - return status.Errorf(codes.NotFound, "the owner with id %s is not found", id) + return ErrNoDataDeleted } return nil @@ -253,7 +246,7 @@ func (r *repository) deleteOwner(ctx context.Context, ownerType string, id strin func (r *repository) GetUserPasswordHash(ctx context.Context, uid uuid.UUID) (string, time.Time, error) { var pw datamodel.Password if result := r.db.First(&pw, "uid = ?", uid.String()); result.Error != nil { - return "", time.Time{}, status.Error(codes.NotFound, "the user is not found") + return "", time.Time{}, result.Error } return pw.PasswordHash.String, pw.PasswordUpdateTime, nil } @@ -265,7 +258,7 @@ func (r *repository) UpdateUserPasswordHash(ctx context.Context, uid uuid.UUID, PasswordUpdateTime: updateTime, }); result.Error != nil { logger.Error(result.Error.Error()) - return status.Errorf(codes.Internal, "error %v", result.Error) + return result.Error } return nil } @@ -277,13 +270,13 @@ func (r *repository) ListAllValidTokens(ctx context.Context) (tokens []datamodel queryBuilder.Where("expire_time >= ?", time.Now()) rows, err := queryBuilder.Rows() if err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, err } defer rows.Close() for rows.Next() { var item datamodel.Token if err = r.db.ScanRows(rows, &item); err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, err } // createTime = item.CreateTime tokens = append(tokens, item) @@ -295,7 +288,7 @@ func (r *repository) ListAllValidTokens(ctx context.Context) (tokens []datamodel func (r *repository) ListTokens(ctx context.Context, owner string, pageSize int64, pageToken string) (tokens []*datamodel.Token, totalSize int64, nextPageToken string, err error) { if result := r.db.Model(&datamodel.Token{}).Where("owner = ?", owner).Count(&totalSize); result.Error != nil { - return nil, 0, "", status.Errorf(codes.Internal, result.Error.Error()) + return nil, 0, "", err } queryBuilder := r.db.Model(&datamodel.Token{}).Order("create_time DESC, uid DESC").Where("owner = ?", owner) @@ -311,7 +304,7 @@ func (r *repository) ListTokens(ctx context.Context, owner string, pageSize int6 if pageToken != "" { createTime, uid, err := paginate.DecodeToken(pageToken) if err != nil { - return nil, 0, "", status.Errorf(codes.InvalidArgument, "Invalid page token: %s", err.Error()) + return nil, 0, "", err } queryBuilder = queryBuilder.Where("(create_time,uid) < (?::timestamp, ?)", createTime, uid) } @@ -319,13 +312,13 @@ func (r *repository) ListTokens(ctx context.Context, owner string, pageSize int6 var createTime time.Time rows, err := queryBuilder.Rows() if err != nil { - return nil, 0, "", status.Errorf(codes.Internal, err.Error()) + return nil, 0, "", err } defer rows.Close() for rows.Next() { var item datamodel.Token if err = r.db.ScanRows(rows, &item); err != nil { - return nil, 0, "", status.Error(codes.Internal, err.Error()) + return nil, 0, "", err } createTime = item.CreateTime tokens = append(tokens, &item) @@ -338,7 +331,7 @@ func (r *repository) ListTokens(ctx context.Context, owner string, pageSize int6 Where("owner = ?", owner). Order("create_time ASC, uid ASC"). Limit(1).Find(lastItem); result.Error != nil { - return nil, 0, "", status.Errorf(codes.Internal, result.Error.Error()) + return nil, 0, "", err } if lastItem.UID.String() == lastUID.String() { nextPageToken = "" @@ -353,14 +346,8 @@ func (r *repository) ListTokens(ctx context.Context, owner string, pageSize int6 func (r *repository) CreateToken(ctx context.Context, token *datamodel.Token) error { logger, _ := logger.GetZapLogger(ctx) if result := r.db.Model(&datamodel.Token{}).Create(token); result.Error != nil { - var pgErr *pgconn.PgError - if errors.As(result.Error, &pgErr) { - if pgErr.Code == "23505" { - return status.Errorf(codes.AlreadyExists, pgErr.Message) - } - } logger.Error(result.Error.Error()) - return status.Errorf(codes.Internal, "error %v", result.Error) + return result.Error } return nil } @@ -369,7 +356,7 @@ func (r *repository) GetToken(ctx context.Context, owner string, id string) (*da queryBuilder := r.db.Model(&datamodel.Token{}).Where("id = ? AND owner = ?", id, owner) var token datamodel.Token if result := queryBuilder.First(&token); result.Error != nil { - return nil, status.Errorf(codes.NotFound, "[GetToken] The token id %s you specified is not found", id) + return nil, result.Error } return &token, nil } @@ -380,11 +367,11 @@ func (r *repository) DeleteToken(ctx context.Context, owner string, id string) e Delete(&datamodel.Token{}) if result.Error != nil { - return status.Error(codes.Internal, result.Error.Error()) + return result.Error } if result.RowsAffected == 0 { - return status.Errorf(codes.NotFound, "[DeleteToken] The token id %s you specified is not found", id) + return ErrNoDataDeleted } return nil diff --git a/pkg/service/errors.go b/pkg/service/errors.go new file mode 100644 index 0000000..b006cc9 --- /dev/null +++ b/pkg/service/errors.go @@ -0,0 +1,12 @@ +package service + +import "errors" + +var ErrNoPermission = errors.New("no permission") +var ErrUnauthenticated = errors.New("unauthenticated") +var ErrInvalidTokenTTL = errors.New("invalid token ttl") +var ErrInvalidRole = errors.New("invalid role") +var ErrStateCanOnlyBeActive = errors.New("state can only be active") +var ErrCanNotRemoveOwnerFromOrganization = errors.New("can not remove owner from organization") +var ErrCanNotSetAnotherOwner = errors.New("can not set another user as owner") +var ErrPasswordNotMatch = errors.New("password not match") diff --git a/pkg/service/metric.go b/pkg/service/metric.go index 48dfabf..291515b 100644 --- a/pkg/service/metric.go +++ b/pkg/service/metric.go @@ -5,18 +5,23 @@ import ( "fmt" "go.einride.tech/aip/filtering" + "google.golang.org/grpc/metadata" "github.com/instill-ai/mgmt-backend/pkg/constant" - "github.com/instill-ai/mgmt-backend/pkg/middleware" "github.com/instill-ai/mgmt-backend/pkg/repository" mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1alpha" pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1alpha" ) +func InjectOwnerToContext(ctx context.Context, owner *mgmtPB.User) context.Context { + ctx = metadata.AppendToOutgoingContext(ctx, "Jwt-Sub", owner.GetUid()) + return ctx +} + func (s *service) pipelineUIDLookup(ctx context.Context, filter filtering.Filter, owner *mgmtPB.User) (filtering.Filter, error) { - ctx = middleware.InjectOwnerToContext(ctx, owner) + ctx = InjectOwnerToContext(ctx, owner) // lookup pipeline uid if len(filter.CheckedExpr.GetExpr().GetCallExpr().GetArgs()) > 0 { @@ -48,7 +53,7 @@ func (s *service) pipelineUIDLookup(ctx context.Context, filter filtering.Filter func (s *service) connectorUIDLookup(ctx context.Context, filter filtering.Filter, owner *mgmtPB.User) (filtering.Filter, error) { - ctx = middleware.InjectOwnerToContext(ctx, owner) + ctx = InjectOwnerToContext(ctx, owner) // lookup connector uid if len(filter.CheckedExpr.GetExpr().GetCallExpr().GetArgs()) > 0 { diff --git a/pkg/service/service.go b/pkg/service/service.go index 82a5cab..cd9aa70 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -9,8 +9,7 @@ import ( "github.com/gofrs/uuid" "github.com/redis/go-redis/v9" "go.einride.tech/aip/filtering" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "golang.org/x/crypto/bcrypt" "github.com/instill-ai/mgmt-backend/internal/resource" "github.com/instill-ai/mgmt-backend/pkg/acl" @@ -24,7 +23,7 @@ import ( // Service interface type Service interface { - GetCtxUser(ctx context.Context) (string, uuid.UUID, error) + AuthenticateUser(ctx context.Context) (userID string, userUID uuid.UUID, err error) ListRole() []string CreateUser(ctx context.Context, ctxUserUID uuid.UUID, user *mgmtPB.User) (*mgmtPB.User, error) @@ -63,8 +62,8 @@ type Service interface { DeleteToken(ctx context.Context, ctxUserUID uuid.UUID, id string) error ValidateToken(accessToken string) (string, error) - GetUserPasswordHash(ctx context.Context, uid uuid.UUID) (string, time.Time, error) - UpdateUserPasswordHash(ctx context.Context, uid uuid.UUID, newPassword string) error + CheckUserPassword(ctx context.Context, uid uuid.UUID, password string) error + UpdateUserPassword(ctx context.Context, uid uuid.UUID, newPassword string) error ListPipelineTriggerRecords(ctx context.Context, owner *mgmtPB.User, pageSize int64, pageToken string, filter filtering.Filter) ([]*mgmtPB.PipelineTriggerRecord, int64, string, error) ListPipelineTriggerTableRecords(ctx context.Context, owner *mgmtPB.User, pageSize int64, pageToken string, filter filtering.Filter) ([]*mgmtPB.PipelineTriggerTableRecord, int64, string, error) @@ -102,23 +101,23 @@ func NewService(r repository.Repository, rc *redis.Client, i repository.InfluxDB } // GetUser returns the api user -func (s *service) GetCtxUser(ctx context.Context) (string, uuid.UUID, error) { +func (s *service) AuthenticateUser(ctx context.Context) (userID string, userUID uuid.UUID, err error) { // Verify if "jwt-sub" is in the header - headerctxUserUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey) + headerCtxUserUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey) - if headerctxUserUID != "" { - _, err := uuid.FromString(headerctxUserUID) + if headerCtxUserUID != "" { + _, err := uuid.FromString(headerCtxUserUID) if err != nil { - return "", uuid.Nil, status.Errorf(codes.Unauthenticated, "Unauthorized") + return "", uuid.Nil, ErrUnauthenticated } - user, err := s.repository.GetUserByUID(ctx, uuid.FromStringOrNil(headerctxUserUID)) + user, err := s.repository.GetUserByUID(ctx, uuid.FromStringOrNil(headerCtxUserUID)) if err != nil { - return "", uuid.Nil, status.Errorf(codes.Unauthenticated, "Unauthorized") + return "", uuid.Nil, ErrUnauthenticated } - return user.ID, uuid.FromStringOrNil(headerctxUserUID), nil + return user.ID, uuid.FromStringOrNil(headerCtxUserUID), nil } - return "", uuid.Nil, status.Errorf(codes.Unauthenticated, "Unauthorized") + return "", uuid.Nil, ErrUnauthenticated } @@ -127,24 +126,15 @@ func (s *service) ListRole() []string { return ListAllowedRoleName() } -// ListUser lists all users -// Return error types -// - codes.InvalidArgument -// - codes.Internal -func (s *service) ListUsers(ctx context.Context, ctxUserUID uuid.UUID, pageSize int, pageToken string, filter filtering.Filter) ([]*mgmtPB.User, int64, string, error) { +func (s *service) ListUsers(ctx context.Context, ctxUserUID uuid.UUID, pageSize int, pageToken string, filter filtering.Filter) (users []*mgmtPB.User, totalSize int64, nextPageToken string, err error) { dbUsers, totalSize, nextPageToken, err := s.repository.ListUsers(ctx, pageSize, pageToken, filter) if err != nil { - return nil, 0, "", err + return nil, 0, "", fmt.Errorf("users/ with page_size=%d page_token=%s: %w", pageSize, pageToken, err) } - pbUsers, err := s.DBUsers2PBUsers(ctx, dbUsers) - return pbUsers, totalSize, nextPageToken, err + users, err = s.DBUsers2PBUsers(ctx, dbUsers) + return users, totalSize, nextPageToken, err } -// CreateUser creates an user instance -// Return error types -// - codes.InvalidArgument -// - codes.NotFound -// - codes.Internal func (s *service) CreateUser(ctx context.Context, ctxUserUID uuid.UUID, user *mgmtPB.User) (*mgmtPB.User, error) { dbUser, err := s.PBUser2DBUser(user) @@ -153,7 +143,7 @@ func (s *service) CreateUser(ctx context.Context, ctxUserUID uuid.UUID, user *mg } if dbUser.Role.Valid { if r := Role(dbUser.Role.String); !ValidateRole(r) { - return nil, status.Errorf(codes.InvalidArgument, "`role` %s in the body is not valid. Please choose from: [ %v ]", r.GetName(), strings.Join(s.ListRole(), ", ")) + return nil, ErrInvalidRole } } if err := s.repository.CreateUser(ctx, dbUser); err != nil { @@ -162,25 +152,17 @@ func (s *service) CreateUser(ctx context.Context, ctxUserUID uuid.UUID, user *mg dbCreatedUser, err := s.repository.GetUser(ctx, dbUser.ID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", dbUser.ID, err) } return s.DBUser2PBUser(ctx, dbCreatedUser) } -// GetUser gets a user by ID -// Return error types -// - codes.InvalidArgument -// - codes.NotFound func (s *service) GetUser(ctx context.Context, ctxUserUID uuid.UUID, id string) (*mgmtPB.User, error) { - // Validation: Required field - if id == "" { - return nil, status.Error(codes.InvalidArgument, "the required field `id` is not specified") - } dbUser, err := s.repository.GetUser(ctx, id) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", id, err) } return s.DBUser2PBUser(ctx, dbUser) } @@ -189,7 +171,7 @@ func (s *service) GetUserAdmin(ctx context.Context, id string) (*mgmtPB.User, er dbUser, err := s.repository.GetUser(ctx, id) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", id, err) } return s.DBUser2PBUser(ctx, dbUser) } @@ -198,7 +180,7 @@ func (s *service) GetUserByUIDAdmin(ctx context.Context, uid uuid.UUID) (*mgmtPB dbUser, err := s.repository.GetUserByUID(ctx, uid) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", uid, err) } return s.DBUser2PBUser(ctx, dbUser) } @@ -206,22 +188,18 @@ func (s *service) GetUserByUIDAdmin(ctx context.Context, uid uuid.UUID) (*mgmtPB func (s *service) ListUsersAdmin(ctx context.Context, pageSize int, pageToken string, filter filtering.Filter) ([]*mgmtPB.User, int64, string, error) { dbUsers, totalSize, nextPageToken, err := s.repository.ListUsers(ctx, pageSize, pageToken, filter) if err != nil { - return nil, 0, "", err + return nil, 0, "", fmt.Errorf("users/ with page_size=%d page_token=%s: %w", pageSize, pageToken, err) } pbUsers, err := s.DBUsers2PBUsers(ctx, dbUsers) return pbUsers, totalSize, nextPageToken, err } // UpdateUser updates a user by UUID -// Return error types -// - codes.InvalidArgument -// - codes.NotFound -// - codes.Internal func (s *service) UpdateUser(ctx context.Context, ctxUserUID uuid.UUID, id string, user *mgmtPB.User) (*mgmtPB.User, error) { // Check if the user exists if _, err := s.repository.GetUser(ctx, id); err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", id, err) } // Update the user @@ -232,39 +210,35 @@ func (s *service) UpdateUser(ctx context.Context, ctxUserUID uuid.UUID, id strin //Validation: role field if dbUser.Role.Valid { if r := Role(dbUser.Role.String); !ValidateRole(r) { - return nil, status.Errorf(codes.InvalidArgument, "`role` %s in the body is not valid. Please choose from: [ %v ]", r.GetName(), strings.Join(s.ListRole(), ", ")) + return nil, ErrInvalidRole } } if err := s.repository.UpdateUser(ctx, id, dbUser); err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", id, err) } dbUserUpdated, err := s.repository.GetUser(ctx, id) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", id, err) } return s.DBUser2PBUser(ctx, dbUserUpdated) } // DeleteUser deletes a user by ID -// Return error types -// - codes.InvalidArgument -// - codes.NotFound -// - codes.Internal func (s *service) DeleteUser(ctx context.Context, ctxUserUID uuid.UUID, id string) error { - // Validation: Required field - if id == "" { - return status.Error(codes.InvalidArgument, "the required field `id` is not specified") - } - return s.repository.DeleteUser(ctx, id) + err := s.repository.DeleteUser(ctx, id) + if err != nil { + return fmt.Errorf("users/%s: %w", id, err) + } + return nil } func (s *service) ListOrganizations(ctx context.Context, ctxUserUID uuid.UUID, pageSize int, pageToken string, filter filtering.Filter) ([]*mgmtPB.Organization, int64, string, error) { dbOrgs, totalSize, nextPageToken, err := s.repository.ListOrganizations(ctx, pageSize, pageToken, filter) if err != nil { - return nil, 0, "", err + return nil, 0, "", fmt.Errorf("organizations/ with page_size=%d page_token=%s: %w", pageSize, pageToken, err) } pbOrgs, err := s.DBOrgs2PBOrgs(ctx, dbOrgs) return pbOrgs, totalSize, nextPageToken, err @@ -286,7 +260,7 @@ func (s *service) CreateOrganization(ctx context.Context, ctxUserUID uuid.UUID, dbCreatedOrg, err := s.repository.GetOrganization(ctx, dbOrg.ID) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", dbOrg.ID, err) } err = s.aclClient.SetOrganizationUserMembership(dbOrg.UID, ctxUserUID, "owner") @@ -298,24 +272,27 @@ func (s *service) CreateOrganization(ctx context.Context, ctxUserUID uuid.UUID, } func (s *service) GetOrganization(ctx context.Context, ctxUserUID uuid.UUID, id string) (*mgmtPB.Organization, error) { - // Validation: Required field - if id == "" { - return nil, status.Error(codes.InvalidArgument, "the required field `id` is not specified") - } dbOrg, err := s.repository.GetOrganization(ctx, id) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", id, err) } return s.DBOrg2PBOrg(ctx, dbOrg) } func (s *service) UpdateOrganization(ctx context.Context, ctxUserUID uuid.UUID, id string, org *mgmtPB.Organization) (*mgmtPB.Organization, error) { - // Check if the org exists - if _, err := s.repository.GetOrganization(ctx, id); err != nil { + oriOrg, err := s.repository.GetOrganization(ctx, id) + if err != nil { + return nil, fmt.Errorf("organizations/%s: %w", id, err) + } + isOwner, err := s.aclClient.CheckOrganizationUserMembership(oriOrg.UID, ctxUserUID, "owner") + if err != nil { return nil, err } + if !isOwner { + return nil, ErrNoPermission + } // Update the user dbOrg, err := s.PBOrg2DBOrg(org) @@ -324,7 +301,7 @@ func (s *service) UpdateOrganization(ctx context.Context, ctxUserUID uuid.UUID, } if err := s.repository.UpdateOrganization(ctx, id, dbOrg); err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", id, err) } dbOrgUpdated, err := s.repository.GetOrganization(ctx, id) @@ -336,19 +313,31 @@ func (s *service) UpdateOrganization(ctx context.Context, ctxUserUID uuid.UUID, } func (s *service) DeleteOrganization(ctx context.Context, ctxUserUID uuid.UUID, id string) error { - // Validation: Required field - if id == "" { - return status.Error(codes.InvalidArgument, "the required field `id` is not specified") + org, err := s.repository.GetOrganization(ctx, id) + if err != nil { + return fmt.Errorf("organizations/%s: %w", id, err) + } + + isOwner, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "owner") + if err != nil { + return err + } + if !isOwner { + return ErrNoPermission } - return s.repository.DeleteOrganization(ctx, id) + err = s.repository.DeleteOrganization(ctx, id) + if err != nil { + return fmt.Errorf("organizations/%s: %w", id, err) + } + return nil } func (s *service) GetOrganizationAdmin(ctx context.Context, id string) (*mgmtPB.Organization, error) { dbOrganization, err := s.repository.GetOrganization(ctx, id) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", id, err) } return s.DBOrg2PBOrg(ctx, dbOrganization) } @@ -357,7 +346,7 @@ func (s *service) GetOrganizationByUIDAdmin(ctx context.Context, uid uuid.UUID) dbOrganization, err := s.repository.GetOrganizationByUID(ctx, uid) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", uid, err) } return s.DBOrg2PBOrg(ctx, dbOrganization) } @@ -365,19 +354,31 @@ func (s *service) GetOrganizationByUIDAdmin(ctx context.Context, uid uuid.UUID) func (s *service) ListOrganizationsAdmin(ctx context.Context, pageSize int, pageToken string, filter filtering.Filter) ([]*mgmtPB.Organization, int64, string, error) { dbOrganizations, totalSize, nextPageToken, err := s.repository.ListOrganizations(ctx, pageSize, pageToken, filter) if err != nil { - return nil, 0, "", err + return nil, 0, "", fmt.Errorf("organizations/ with page_size=%d page_token=%s: %w", pageSize, pageToken, err) } pbOrganizations, err := s.DBOrgs2PBOrgs(ctx, dbOrganizations) return pbOrganizations, totalSize, nextPageToken, err } -func (s *service) GetUserPasswordHash(ctx context.Context, uid uuid.UUID) (string, time.Time, error) { - return s.repository.GetUserPasswordHash(ctx, uid) -} +func (s *service) CheckUserPassword(ctx context.Context, uid uuid.UUID, password string) error { + passwordHash, _, err := s.repository.GetUserPasswordHash(ctx, uid) + if err != nil { + return err + } -func (s *service) UpdateUserPasswordHash(ctx context.Context, uid uuid.UUID, newPassword string) error { + err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)) + if err != nil { + return ErrPasswordNotMatch + } + return nil +} - return s.repository.UpdateUserPasswordHash(ctx, uid, newPassword, time.Now()) +func (s *service) UpdateUserPassword(ctx context.Context, uid uuid.UUID, newPassword string) error { + passwordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), 10) + if err != nil { + return err + } + return s.repository.UpdateUserPasswordHash(ctx, uid, string(passwordBytes), time.Now()) } func (s *service) CreateToken(ctx context.Context, ctxUserUID uuid.UUID, token *mgmtPB.ApiToken) error { @@ -401,7 +402,7 @@ func (s *service) CreateToken(ctx context.Context, ctxUserUID uuid.UUID, token * } else if token.GetTtl() == -1 { dbToken.ExpireTime = time.Date(2099, 12, 31, 0, 0, 0, 0, time.Now().UTC().Location()) } else { - return status.Errorf(codes.InvalidArgument, "ttl should >= -1") + return ErrInvalidTokenTTL } case *mgmtPB.ApiToken_ExpireTime: dbToken.ExpireTime = token.GetExpireTime().AsTime() @@ -423,7 +424,7 @@ func (s *service) ListTokens(ctx context.Context, ctxUserUID uuid.UUID, pageSize ownerPermlink := fmt.Sprintf("users/%s", ctxUserUID.String()) dbTokens, pageSize, pageToken, err := s.repository.ListTokens(ctx, ownerPermlink, pageSize, pageToken) if err != nil { - return nil, 0, "", err + return nil, 0, "", fmt.Errorf("tokens/ with page_size=%d page_token=%s: %w", pageSize, pageToken, err) } pbTokens, err := s.DBTokens2PBTokens(ctx, dbTokens) @@ -435,7 +436,7 @@ func (s *service) GetToken(ctx context.Context, ctxUserUID uuid.UUID, id string) ownerPermlink := fmt.Sprintf("users/%s", ctxUserUID.String()) dbToken, err := s.repository.GetToken(ctx, ownerPermlink, id) if err != nil { - return nil, err + return nil, fmt.Errorf("tokens/%s: %w", id, err) } return s.DBToken2PBToken(ctx, dbToken) @@ -446,17 +447,16 @@ func (s *service) DeleteToken(ctx context.Context, ctxUserUID uuid.UUID, id stri ownerPermlink := fmt.Sprintf("users/%s", ctxUserUID.String()) token, err := s.repository.GetToken(ctx, ownerPermlink, id) if err != nil { - return err + return fmt.Errorf("tokens/%s: %w", id, err) } accessToken := token.AccessToken // TODO: should be more robust s.redisClient.Del(context.Background(), fmt.Sprintf(constant.AccessTokenKeyFormat, accessToken)) - delErr := s.repository.DeleteToken(ctx, ownerPermlink, id) - if delErr != nil { - return delErr + err = s.repository.DeleteToken(ctx, ownerPermlink, id) + if err != nil { + return fmt.Errorf("tokens/%s: %w", id, err) } - return nil } func (s *service) ValidateToken(accessToken string) (string, error) { @@ -470,7 +470,10 @@ func (s *service) ValidateToken(accessToken string) (string, error) { func (s *service) ListUserMemberships(ctx context.Context, ctxUserUID uuid.UUID, userID string) ([]*mgmtPB.UserMembership, error) { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", userID, err) + } + if ctxUserUID != user.UID { + return nil, ErrNoPermission } orgRelations, err := s.aclClient.GetUserOrganizations(user.UID) @@ -487,7 +490,7 @@ func (s *service) ListUserMemberships(ctx context.Context, ctxUserUID uuid.UUID, for _, orgRelation := range orgRelations { org, err := s.repository.GetOrganizationByUID(ctx, orgRelation.UID) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", org.ID, err) } pbOrg, err := s.DBOrg2PBOrg(ctx, org) if err != nil { @@ -508,11 +511,14 @@ func (s *service) ListUserMemberships(ctx context.Context, ctxUserUID uuid.UUID, func (s *service) GetUserMembership(ctx context.Context, ctxUserUID uuid.UUID, userID string, orgID string) (*mgmtPB.UserMembership, error) { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", userID, err) + } + if ctxUserUID != user.UID { + return nil, ErrNoPermission } org, err := s.repository.GetOrganization(ctx, orgID) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", orgID, err) } role, err := s.aclClient.GetOrganizationUserMembership(org.UID, user.UID) if err != nil { @@ -541,14 +547,14 @@ func (s *service) GetUserMembership(ctx context.Context, ctxUserUID uuid.UUID, u func (s *service) UpdateUserMembership(ctx context.Context, ctxUserUID uuid.UUID, userID string, orgID string, membership *mgmtPB.UserMembership) (*mgmtPB.UserMembership, error) { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", userID, err) } if ctxUserUID != user.UID { - return nil, status.Errorf(codes.PermissionDenied, "Permission Denied") + return nil, ErrNoPermission } org, err := s.repository.GetOrganization(ctx, orgID) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", orgID, err) } pbUser, err := s.DBUser2PBUser(ctx, user) if err != nil { @@ -585,20 +591,20 @@ func (s *service) UpdateUserMembership(ctx context.Context, ctxUserUID uuid.UUID return updatedMembership, nil } - return nil, fmt.Errorf("state can only be 'active'") + return nil, ErrStateCanOnlyBeActive } func (s *service) DeleteUserMembership(ctx context.Context, ctxUserUID uuid.UUID, userID string, orgID string) error { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return err + return fmt.Errorf("users/%s: %w", userID, err) } if ctxUserUID != user.UID { - return status.Errorf(codes.PermissionDenied, "Permission Denied") + return ErrNoPermission } org, err := s.repository.GetOrganization(ctx, orgID) if err != nil { - return err + return fmt.Errorf("organizations/%s: %w", orgID, err) } err = s.aclClient.DeleteOrganizationUserMembership(org.UID, user.UID) if err != nil { @@ -609,9 +615,21 @@ func (s *service) DeleteUserMembership(ctx context.Context, ctxUserUID uuid.UUID func (s *service) ListOrganizationMemberships(ctx context.Context, ctxUserUID uuid.UUID, orgID string) ([]*mgmtPB.OrganizationMembership, error) { org, err := s.repository.GetOrganization(ctx, orgID) + if err != nil { + return nil, fmt.Errorf("organizations/%s: %w", orgID, err) + } + + isOwner, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "owner") + if err != nil { + return nil, err + } + isMember, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "member") if err != nil { return nil, err } + if !isOwner && !isMember { + return nil, ErrNoPermission + } userRelations, err := s.aclClient.GetOrganizationUsers(org.UID) if err != nil { @@ -627,7 +645,7 @@ func (s *service) ListOrganizationMemberships(ctx context.Context, ctxUserUID uu for _, userRelation := range userRelations { user, err := s.repository.GetUserByUID(ctx, userRelation.UID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", user.ID, err) } pbUser, err := s.DBUser2PBUser(ctx, user) if err != nil { @@ -648,12 +666,25 @@ func (s *service) ListOrganizationMemberships(ctx context.Context, ctxUserUID uu func (s *service) GetOrganizationMembership(ctx context.Context, ctxUserUID uuid.UUID, orgID string, userID string) (*mgmtPB.OrganizationMembership, error) { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", userID, err) } org, err := s.repository.GetOrganization(ctx, orgID) + if err != nil { + return nil, fmt.Errorf("organizations/%s: %w", orgID, err) + } + + isOwner, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "owner") if err != nil { return nil, err } + isMember, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "member") + if err != nil { + return nil, err + } + if !isOwner && !isMember { + return nil, ErrNoPermission + } + role, err := s.aclClient.GetOrganizationUserMembership(org.UID, user.UID) if err != nil { return nil, err @@ -681,19 +712,23 @@ func (s *service) GetOrganizationMembership(ctx context.Context, ctxUserUID uuid func (s *service) UpdateOrganizationMembership(ctx context.Context, ctxUserUID uuid.UUID, orgID string, userID string, membership *mgmtPB.OrganizationMembership) (*mgmtPB.OrganizationMembership, error) { user, err := s.repository.GetUser(ctx, userID) if err != nil { - return nil, err + return nil, fmt.Errorf("users/%s: %w", userID, err) } org, err := s.repository.GetOrganization(ctx, orgID) if err != nil { - return nil, err + return nil, fmt.Errorf("organizations/%s: %w", orgID, err) } - role, err := s.aclClient.GetOrganizationUserMembership(org.UID, ctxUserUID) + isOwner, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "owner") if err != nil { - return nil, status.Errorf(codes.PermissionDenied, "Permission Denied") + return nil, err } - if role != "owner" { - return nil, status.Errorf(codes.PermissionDenied, "Permission Denied") + if !isOwner { + return nil, ErrNoPermission + } + + if membership.Role == "owner" { + return nil, ErrCanNotSetAnotherOwner } pbUser, err := s.DBUser2PBUser(ctx, user) @@ -743,18 +778,22 @@ func (s *service) DeleteOrganizationMembership(ctx context.Context, ctxUserUID u user, err := s.repository.GetUser(ctx, userID) if err != nil { - return err + return fmt.Errorf("users/%s: %w", userID, err) } org, err := s.repository.GetOrganization(ctx, orgID) if err != nil { - return err + return fmt.Errorf("organizations/%s: %w", orgID, err) } - role, err := s.aclClient.GetOrganizationUserMembership(org.UID, ctxUserUID) + + isOwner, err := s.aclClient.CheckOrganizationUserMembership(org.UID, ctxUserUID, "owner") if err != nil { - return status.Errorf(codes.PermissionDenied, "Permission Denied") + return err + } + if !isOwner { + return ErrNoPermission } - if role != "owner" { - return status.Errorf(codes.PermissionDenied, "Permission Denied") + if isOwner && ctxUserUID == user.UID { + return ErrCanNotRemoveOwnerFromOrganization } err = s.aclClient.DeleteOrganizationUserMembership(org.UID, user.UID) if err != nil {