diff --git a/orm/composite_create.go b/orm/composite_create.go index fd60a94e..2993d306 100644 --- a/orm/composite_create.go +++ b/orm/composite_create.go @@ -53,6 +53,7 @@ func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/composite_drop.go b/orm/composite_drop.go index 2a169b07..1873dcf8 100644 --- a/orm/composite_drop.go +++ b/orm/composite_drop.go @@ -50,6 +50,7 @@ func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/delete.go b/orm/delete.go index c54cd10f..834ef456 100644 --- a/orm/delete.go +++ b/orm/delete.go @@ -52,6 +52,7 @@ func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/insert.go b/orm/insert.go index a7a54357..1876ca8f 100644 --- a/orm/insert.go +++ b/orm/insert.go @@ -56,6 +56,7 @@ func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) { var _ QueryAppender = (*InsertQuery)(nil) func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/query.go b/orm/query.go index 37ae194d..4f64e732 100644 --- a/orm/query.go +++ b/orm/query.go @@ -81,6 +81,7 @@ type Query struct { onConflict *SafeQueryAppender returning []*SafeQueryAppender + comment string } func NewQuery(db DB, model ...interface{}) *Query { @@ -779,6 +780,12 @@ func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { return qq } +// Comment adds a comment to the query, wrapped by /* ... */. +func (q *Query) Comment(c string) *Query { + q.comment = c + return q +} + // Count returns number of rows matching the query using count aggregate function. func (q *Query) Count() (int, error) { if q.stickyErr != nil { diff --git a/orm/select.go b/orm/select.go index d3b38742..8d11f219 100644 --- a/orm/select.go +++ b/orm/select.go @@ -53,6 +53,7 @@ func (q *SelectQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *SelectQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { //nolint:gocyclo + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/table_create.go b/orm/table_create.go index 384c729d..d846ab66 100644 --- a/orm/table_create.go +++ b/orm/table_create.go @@ -64,6 +64,7 @@ func (q *CreateTableQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *CreateTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/table_drop.go b/orm/table_drop.go index 599ac395..140e0a2e 100644 --- a/orm/table_drop.go +++ b/orm/table_drop.go @@ -50,6 +50,7 @@ func (q *DropTableQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *DropTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/update.go b/orm/update.go index ce6396fd..406a7c89 100644 --- a/orm/update.go +++ b/orm/update.go @@ -57,6 +57,7 @@ func (q *UpdateQuery) AppendTemplate(b []byte) ([]byte, error) { } func (q *UpdateQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.q.comment) if q.q.stickyErr != nil { return nil, q.q.stickyErr } diff --git a/orm/util.go b/orm/util.go index b7963ba0..a1b5f82f 100644 --- a/orm/util.go +++ b/orm/util.go @@ -1,7 +1,9 @@ package orm import ( + "fmt" "reflect" + "strings" "github.com/go-pg/pg/v10/types" ) @@ -149,3 +151,13 @@ func appendColumns(b []byte, table types.Safe, fields []*Field) []byte { } return b } + +// appendComment adds comment in the header of the query into buffer +func appendComment(b []byte, name string) []byte { + if name == "" { + return b + } + name = strings.ReplaceAll(name, `/*`, `/\*`) + name = strings.ReplaceAll(name, `*/`, `*\/`) + return append(b, fmt.Sprintf("/* %s */ ", name)...) +} diff --git a/orm/util_test.go b/orm/util_test.go new file mode 100644 index 00000000..8612a181 --- /dev/null +++ b/orm/util_test.go @@ -0,0 +1,30 @@ +package orm + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Comment - escape comment symbols", func() { + It("only open sequence", func() { + var res []byte + c := "/* comment" + + s := appendComment(res, c) + Expect(s).To(Equal([]byte("/* /\\* comment */ "))) + }) + It("only close sequence", func() { + var res []byte + c := "comment */" + + s := appendComment(res, c) + Expect(s).To(Equal([]byte("/* comment *\\/ */ "))) + }) + It("open and close sequences", func() { + var res []byte + c := "/* comment */" + + s := appendComment(res, c) + Expect(s).To(Equal([]byte("/* /\\* comment *\\/ */ "))) + }) +})