Skip to content

Commit

Permalink
AES => GCM && reset GCM cipher.AEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
molon committed Sep 25, 2024
1 parent d18e4ba commit 8dacd17
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 51 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
## Features

- **Supports keyset-based and offset-based pagination**: You can freely choose high-performance keyset pagination based on multiple indexed columns, or use offset pagination.
- **Optional cursor encryption**: Supports encrypting cursors using `AES` or `Base64` to ensure the security of pagination information.
- **Optional cursor encryption**: Supports encrypting cursors using `GCM(AES)` or `Base64` to ensure the security of pagination information.
- **Flexible query strategies**: Optionally skip the `TotalCount` query to improve performance, especially in large datasets.
- **Non-generic support**: Even without using Go generics, you can paginate using the `any` type for flexible use cases.

Expand All @@ -33,14 +33,16 @@ resp, err := p.Paginate(context.Background(), &relay.PaginateRequest[*User]{

### Middleware

If you need to encrypt cursors, you can use `cursor.Base64` or `cursor.AES` middlewares:
If you need to encrypt cursors, you can use `cursor.Base64` or `cursor.GCM` middlewares:

```go
// Encrypt cursors with Base64
cursor.Base64(gormrelay.NewOffsetAdapter[*User](db))

// Encrypt cursors with AES
cursor.AES(encryptionKey)(gormrelay.NewKeysetAdapter[*User](db))
// Encrypt cursors with GCM(AES)
gcm, err := cursor.NewGCM(encryptionKey)
require.NoError(t, err)
cursor.GCM(gcm)(gormrelay.NewKeysetAdapter[*User](db))
```

If you need to append `PrimaryOrderBys` to `PaginateRequest.OrderBys`
Expand Down
53 changes: 24 additions & 29 deletions cursor/aes.go → cursor/gcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,20 @@ import (
"github.com/theplant/relay"
)

func encryptAES(plainText string, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", errors.New("could not create cipher block")
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}

func encryptGCM(gcm cipher.AEAD, plainText string) (string, error) {
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
return "", errors.Wrap(err, "could not generate nonce")
}

cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
return base64.RawURLEncoding.EncodeToString(cipherText), nil
}

func decryptAES(cipherText string, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", errors.New("could not create cipher block")
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}

func decryptGCM(gcm cipher.AEAD, cipherText string) (string, error) {
decodedCipherText, err := base64.RawURLEncoding.DecodeString(cipherText)
if err != nil {
return "", err
return "", errors.Wrap(err, "could not decode cipher text")
}

nonceSize := gcm.NonceSize()
Expand All @@ -57,25 +37,40 @@ func decryptAES(cipherText string, key []byte) (string, error) {
nonce, dataCipherText := decodedCipherText[:nonceSize], decodedCipherText[nonceSize:]
plainText, err := gcm.Open(nil, nonce, dataCipherText, nil)
if err != nil {
return "", err
return "", errors.Wrap(err, "could not decrypt cipher text")
}

return string(plainText), nil
}

func AES[T any](encryptionKey []byte) relay.CursorMiddleware[T] {
// NewGCM creates a new GCM cipher
// Concurrent safe: https://github.com/golang/go/issues/41689
func NewGCM(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, errors.New("could not create cipher block")
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, errors.New("could not create GCM")
}
return gcm, nil
}

func GCM[T any](gcm cipher.AEAD) relay.CursorMiddleware[T] {
return func(next relay.ApplyCursorsFunc[T]) relay.ApplyCursorsFunc[T] {
return func(ctx context.Context, req *relay.ApplyCursorsRequest) (*relay.ApplyCursorsResponse[T], error) {
if req.After != nil {
decodedCursor, err := decryptAES(*req.After, encryptionKey)
decodedCursor, err := decryptGCM(gcm, *req.After)
if err != nil {
return nil, errors.Wrap(err, "invalid after cursor")
}
req.After = lo.ToPtr(decodedCursor)
}

if req.Before != nil {
decodedCursor, err := decryptAES(*req.Before, encryptionKey)
decodedCursor, err := decryptGCM(gcm, *req.Before)
if err != nil {
return nil, errors.Wrap(err, "invalid before cursor")
}
Expand All @@ -95,7 +90,7 @@ func AES[T any](encryptionKey []byte) relay.CursorMiddleware[T] {
if err != nil {
return "", err
}
encryptedCursor, err := encryptAES(cursor, encryptionKey)
encryptedCursor, err := encryptGCM(gcm, cursor)
if err != nil {
return "", err
}
Expand Down
20 changes: 12 additions & 8 deletions cursor/aes_test.go → cursor/gcm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,45 @@ import (
"io"
"testing"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
)

func generateAESKey(length int) ([]byte, error) {
func generateGCMKey(length int) ([]byte, error) {
key := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, err
return nil, errors.Wrap(err, "could not generate key")
}
return key, nil
}

func TestAES(t *testing.T) {
aesKey, err := generateAESKey(32)
func TestGCM(t *testing.T) {
gcmKey, err := generateGCMKey(32)
require.NoError(t, err)

gcm, err := NewGCM(gcmKey)
require.NoError(t, err)

plainText := `{"ID":225}`

{
cipherText, err := encryptAES(plainText, aesKey)
cipherText, err := encryptGCM(gcm, plainText)
require.NoError(t, err)

t.Logf("cipherText: %s", cipherText)

decryptedText, err := decryptAES(cipherText, aesKey)
decryptedText, err := decryptGCM(gcm, cipherText)
require.NoError(t, err)
require.Equal(t, plainText, decryptedText)
}

{
cipherText, err := encryptAES(base64.RawURLEncoding.EncodeToString([]byte(plainText)), aesKey)
cipherText, err := encryptGCM(gcm, base64.RawURLEncoding.EncodeToString([]byte(plainText)))
require.NoError(t, err)

t.Logf("cipherText: %s", cipherText)

decryptedText, err := decryptAES(cipherText, aesKey)
decryptedText, err := decryptGCM(gcm, cipherText)
require.NoError(t, err)

plainTextData, err := base64.RawURLEncoding.DecodeString(decryptedText)
Expand Down
26 changes: 16 additions & 10 deletions gormrelay/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,10 @@ func TestTotalCountZero(t *testing.T) {
t.Run("offset", func(t *testing.T) { testCase(t, NewOffsetAdapter) })
}

func generateAESKey(length int) ([]byte, error) {
func generateGCMKey(length int) ([]byte, error) {
key := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, err
return nil, errors.Wrap(err, "could not generate key")
}
return key, nil
}
Expand Down Expand Up @@ -320,19 +320,22 @@ func TestMiddleware(t *testing.T) {
})
})

t.Run("AES", func(t *testing.T) {
encryptionKey, err := generateAESKey(32)
t.Run("GCM", func(t *testing.T) {
encryptionKey, err := generateGCMKey(32)
require.NoError(t, err)

gcm, err := cursor.NewGCM(encryptionKey)
require.NoError(t, err)

t.Run("keyset", func(t *testing.T) {
testCase(t, func(db *gorm.DB) relay.ApplyCursorsFunc[*User] {
return cursor.AES[*User](encryptionKey)(NewKeysetAdapter[*User](db))
return cursor.GCM[*User](gcm)(NewKeysetAdapter[*User](db))
})
})

t.Run("offset", func(t *testing.T) {
testCase(t, func(db *gorm.DB) relay.ApplyCursorsFunc[*User] {
return cursor.AES[*User](encryptionKey)(NewOffsetAdapter[*User](db))
return cursor.GCM[*User](gcm)(NewOffsetAdapter[*User](db))
})
})
})
Expand Down Expand Up @@ -376,19 +379,22 @@ func TestMiddleware(t *testing.T) {
func TestAppendCursorMiddleware(t *testing.T) {
resetDB(t)

encryptionKey, err := generateAESKey(32)
encryptionKey, err := generateGCMKey(32)
require.NoError(t, err)

gcm, err := cursor.NewGCM(encryptionKey)
require.NoError(t, err)

aesMiddleware := cursor.AES[*User](encryptionKey)
gcmMiddleware := cursor.GCM[*User](gcm)

testCase := func(t *testing.T, f func(db *gorm.DB) relay.ApplyCursorsFunc[*User]) {
p := relay.New(
false,
10, 10,
f(db),
)
p = relay.AppendCursorMiddleware(aesMiddleware)(p) // test add single middleware
p = relay.AppendCursorMiddleware(cursor.Base64[*User], aesMiddleware)(p) // test add multiple middlewares
p = relay.AppendCursorMiddleware(gcmMiddleware)(p) // test add single middleware
p = relay.AppendCursorMiddleware(cursor.Base64[*User], gcmMiddleware)(p) // test add multiple middlewares
p = relay.PrimaryOrderBy[*User](relay.OrderBy{Field: "ID", Desc: false})(p) // test a pagination middleware

resp, err := p.Paginate(context.Background(), &relay.PaginateRequest[*User]{
Expand Down

0 comments on commit 8dacd17

Please sign in to comment.