From 4fa9cc6405287861994d8b0f643b59335b4c1f7f Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 20 Apr 2024 22:31:06 -0400 Subject: [PATCH 1/3] Add StoreController to allowing extending interface --- database/controller.go | 73 +++++++++++++++++++ database/dialect.go | 16 +++- internal/dialect/dialectquery/dialectquery.go | 42 +++++++++++ internal/dialect/dialectquery/postgres.go | 4 + provider.go | 4 +- provider_run.go | 30 ++------ 6 files changed, 143 insertions(+), 26 deletions(-) create mode 100644 database/controller.go diff --git a/database/controller.go b/database/controller.go new file mode 100644 index 000000000..76d11a70f --- /dev/null +++ b/database/controller.go @@ -0,0 +1,73 @@ +package database + +import ( + "context" + "errors" +) + +// ErrNotSupported is returned when an optional method is not supported by the Store implementation. +var ErrNotSupported = errors.New("not supported") + +// A StoreController is used by the goose package to interact with a database. This type is a +// wrapper around the Store interface, but can be extended to include additional (optional) methods +// that are not part of the core Store interface. +type StoreController struct { + store Store +} + +var _ Store = (*StoreController)(nil) + +// NewStoreController returns a new StoreController that wraps the given Store. +// +// If the Store implements the following optional methods, the StoreController will call them as +// appropriate: +// +// - TableExists(context.Context, DBTxConn) (bool, error) +// +// If the Store does not implement a method, it will either return a [ErrNotSupported] error or fall +// back to the default behavior. +func NewStoreController(store Store) *StoreController { + return &StoreController{store: store} +} + +// TableExists is an optional method that checks if the version table exists in the database. It is +// recommended to implement this method if the database supports it, as it can be used to optimize +// certain operations. +func (c *StoreController) TableExists(ctx context.Context, db DBTxConn) (bool, error) { + if t, ok := c.store.(interface { + TableExists(context.Context, DBTxConn, string) (bool, error) + }); ok { + return t.TableExists(ctx, db, c.Tablename()) + } + return false, ErrNotSupported +} + +// Default methods + +func (c *StoreController) Tablename() string { + return c.store.Tablename() +} + +func (c *StoreController) CreateVersionTable(ctx context.Context, db DBTxConn) error { + return c.store.CreateVersionTable(ctx, db) +} + +func (c *StoreController) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error { + return c.store.Insert(ctx, db, req) +} + +func (c *StoreController) Delete(ctx context.Context, db DBTxConn, version int64) error { + return c.store.Delete(ctx, db, version) +} + +func (c *StoreController) GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) { + return c.store.GetMigration(ctx, db, version) +} + +func (c *StoreController) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) { + return c.store.GetLatestVersion(ctx, db) +} + +func (c *StoreController) ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error) { + return c.store.ListMigrations(ctx, db) +} diff --git a/database/dialect.go b/database/dialect.go index ca7d24cf2..e7755ccfc 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -51,13 +51,13 @@ func NewStore(dialect Dialect, tablename string) (Store, error) { } return &store{ tablename: tablename, - querier: querier, + querier: dialectquery.NewQueryController(querier), }, nil } type store struct { tablename string - querier dialectquery.Querier + querier *dialectquery.QueryController } var _ Store = (*store)(nil) @@ -137,3 +137,15 @@ func (s *store) ListMigrations( } return migrations, nil } + +func (s *store) TableExists(ctx context.Context, db DBTxConn, name string) (bool, error) { + q := s.querier.TableExists(s.tablename) + if q == "" { + return false, ErrNotSupported + } + var exists bool + if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil { + return false, fmt.Errorf("failed to check if table exists: %w", err) + } + return exists, nil +} diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index 482771aa1..4f6dde736 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -26,3 +26,45 @@ type Querier interface { // The query should return the version_id and is_applied columns. ListMigrations(tableName string) string } + +type QueryController struct { + querier Querier +} + +// NewQueryController returns a new QueryController that wraps the given Querier. +func NewQueryController(querier Querier) *QueryController { + return &QueryController{querier: querier} +} + +// Optional methods + +func (c *QueryController) TableExists(tableName string) string { + if t, ok := c.querier.(interface { + TableExists(string) string + }); ok { + return t.TableExists(tableName) + } + return "" +} + +// Default methods + +func (c *QueryController) CreateTable(tableName string) string { + return c.querier.CreateTable(tableName) +} + +func (c *QueryController) InsertVersion(tableName string) string { + return c.querier.InsertVersion(tableName) +} + +func (c *QueryController) DeleteVersion(tableName string) string { + return c.querier.DeleteVersion(tableName) +} + +func (c *QueryController) GetMigrationByVersion(tableName string) string { + return c.querier.GetMigrationByVersion(tableName) +} + +func (c *QueryController) ListMigrations(tableName string) string { + return c.querier.ListMigrations(tableName) +} diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index 5103390f4..e4e077105 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -36,3 +36,7 @@ func (p *Postgres) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` return fmt.Sprintf(q, tableName) } + +func (p *Postgres) TableExists(tableName string) string { + return `SELECT EXISTS ( SELECT FROM pg_tables WHERE tablename = $1)` +} diff --git a/provider.go b/provider.go index 24a9eb5a7..6c93938e4 100644 --- a/provider.go +++ b/provider.go @@ -24,7 +24,7 @@ type Provider struct { mu sync.Mutex db *sql.DB - store database.Store + store *database.StoreController fsys fs.FS cfg config @@ -143,7 +143,7 @@ func newProvider( db: db, fsys: fsys, cfg: cfg, - store: store, + store: database.NewStoreController(store), migrations: migrations, }, nil } diff --git a/provider_run.go b/provider_run.go index 4d0760181..12baa4db9 100644 --- a/provider_run.go +++ b/provider_run.go @@ -330,34 +330,20 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err } func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { - // existor is an interface that extends the Store interface with a method to check if the - // version table exists. This API is not stable and may change in the future. - type existor interface { - TableExists(context.Context, database.DBTxConn, string) (bool, error) + if ok, err := p.store.TableExists(ctx, conn); err != nil && !errors.Is(err, database.ErrNotSupported) { + return err + } else if ok { + return nil } - if e, ok := p.store.(existor); ok { - exists, err := e.TableExists(ctx, conn, p.store.Tablename()) - if err != nil { - return fmt.Errorf("failed to check if version table exists: %w", err) - } - if exists { - return nil - } - } else { - // feat(mf): this is where we can check if the version table exists instead of trying to fetch - // from a table that may not exist. https://github.com/pressly/goose/issues/461 - res, err := p.store.GetMigration(ctx, conn, 0) - if err == nil && res != nil { - return nil - } + // Fall back to the default behavior if the Store does not implement TableExists. + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil } return beginTx(ctx, conn, func(tx *sql.Tx) error { if err := p.store.CreateVersionTable(ctx, tx); err != nil { return err } - if p.cfg.disableVersioning { - return nil - } return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) }) } From 7863f54bd5408c92bc3d980b6c5c610e25696082 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 20 Apr 2024 22:42:42 -0400 Subject: [PATCH 2/3] docs --- internal/dialect/dialectquery/dialectquery.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index 4f6dde736..fc51490ad 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -1,27 +1,22 @@ package dialectquery -// Querier is the interface that wraps the basic methods to create a dialect -// specific query. +// Querier is the interface that wraps the basic methods to create a dialect specific query. type Querier interface { // CreateTable returns the SQL query string to create the db version table. CreateTable(tableName string) string - // InsertVersion returns the SQL query string to insert a new version into - // the db version table. + // InsertVersion returns the SQL query string to insert a new version into the db version table. InsertVersion(tableName string) string - // DeleteVersion returns the SQL query string to delete a version from - // the db version table. + // DeleteVersion returns the SQL query string to delete a version from the db version table. DeleteVersion(tableName string) string - // GetMigrationByVersion returns the SQL query string to get a single - // migration by version. + // GetMigrationByVersion returns the SQL query string to get a single migration by version. // // The query should return the timestamp and is_applied columns. GetMigrationByVersion(tableName string) string - // ListMigrations returns the SQL query string to list all migrations in - // descending order by id. + // ListMigrations returns the SQL query string to list all migrations in descending order by id. // // The query should return the version_id and is_applied columns. ListMigrations(tableName string) string @@ -38,6 +33,10 @@ func NewQueryController(querier Querier) *QueryController { // Optional methods +// TableExists returns the SQL query string to check if the version table exists. If the Querier +// does not implement this method, it will return an empty string. +// +// The query should return a boolean value. func (c *QueryController) TableExists(tableName string) string { if t, ok := c.querier.(interface { TableExists(string) string From b8700dd3674df3c1dd910c44e70b3c7852122148 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 20 Apr 2024 23:11:26 -0400 Subject: [PATCH 3/3] minor improvement --- database/controller.go | 7 +++-- database/dialect.go | 4 +-- internal/dialect/dialectquery/dialectquery.go | 2 ++ internal/dialect/dialectquery/postgres.go | 2 +- provider_run.go | 16 ++++++---- provider_run_test.go | 31 ++++++++++--------- 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/database/controller.go b/database/controller.go index 76d11a70f..063b48d9c 100644 --- a/database/controller.go +++ b/database/controller.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "errors" ) @@ -33,11 +34,11 @@ func NewStoreController(store Store) *StoreController { // TableExists is an optional method that checks if the version table exists in the database. It is // recommended to implement this method if the database supports it, as it can be used to optimize // certain operations. -func (c *StoreController) TableExists(ctx context.Context, db DBTxConn) (bool, error) { +func (c *StoreController) TableExists(ctx context.Context, db *sql.Conn) (bool, error) { if t, ok := c.store.(interface { - TableExists(context.Context, DBTxConn, string) (bool, error) + TableExists(ctx context.Context, db *sql.Conn) (bool, error) }); ok { - return t.TableExists(ctx, db, c.Tablename()) + return t.TableExists(ctx, db) } return false, ErrNotSupported } diff --git a/database/dialect.go b/database/dialect.go index e7755ccfc..63390c849 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -138,13 +138,13 @@ func (s *store) ListMigrations( return migrations, nil } -func (s *store) TableExists(ctx context.Context, db DBTxConn, name string) (bool, error) { +func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) { q := s.querier.TableExists(s.tablename) if q == "" { return false, ErrNotSupported } var exists bool - if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil { + if err := db.QueryRowContext(ctx, q, s.tablename).Scan(&exists); err != nil { return false, fmt.Errorf("failed to check if table exists: %w", err) } return exists, nil diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index fc51490ad..7d326f45b 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -22,6 +22,8 @@ type Querier interface { ListMigrations(tableName string) string } +var _ Querier = (*QueryController)(nil) + type QueryController struct { querier Querier } diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index e4e077105..37a6d103c 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -37,6 +37,6 @@ func (p *Postgres) ListMigrations(tableName string) string { return fmt.Sprintf(q, tableName) } -func (p *Postgres) TableExists(tableName string) string { +func (p *Postgres) TableExists(_ string) string { return `SELECT EXISTS ( SELECT FROM pg_tables WHERE tablename = $1)` } diff --git a/provider_run.go b/provider_run.go index 12baa4db9..5184029ab 100644 --- a/provider_run.go +++ b/provider_run.go @@ -330,16 +330,20 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err } func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { - if ok, err := p.store.TableExists(ctx, conn); err != nil && !errors.Is(err, database.ErrNotSupported) { + ok, err := p.store.TableExists(ctx, conn) + if err != nil && !errors.Is(err, database.ErrNotSupported) { return err - } else if ok { - return nil } - // Fall back to the default behavior if the Store does not implement TableExists. - res, err := p.store.GetMigration(ctx, conn, 0) - if err == nil && res != nil { + if ok { return nil } + if errors.Is(err, database.ErrNotSupported) { + // Fall back to the default behavior if the Store does not implement TableExists. + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } + } return beginTx(ctx, conn, func(tx *sql.Tx) error { if err := p.store.CreateVersionTable(ctx, tx); err != nil { return err diff --git a/provider_run_test.go b/provider_run_test.go index 914597f2b..32365d370 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -748,19 +748,6 @@ func TestGoMigrationPanic(t *testing.T) { check.Contains(t, expected.Err.Error(), wantErrString) } -func TestCustomStoreTableExists(t *testing.T) { - t.Parallel() - - store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename) - check.NoError(t, err) - p, err := goose.NewProvider("", newDB(t), newFsys(), - goose.WithStore(&customStoreSQLite3{store}), - ) - check.NoError(t, err) - _, err = p.Up(context.Background()) - check.NoError(t, err) -} - func TestProviderApply(t *testing.T) { t.Parallel() @@ -774,15 +761,29 @@ func TestProviderApply(t *testing.T) { check.HasError(t, err) check.Bool(t, errors.Is(err, goose.ErrNotApplied), true) } +func TestCustomStoreTableExists(t *testing.T) { + t.Parallel() + + store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename) + check.NoError(t, err) + p, err := goose.NewProvider("", newDB(t), newFsys(), + goose.WithStore(&customStoreSQLite3{store}), + ) + check.NoError(t, err) + _, err = p.Up(context.Background()) + check.NoError(t, err) + _, err = p.Up(context.Background()) + check.NoError(t, err) +} type customStoreSQLite3 struct { database.Store } -func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) { +func (s *customStoreSQLite3) TableExists(ctx context.Context, db *sql.Conn) (bool, error) { q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists` var exists bool - if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil { + if err := db.QueryRowContext(ctx, q, s.Tablename()).Scan(&exists); err != nil { return false, err } return exists, nil