diff --git a/internal/application/common/context.go b/internal/application/common/context.go index 6ea42ad..185043c 100644 --- a/internal/application/common/context.go +++ b/internal/application/common/context.go @@ -31,14 +31,12 @@ type AllClaims struct { CustomClaims } -const RequestIDKey = "RequestIDKey" - // GetRequestID extracts the request ID from the context. // // It originally comes from a header with the request, or is rolled while processing // the request. func GetRequestID(ctx context.Context) string { - if reqID, ok := ctx.Value(RequestIDKey).(string); ok { + if reqID, ok := ctx.Value(CtxKeyRequestID{}).(string); ok { return reqID } diff --git a/internal/application/middleware/reqid.go b/internal/application/middleware/reqid.go index e79403f..7e48e80 100644 --- a/internal/application/middleware/reqid.go +++ b/internal/application/middleware/reqid.go @@ -2,6 +2,7 @@ package middleware import ( "context" + aulogging "github.com/StephanHCB/go-autumn-logging" "github.com/eurofurence/reg-room-service/internal/application/common" "net/http" "regexp" @@ -18,6 +19,8 @@ var ValidRequestIdRegex = regexp.MustCompile("^[0-9a-f]{8}$") // It also adds it to the response under the same header. // This automatically also leads to all logging using this context to log the request id. func RequestIdMiddleware(next http.Handler) http.Handler { + aulogging.RequestIdRetriever = common.GetRequestID + handlerFunc := func(w http.ResponseWriter, r *http.Request) { reqUuidStr := r.Header.Get(RequestIDHeader) if !ValidRequestIdRegex.MatchString(reqUuidStr) { diff --git a/internal/application/middleware/security.go b/internal/application/middleware/security.go index 0e15a32..ff1ff54 100644 --- a/internal/application/middleware/security.go +++ b/internal/application/middleware/security.go @@ -73,18 +73,6 @@ func fromApiTokenHeader(r *http.Request) string { return r.Header.Get(common.ApiKeyHeader) } -// TODO Remove after legacy system was replaced with 2FA -// See reference https://github.com/eurofurence/reg-room-service/issues/57 -func storeAdminRequestHeaderIfAvailable(ctx context.Context, r *http.Request) context.Context { - adminHeader := r.Header.Get(adminRequestHeader) - - if adminHeader == "" { - return ctx - } - - return context.WithValue(ctx, common.CtxKeyAdminHeader{}, adminHeader) -} - // --- validating the individual pieces --- // important - if any of these return an error, you must abort processing via "return" and log the error message @@ -263,7 +251,6 @@ func CheckRequestAuthorization(conf *config.SecurityConfig) func(http.Handler) h return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = storeAdminRequestHeaderIfAvailable(ctx, r) apiTokenHeaderValue := fromApiTokenHeader(r) authHeaderValue := fromAuthHeader(r) idTokenCookieValue := parseAuthCookie(r, conf.Oidc.IDTokenCookieName) diff --git a/internal/application/middleware/security_test.go b/internal/application/middleware/security_test.go index 2e0910d..953402a 100644 --- a/internal/application/middleware/security_test.go +++ b/internal/application/middleware/security_test.go @@ -275,45 +275,3 @@ func TestCookiesExpiredJwt(t *testing.T) { require.Nil(t, ctx.Value(common.CtxKeyClaims{})) tstRequireNoAuthServiceCall(t) } - -// TODO Remove after legacy system was replaced with 2FA -// See reference https://github.com/eurofurence/reg-room-service/issues/57 -func TestStoreAdminHeaderInContext(t *testing.T) { - docs.Description("stores the header value for legacy system admin calls") - - tests := []struct { - name string - shouldStore bool - }{ - { - name: "should not store value in context", - shouldStore: false, - }, - { - name: "should not store value in context", - shouldStore: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - - r, err := http.NewRequest(http.MethodGet, "http://test.local", nil) - require.NoError(t, err) - if tt.shouldStore { - r.Header.Add(adminRequestHeader, "available") - } - - ctx = storeAdminRequestHeaderIfAvailable(ctx, r) - val, ok := ctx.Value(common.CtxKeyAdminHeader{}).(string) - if tt.shouldStore { - require.True(t, ok) - require.NotEmpty(t, val) - } else { - require.False(t, ok) - require.Empty(t, val) - } - }) - } -} diff --git a/internal/application/web/response.go b/internal/application/web/response.go index ea77e98..aaa08bf 100644 --- a/internal/application/web/response.go +++ b/internal/application/web/response.go @@ -46,7 +46,7 @@ func SendErrorResponse(ctx context.Context, w http.ResponseWriter, err error) { // which contains relevant information about the failed request to the client. // The function will also set the http status according to the provided status. func SendAPIErrorResponse(ctx context.Context, w http.ResponseWriter, apiErr common.APIError) { - aulogging.InfoErrf(ctx, apiErr, fmt.Sprintf("api response status %d: %v", apiErr.Status(), apiErr)) + aulogging.InfoErrf(ctx, apiErr, fmt.Sprintf("api response status %d: %v", apiErr.Status(), apiErr.Response())) for _, cause := range apiErr.InternalCauses() { aulogging.InfoErrf(ctx, cause, fmt.Sprintf("... caused by: %v", cause)) } diff --git a/internal/repository/database/historizeddb/implementation.go b/internal/repository/database/historizeddb/implementation.go index cd8e5f8..8cf0f4b 100644 --- a/internal/repository/database/historizeddb/implementation.go +++ b/internal/repository/database/historizeddb/implementation.go @@ -59,8 +59,8 @@ func (r *HistorizingRepository) AddGroup(ctx context.Context, group *entity.Grou return r.wrappedRepository.AddGroup(ctx, group) } -func (r *HistorizingRepository) FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { - return r.wrappedRepository.FindGroups(ctx, minOccupancy, maxOccupancy, anyOfMemberID) +func (r *HistorizingRepository) FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { + return r.wrappedRepository.FindGroups(ctx, name, minOccupancy, maxOccupancy, anyOfMemberID) } func (r *HistorizingRepository) UpdateGroup(ctx context.Context, group *entity.Group) error { diff --git a/internal/repository/database/inmemorydb/implementation.go b/internal/repository/database/inmemorydb/implementation.go index 996df9d..8a115bb 100644 --- a/internal/repository/database/inmemorydb/implementation.go +++ b/internal/repository/database/inmemorydb/implementation.go @@ -71,12 +71,13 @@ func (r *InMemoryRepository) GetGroups(_ context.Context) ([]*entity.Group, erro return result, nil } -func (r *InMemoryRepository) FindGroups(_ context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { +func (r *InMemoryRepository) FindGroups(_ context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { result := make([]string, 0) for _, grp := range r.groups { if !grp.Group.DeletedAt.Valid { if len(grp.Members) >= int(minOccupancy) && - (maxOccupancy == -1 || len(grp.Members) <= maxOccupancy) { + (maxOccupancy == -1 || len(grp.Members) <= maxOccupancy) && + (name == "" || name == grp.Group.Name) { matches := len(anyOfMemberID) == 0 for _, wantedID := range anyOfMemberID { for _, actualMember := range grp.Members { diff --git a/internal/repository/database/interface.go b/internal/repository/database/interface.go index 825d8d7..af1849f 100644 --- a/internal/repository/database/interface.go +++ b/internal/repository/database/interface.go @@ -20,7 +20,9 @@ type Repository interface { // // A group matches the list of badge numbers in anyOfMemberID if at least one of those badge numbers // is a member of the group. An empty list or nil means no condition. - FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) + // + // If name is not the empty string, finds only groups of that name. + FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) AddGroup(ctx context.Context, group *entity.Group) (string, error) UpdateGroup(ctx context.Context, group *entity.Group) error GetGroupByID(ctx context.Context, id string) (*entity.Group, error) // may return soft deleted entities! diff --git a/internal/repository/database/mysqldb/implementation.go b/internal/repository/database/mysqldb/implementation.go index e97fa26..cd399c0 100644 --- a/internal/repository/database/mysqldb/implementation.go +++ b/internal/repository/database/mysqldb/implementation.go @@ -146,17 +146,21 @@ func (r *MysqlRepository) GetGroups(ctx context.Context) ([]*entity.Group, error return getAllNonDeleted[entity.Group](ctx, r.db, groupDesc) } -func (r *MysqlRepository) FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { - query, params := buildFindQuery(minOccupancy, maxOccupancy, anyOfMemberID) +func (r *MysqlRepository) FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) { + query, params := buildFindQuery(name, minOccupancy, maxOccupancy, anyOfMemberID) return r.findGroupIDsByQuery(ctx, query, params) } -func buildFindQuery(minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) (string, map[string]any) { +func buildFindQuery(name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) (string, map[string]any) { params := make(map[string]any) query := strings.Builder{} query.WriteString("SELECT g.id AS id FROM room_groups g WHERE (@use_named_params = 1) ") params["use_named_params"] = 1 // must always have at least one named param, or you get an error when using a param map + if name != "" { + query.WriteString("AND name = @name ") + params["name"] = name + } if minOccupancy > 0 { query.WriteString("AND (SELECT count(*) FROM room_group_members m WHERE m.group_id = g.id) >= @min_occ ") params["min_occ"] = minOccupancy @@ -191,13 +195,13 @@ func (r *MysqlRepository) findGroupIDsByQuery(ctx context.Context, query string, }() for rows.Next() { - groupID := "" - err = r.db.ScanRows(rows, &groupID) + entry := entity.Group{} + err = r.db.ScanRows(rows, &entry) if err != nil { aulogging.Logger.Ctx(ctx).Error().WithErr(err).Printf("error reading group id during find: %s", err.Error()) return result, err } - result = append(result, groupID) + result = append(result, entry.ID) } return result, nil @@ -373,13 +377,13 @@ func (r *MysqlRepository) findRoomIDsByQuery(ctx context.Context, query string, }() for rows.Next() { - roomID := "" - err = r.db.ScanRows(rows, &roomID) + entry := entity.Room{} + err = r.db.ScanRows(rows, &entry) if err != nil { aulogging.Logger.Ctx(ctx).Error().WithErr(err).Printf("error reading group id during find: %s", err.Error()) return result, err } - result = append(result, roomID) + result = append(result, entry.ID) } return result, nil @@ -493,7 +497,7 @@ func getByID[E anyMemberCollection]( logDescription string, ) (*E, error) { var g E - err := db.First(&g, id).Error + err := db.First(&g, "id = ?", id).Error if err != nil { aulogging.InfoErrf(ctx, err, "mysql error during %s select - might be ok: %s", logDescription, err.Error()) } @@ -507,7 +511,7 @@ func deleteByID[E anyMemberCollection]( logDescription string, ) error { var g E - err := db.First(&g, id).Error + err := db.First(&g, "id = ?", id).Error if err != nil { aulogging.WarnErrf(ctx, err, "mysql error during %s soft delete - %s not found: %s", logDescription, logDescription, err.Error()) return err diff --git a/internal/service/groups/groups.go b/internal/service/groups/groups.go index a13f585..4d59221 100644 --- a/internal/service/groups/groups.go +++ b/internal/service/groups/groups.go @@ -113,7 +113,7 @@ func (g *groupService) FindGroups(ctx context.Context, minSize uint, maxSize int func (g *groupService) findGroupsFullAccess(ctx context.Context, minSize uint, maxSize int, memberIDs []int64, publicOnly bool) ([]*modelsv1.Group, error) { result := make([]*modelsv1.Group, 0) - groupIDs, err := g.DB.FindGroups(ctx, minSize, maxSize, memberIDs) + groupIDs, err := g.DB.FindGroups(ctx, "", minSize, maxSize, memberIDs) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return result, nil @@ -254,10 +254,22 @@ func (g *groupService) CreateGroup(ctx context.Context, group *modelsv1.GroupCre return "", common.NewBadRequest(ctx, common.GroupDataInvalid, validation) } + // check for name conflicts + matchingIDs, err := g.DB.FindGroups(ctx, group.Name, 0, -1, nil) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return "", errGroupRead(ctx, err.Error()) + } + } + + if len(matchingIDs) > 0 { + return "", common.NewConflict(ctx, common.GroupDataDuplicate, common.Details("another group with this name already exists")) + } + // Create a new group in the database groupID, err := g.DB.AddGroup(ctx, &entity.Group{ Name: group.Name, - Flags: fmt.Sprintf(",%s,", strings.Join(group.Flags, ",")), + Flags: collectFlags(group.Flags), Comments: common.Deref(group.Comments), MaximumSize: maxGroupSize(), Owner: ownerID, @@ -367,7 +379,7 @@ func (g *groupService) UpdateGroup(ctx context.Context, group *modelsv1.Group) e // do not touch fields that we do not wish to change, like createdAt or referenced members dbGroup.Name = group.Name - dbGroup.Flags = fmt.Sprintf(",%s,", strings.Join(group.Flags, ",")) + dbGroup.Flags = collectFlags(group.Flags) dbGroup.Comments = common.Deref(group.Comments) dbGroup.MaximumSize = group.MaximumSize @@ -519,6 +531,13 @@ func aggregateFlags(input string) []string { return tags } +func collectFlags(input []string) string { + if len(input) == 0 { + return "," + } + return fmt.Sprintf(",%s,", strings.Join(input, ",")) +} + func errNoGroup(ctx context.Context) error { return common.NewNotFound(ctx, common.GroupMemberNotFound, common.Details("not in a group")) } diff --git a/internal/service/rbac/rbac.go b/internal/service/rbac/rbac.go index 473c825..2c79fc4 100644 --- a/internal/service/rbac/rbac.go +++ b/internal/service/rbac/rbac.go @@ -80,7 +80,7 @@ func NewValidator(ctx context.Context) (Validator, error) { manager.isUser = true for _, group := range claims.Groups { - if group == conf.Security.Oidc.AdminGroup && hasValidAdminHeader(ctx) { + if group == conf.Security.Oidc.AdminGroup { manager.isUser = false manager.isAdmin = true break @@ -90,15 +90,3 @@ func NewValidator(ctx context.Context) (Validator, error) { return manager, nil } - -// TODO remove after 2FA is available -// See reference https://github.com/eurofurence/reg-payment-service/issues/57 -func hasValidAdminHeader(ctx context.Context) bool { - adminHeaderValue, ok := ctx.Value(common.CtxKeyAdminHeader{}).(string) - if !ok { - return false - } - - // legacy system implementation requires check against constant value "available" - return adminHeaderValue == "available" -} diff --git a/internal/service/rbac/rbac_test.go b/internal/service/rbac/rbac_test.go index 492a19c..904ccfb 100644 --- a/internal/service/rbac/rbac_test.go +++ b/internal/service/rbac/rbac_test.go @@ -29,7 +29,6 @@ func TestNewRBACValidator(t *testing.T) { inputJWT string inputAPIKey string inputClaims *common.AllClaims - includeAdminHeader bool customAdminHeaderValue string } @@ -72,7 +71,6 @@ func TestNewRBACValidator(t *testing.T) { EMail: "peter@peter.eu", }, }, - includeAdminHeader: true, }, expected: expected{ isAdmin: true, @@ -80,59 +78,6 @@ func TestNewRBACValidator(t *testing.T) { roles: []string{"admin", "test"}, }, }, - // TODO remove test case after 2FA is available - // See reference https://github.com/eurofurence/reg-payment-service/issues/57 - { - name: "Should not create manager with admin role when no admin header is set", - args: args{ - inputJWT: "valid", - inputAPIKey: "", - inputClaims: &common.AllClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "123456", - }, - CustomClaims: common.CustomClaims{ - Groups: []string{"admin", "test"}, - Name: "Peter", - EMail: "peter@peter.eu", - }, - }, - includeAdminHeader: false, - }, - expected: expected{ - isAdmin: false, - isRegisteredUser: true, - subject: "123456", - roles: []string{"admin", "test"}, - }, - }, - // TODO remove test case after 2FA is available - // See reference https://github.com/eurofurence/reg-payment-service/issues/57 - { - name: "Should not create manager with admin role when no valid admin header is set", - args: args{ - inputJWT: "valid", - inputAPIKey: "", - inputClaims: &common.AllClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: "123456", - }, - CustomClaims: common.CustomClaims{ - Groups: []string{"admin", "test"}, - Name: "Peter", - EMail: "peter@peter.eu", - }, - }, - includeAdminHeader: true, - customAdminHeaderValue: "test-12345", - }, - expected: expected{ - isAdmin: false, - isRegisteredUser: true, - subject: "123456", - roles: []string{"admin", "test"}, - }, - }, { name: "Should create manager with registered user role", args: args{ @@ -218,9 +163,6 @@ func TestNewRBACValidator(t *testing.T) { if tt.args.inputClaims != nil { ctx = context.WithValue(ctx, common.CtxKeyClaims{}, tt.args.inputClaims) - if tt.args.includeAdminHeader { - ctx = context.WithValue(ctx, common.CtxKeyAdminHeader{}, coalesce(tt.args.customAdminHeaderValue, "available")) - } } mgr, err := NewValidator(ctx) diff --git a/internal/service/rooms/rooms.go b/internal/service/rooms/rooms.go index ed9b0b9..989b4ed 100644 --- a/internal/service/rooms/rooms.go +++ b/internal/service/rooms/rooms.go @@ -156,9 +156,21 @@ func (r *roomService) CreateRoom(ctx context.Context, room *modelsv1.RoomCreate) return "", common.NewBadRequest(ctx, common.RoomDataInvalid, validation) } + // check for name conflicts + matchingIDs, err := r.DB.FindRooms(ctx, room.Name, 0, -1, 0, 0, nil) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return "", errRoomRead(ctx, err.Error()) + } + } + + if len(matchingIDs) > 0 { + return "", common.NewConflict(ctx, common.RoomDataDuplicate, common.Details("another room with this name already exists")) + } + roomID, err := r.DB.AddRoom(ctx, &entity.Room{ Name: room.Name, - Flags: fmt.Sprintf(",%s,", strings.Join(room.Flags, ",")), + Flags: collectFlags(room.Flags), Comments: common.Deref(room.Comments), Size: room.Size, }) @@ -223,7 +235,7 @@ func (r *roomService) UpdateRoom(ctx context.Context, room *modelsv1.Room) error // do not touch fields that we do not wish to change, like createdAt or referenced occupants dbRoom.Name = room.Name - dbRoom.Flags = fmt.Sprintf(",%s,", strings.Join(room.Flags, ",")) + dbRoom.Flags = collectFlags(room.Flags) dbRoom.Comments = common.Deref(room.Comments) dbRoom.Size = room.Size @@ -328,6 +340,13 @@ func aggregateFlags(input string) []string { return tags } +func collectFlags(input []string) string { + if len(input) == 0 { + return "," + } + return fmt.Sprintf(",%s,", strings.Join(input, ",")) +} + func toOccupants(roomMembers []*entity.RoomMember) []modelsv1.Member { members := make([]modelsv1.Member, 0) for _, m := range roomMembers { diff --git a/test/acceptance/groups_delete_test.go b/test/acceptance/groups_delete_test.go index f42678a..61a8e13 100644 --- a/test/acceptance/groups_delete_test.go +++ b/test/acceptance/groups_delete_test.go @@ -68,7 +68,7 @@ func TestGroupsDelete_UserNotMemberDeny(t *testing.T) { docs.Given("Given an attendee with an active registration who is in a group") id1 := setupExistingGroup(t, "kittens", false, "101") - _ = setupExistingGroup(t, "kittens", false, "202") + _ = setupExistingGroup(t, "puppies", false, "202") token := tstValidUserToken(t, 202) docs.When("When they attempt to delete a different group they are not a member of")