Skip to content

Commit

Permalink
Introduce DB#InsertObtainID() function
Browse files Browse the repository at this point in the history
  • Loading branch information
yhabteab committed Sep 26, 2024
1 parent afba056 commit 1317301
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
12 changes: 12 additions & 0 deletions database/contracts.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package database

import (
"context"
"github.com/jmoiron/sqlx"
)

// Entity is implemented by each type that works with the database package.
type Entity interface {
Fingerprinter
Expand Down Expand Up @@ -54,3 +59,10 @@ type PgsqlOnConflictConstrainter interface {
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
PgsqlOnConflictConstraint() string
}

// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance.
type TxOrDB interface {
sqlx.ExtContext

PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
6 changes: 6 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,9 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio
db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed)
}))
}

var (
// Assert TxOrDB interface compliance of the DB and sqlx.Tx types.
_ TxOrDB = (*DB)(nil)
_ TxOrDB = (*sqlx.Tx)(nil)
)
37 changes: 37 additions & 0 deletions database/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/strcase"
"github.com/icinga/icinga-go-library/types"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -44,6 +45,42 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
}
}

// InsertObtainID executes the given query and fetches the last inserted ID.
//
// Using this method for database tables that don't define an auto-incrementing ID, or none at all,
// will not work. The only supported column that can be retrieved with this method is id.
//
// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance.
//
// Returns the retrieved ID on success and error on any database inserting/retrieving failure.
func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int64, error) {
var resultID int64
switch conn.DriverName() {
case PostgreSQL:
stmt = stmt + " RETURNING id"
query, args, err := conn.BindNamed(stmt, arg)
if err != nil {
return 0, errors.Wrapf(err, "can't bind named query %q", stmt)
}

if err := sqlx.GetContext(ctx, conn, &resultID, query, args...); err != nil {
return 0, CantPerformQuery(err, query)
}
default:
result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg)
if err != nil {
return 0, CantPerformQuery(err, stmt)
}

resultID, err = result.LastInsertId()
if err != nil {
return 0, errors.Wrap(err, "can't retrieve last inserted ID")
}
}

return resultID, nil
}

// unsafeSetSessionVariableIfExists sets the given MySQL/MariaDB system variable for the specified database session.
//
// NOTE: It is unsafe to use this function with untrusted/user supplied inputs and poses an SQL injection,
Expand Down
51 changes: 44 additions & 7 deletions database/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,48 @@ import (
"time"
)

func TestSetMysqlSessionVars(t *testing.T) {
func TestDatabaseUtils(t *testing.T) {
t.Parallel()

ctx := context.Background()
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")

t.Run("SetMySQLSessionVars", func(t *testing.T) {
t.Parallel()
if db.DriverName() != MySQL {
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
}

setMysqlSessionVars(ctx, db, t)
})

t.Run("InsertObtainID", func(t *testing.T) {
t.Parallel()

defer func() {
_, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS igl_test_insert_obtain")
assert.NoError(t, err, "dropping test database table should not fail")
}()

var err error
if db.DriverName() == PostgreSQL {
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id SERIAL PRIMARY KEY, name VARCHAR(255))")
} else {
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))")
}
require.NoError(t, err, "creating test database table should not fail")

id, err := InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test1"})
require.NoError(t, err, "inserting new row into test database table should not fail")
assert.Equal(t, id, int64(1))

id, err = InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test2"})
require.NoError(t, err, "inserting new row into test database table should not fail")
assert.Equal(t, id, int64(2))
})
}

func setMysqlSessionVars(ctx context.Context, db *DB, t *testing.T) {
vars := map[string][]struct {
name string
value string
Expand Down Expand Up @@ -45,14 +86,10 @@ func TestSetMysqlSessionVars(t *testing.T) {
},
}

ctx := context.Background()
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")
if db.DriverName() != MySQL {
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
}

for name, vs := range vars {
t.Run(name, func(t *testing.T) {
t.Parallel()

for _, v := range vs {
conn, err := db.DB.Conn(ctx)
require.NoError(t, err, "connecting to MySQL/MariaDB database should not fail")
Expand Down

0 comments on commit 1317301

Please sign in to comment.