Skip to content

Commit

Permalink
Merge pull request #95 from eurofurence/issue-94-fixes-part1
Browse files Browse the repository at this point in the history
Issue 94 fixes part1
  • Loading branch information
Jumpy-Squirrel authored Nov 9, 2024
2 parents 88b0177 + 2ad6d4f commit 5c116fb
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 152 deletions.
4 changes: 1 addition & 3 deletions internal/application/common/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@ type AllClaims struct {
CustomClaims
}

const RequestIDKey = "RequestIDKey"

// GetRequestID extracts the request ID from the context.
//
// It originally comes from a header with the request, or is rolled while processing
// the request.
func GetRequestID(ctx context.Context) string {
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
if reqID, ok := ctx.Value(CtxKeyRequestID{}).(string); ok {
return reqID
}

Expand Down
3 changes: 3 additions & 0 deletions internal/application/middleware/reqid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"context"
aulogging "github.com/StephanHCB/go-autumn-logging"
"github.com/eurofurence/reg-room-service/internal/application/common"
"net/http"
"regexp"
Expand All @@ -18,6 +19,8 @@ var ValidRequestIdRegex = regexp.MustCompile("^[0-9a-f]{8}$")
// It also adds it to the response under the same header.
// This automatically also leads to all logging using this context to log the request id.
func RequestIdMiddleware(next http.Handler) http.Handler {
aulogging.RequestIdRetriever = common.GetRequestID

handlerFunc := func(w http.ResponseWriter, r *http.Request) {
reqUuidStr := r.Header.Get(RequestIDHeader)
if !ValidRequestIdRegex.MatchString(reqUuidStr) {
Expand Down
13 changes: 0 additions & 13 deletions internal/application/middleware/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ func fromApiTokenHeader(r *http.Request) string {
return r.Header.Get(common.ApiKeyHeader)
}

// TODO Remove after legacy system was replaced with 2FA
// See reference https://github.com/eurofurence/reg-room-service/issues/57
func storeAdminRequestHeaderIfAvailable(ctx context.Context, r *http.Request) context.Context {
adminHeader := r.Header.Get(adminRequestHeader)

if adminHeader == "" {
return ctx
}

return context.WithValue(ctx, common.CtxKeyAdminHeader{}, adminHeader)
}

// --- validating the individual pieces ---

// important - if any of these return an error, you must abort processing via "return" and log the error message
Expand Down Expand Up @@ -263,7 +251,6 @@ func CheckRequestAuthorization(conf *config.SecurityConfig) func(http.Handler) h
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

ctx = storeAdminRequestHeaderIfAvailable(ctx, r)
apiTokenHeaderValue := fromApiTokenHeader(r)
authHeaderValue := fromAuthHeader(r)
idTokenCookieValue := parseAuthCookie(r, conf.Oidc.IDTokenCookieName)
Expand Down
42 changes: 0 additions & 42 deletions internal/application/middleware/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,45 +275,3 @@ func TestCookiesExpiredJwt(t *testing.T) {
require.Nil(t, ctx.Value(common.CtxKeyClaims{}))
tstRequireNoAuthServiceCall(t)
}

// TODO Remove after legacy system was replaced with 2FA
// See reference https://github.com/eurofurence/reg-room-service/issues/57
func TestStoreAdminHeaderInContext(t *testing.T) {
docs.Description("stores the header value for legacy system admin calls")

tests := []struct {
name string
shouldStore bool
}{
{
name: "should not store value in context",
shouldStore: false,
},
{
name: "should not store value in context",
shouldStore: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()

r, err := http.NewRequest(http.MethodGet, "http://test.local", nil)
require.NoError(t, err)
if tt.shouldStore {
r.Header.Add(adminRequestHeader, "available")
}

ctx = storeAdminRequestHeaderIfAvailable(ctx, r)
val, ok := ctx.Value(common.CtxKeyAdminHeader{}).(string)
if tt.shouldStore {
require.True(t, ok)
require.NotEmpty(t, val)
} else {
require.False(t, ok)
require.Empty(t, val)
}
})
}
}
2 changes: 1 addition & 1 deletion internal/application/web/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func SendErrorResponse(ctx context.Context, w http.ResponseWriter, err error) {
// which contains relevant information about the failed request to the client.
// The function will also set the http status according to the provided status.
func SendAPIErrorResponse(ctx context.Context, w http.ResponseWriter, apiErr common.APIError) {
aulogging.InfoErrf(ctx, apiErr, fmt.Sprintf("api response status %d: %v", apiErr.Status(), apiErr))
aulogging.InfoErrf(ctx, apiErr, fmt.Sprintf("api response status %d: %v", apiErr.Status(), apiErr.Response()))
for _, cause := range apiErr.InternalCauses() {
aulogging.InfoErrf(ctx, cause, fmt.Sprintf("... caused by: %v", cause))
}
Expand Down
4 changes: 2 additions & 2 deletions internal/repository/database/historizeddb/implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (r *HistorizingRepository) AddGroup(ctx context.Context, group *entity.Grou
return r.wrappedRepository.AddGroup(ctx, group)
}

func (r *HistorizingRepository) FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
return r.wrappedRepository.FindGroups(ctx, minOccupancy, maxOccupancy, anyOfMemberID)
func (r *HistorizingRepository) FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
return r.wrappedRepository.FindGroups(ctx, name, minOccupancy, maxOccupancy, anyOfMemberID)
}

func (r *HistorizingRepository) UpdateGroup(ctx context.Context, group *entity.Group) error {
Expand Down
5 changes: 3 additions & 2 deletions internal/repository/database/inmemorydb/implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ func (r *InMemoryRepository) GetGroups(_ context.Context) ([]*entity.Group, erro
return result, nil
}

func (r *InMemoryRepository) FindGroups(_ context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
func (r *InMemoryRepository) FindGroups(_ context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
result := make([]string, 0)
for _, grp := range r.groups {
if !grp.Group.DeletedAt.Valid {
if len(grp.Members) >= int(minOccupancy) &&
(maxOccupancy == -1 || len(grp.Members) <= maxOccupancy) {
(maxOccupancy == -1 || len(grp.Members) <= maxOccupancy) &&
(name == "" || name == grp.Group.Name) {
matches := len(anyOfMemberID) == 0
for _, wantedID := range anyOfMemberID {
for _, actualMember := range grp.Members {
Expand Down
4 changes: 3 additions & 1 deletion internal/repository/database/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ type Repository interface {
//
// A group matches the list of badge numbers in anyOfMemberID if at least one of those badge numbers
// is a member of the group. An empty list or nil means no condition.
FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error)
//
// If name is not the empty string, finds only groups of that name.
FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error)
AddGroup(ctx context.Context, group *entity.Group) (string, error)
UpdateGroup(ctx context.Context, group *entity.Group) error
GetGroupByID(ctx context.Context, id string) (*entity.Group, error) // may return soft deleted entities!
Expand Down
26 changes: 15 additions & 11 deletions internal/repository/database/mysqldb/implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,21 @@ func (r *MysqlRepository) GetGroups(ctx context.Context) ([]*entity.Group, error
return getAllNonDeleted[entity.Group](ctx, r.db, groupDesc)
}

func (r *MysqlRepository) FindGroups(ctx context.Context, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
query, params := buildFindQuery(minOccupancy, maxOccupancy, anyOfMemberID)
func (r *MysqlRepository) FindGroups(ctx context.Context, name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) ([]string, error) {
query, params := buildFindQuery(name, minOccupancy, maxOccupancy, anyOfMemberID)

return r.findGroupIDsByQuery(ctx, query, params)
}

func buildFindQuery(minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) (string, map[string]any) {
func buildFindQuery(name string, minOccupancy uint, maxOccupancy int, anyOfMemberID []int64) (string, map[string]any) {
params := make(map[string]any)
query := strings.Builder{}
query.WriteString("SELECT g.id AS id FROM room_groups g WHERE (@use_named_params = 1) ")
params["use_named_params"] = 1 // must always have at least one named param, or you get an error when using a param map
if name != "" {
query.WriteString("AND name = @name ")
params["name"] = name
}
if minOccupancy > 0 {
query.WriteString("AND (SELECT count(*) FROM room_group_members m WHERE m.group_id = g.id) >= @min_occ ")
params["min_occ"] = minOccupancy
Expand Down Expand Up @@ -191,13 +195,13 @@ func (r *MysqlRepository) findGroupIDsByQuery(ctx context.Context, query string,
}()

for rows.Next() {
groupID := ""
err = r.db.ScanRows(rows, &groupID)
entry := entity.Group{}
err = r.db.ScanRows(rows, &entry)
if err != nil {
aulogging.Logger.Ctx(ctx).Error().WithErr(err).Printf("error reading group id during find: %s", err.Error())
return result, err
}
result = append(result, groupID)
result = append(result, entry.ID)
}

return result, nil
Expand Down Expand Up @@ -373,13 +377,13 @@ func (r *MysqlRepository) findRoomIDsByQuery(ctx context.Context, query string,
}()

for rows.Next() {
roomID := ""
err = r.db.ScanRows(rows, &roomID)
entry := entity.Room{}
err = r.db.ScanRows(rows, &entry)
if err != nil {
aulogging.Logger.Ctx(ctx).Error().WithErr(err).Printf("error reading group id during find: %s", err.Error())
return result, err
}
result = append(result, roomID)
result = append(result, entry.ID)
}

return result, nil
Expand Down Expand Up @@ -493,7 +497,7 @@ func getByID[E anyMemberCollection](
logDescription string,
) (*E, error) {
var g E
err := db.First(&g, id).Error
err := db.First(&g, "id = ?", id).Error
if err != nil {
aulogging.InfoErrf(ctx, err, "mysql error during %s select - might be ok: %s", logDescription, err.Error())
}
Expand All @@ -507,7 +511,7 @@ func deleteByID[E anyMemberCollection](
logDescription string,
) error {
var g E
err := db.First(&g, id).Error
err := db.First(&g, "id = ?", id).Error
if err != nil {
aulogging.WarnErrf(ctx, err, "mysql error during %s soft delete - %s not found: %s", logDescription, logDescription, err.Error())
return err
Expand Down
25 changes: 22 additions & 3 deletions internal/service/groups/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (g *groupService) FindGroups(ctx context.Context, minSize uint, maxSize int
func (g *groupService) findGroupsFullAccess(ctx context.Context, minSize uint, maxSize int, memberIDs []int64, publicOnly bool) ([]*modelsv1.Group, error) {
result := make([]*modelsv1.Group, 0)

groupIDs, err := g.DB.FindGroups(ctx, minSize, maxSize, memberIDs)
groupIDs, err := g.DB.FindGroups(ctx, "", minSize, maxSize, memberIDs)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return result, nil
Expand Down Expand Up @@ -254,10 +254,22 @@ func (g *groupService) CreateGroup(ctx context.Context, group *modelsv1.GroupCre
return "", common.NewBadRequest(ctx, common.GroupDataInvalid, validation)
}

// check for name conflicts
matchingIDs, err := g.DB.FindGroups(ctx, group.Name, 0, -1, nil)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return "", errGroupRead(ctx, err.Error())
}
}

if len(matchingIDs) > 0 {
return "", common.NewConflict(ctx, common.GroupDataDuplicate, common.Details("another group with this name already exists"))
}

// Create a new group in the database
groupID, err := g.DB.AddGroup(ctx, &entity.Group{
Name: group.Name,
Flags: fmt.Sprintf(",%s,", strings.Join(group.Flags, ",")),
Flags: collectFlags(group.Flags),
Comments: common.Deref(group.Comments),
MaximumSize: maxGroupSize(),
Owner: ownerID,
Expand Down Expand Up @@ -367,7 +379,7 @@ func (g *groupService) UpdateGroup(ctx context.Context, group *modelsv1.Group) e

// do not touch fields that we do not wish to change, like createdAt or referenced members
dbGroup.Name = group.Name
dbGroup.Flags = fmt.Sprintf(",%s,", strings.Join(group.Flags, ","))
dbGroup.Flags = collectFlags(group.Flags)
dbGroup.Comments = common.Deref(group.Comments)
dbGroup.MaximumSize = group.MaximumSize

Expand Down Expand Up @@ -519,6 +531,13 @@ func aggregateFlags(input string) []string {
return tags
}

func collectFlags(input []string) string {
if len(input) == 0 {
return ","
}
return fmt.Sprintf(",%s,", strings.Join(input, ","))
}

func errNoGroup(ctx context.Context) error {
return common.NewNotFound(ctx, common.GroupMemberNotFound, common.Details("not in a group"))
}
Expand Down
14 changes: 1 addition & 13 deletions internal/service/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewValidator(ctx context.Context) (Validator, error) {
manager.isUser = true

for _, group := range claims.Groups {
if group == conf.Security.Oidc.AdminGroup && hasValidAdminHeader(ctx) {
if group == conf.Security.Oidc.AdminGroup {
manager.isUser = false
manager.isAdmin = true
break
Expand All @@ -90,15 +90,3 @@ func NewValidator(ctx context.Context) (Validator, error) {

return manager, nil
}

// TODO remove after 2FA is available
// See reference https://github.com/eurofurence/reg-payment-service/issues/57
func hasValidAdminHeader(ctx context.Context) bool {
adminHeaderValue, ok := ctx.Value(common.CtxKeyAdminHeader{}).(string)
if !ok {
return false
}

// legacy system implementation requires check against constant value "available"
return adminHeaderValue == "available"
}
58 changes: 0 additions & 58 deletions internal/service/rbac/rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func TestNewRBACValidator(t *testing.T) {
inputJWT string
inputAPIKey string
inputClaims *common.AllClaims
includeAdminHeader bool
customAdminHeaderValue string
}

Expand Down Expand Up @@ -72,67 +71,13 @@ func TestNewRBACValidator(t *testing.T) {
EMail: "[email protected]",
},
},
includeAdminHeader: true,
},
expected: expected{
isAdmin: true,
subject: "123456",
roles: []string{"admin", "test"},
},
},
// TODO remove test case after 2FA is available
// See reference https://github.com/eurofurence/reg-payment-service/issues/57
{
name: "Should not create manager with admin role when no admin header is set",
args: args{
inputJWT: "valid",
inputAPIKey: "",
inputClaims: &common.AllClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: "123456",
},
CustomClaims: common.CustomClaims{
Groups: []string{"admin", "test"},
Name: "Peter",
EMail: "[email protected]",
},
},
includeAdminHeader: false,
},
expected: expected{
isAdmin: false,
isRegisteredUser: true,
subject: "123456",
roles: []string{"admin", "test"},
},
},
// TODO remove test case after 2FA is available
// See reference https://github.com/eurofurence/reg-payment-service/issues/57
{
name: "Should not create manager with admin role when no valid admin header is set",
args: args{
inputJWT: "valid",
inputAPIKey: "",
inputClaims: &common.AllClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: "123456",
},
CustomClaims: common.CustomClaims{
Groups: []string{"admin", "test"},
Name: "Peter",
EMail: "[email protected]",
},
},
includeAdminHeader: true,
customAdminHeaderValue: "test-12345",
},
expected: expected{
isAdmin: false,
isRegisteredUser: true,
subject: "123456",
roles: []string{"admin", "test"},
},
},
{
name: "Should create manager with registered user role",
args: args{
Expand Down Expand Up @@ -218,9 +163,6 @@ func TestNewRBACValidator(t *testing.T) {

if tt.args.inputClaims != nil {
ctx = context.WithValue(ctx, common.CtxKeyClaims{}, tt.args.inputClaims)
if tt.args.includeAdminHeader {
ctx = context.WithValue(ctx, common.CtxKeyAdminHeader{}, coalesce(tt.args.customAdminHeaderValue, "available"))
}
}

mgr, err := NewValidator(ctx)
Expand Down
Loading

0 comments on commit 5c116fb

Please sign in to comment.