store.go

  1package dialect
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"fmt"
  7	"time"
  8
  9	"github.com/pressly/goose/v3/internal/dialect/dialectquery"
 10)
 11
 12// Store is the interface that wraps the basic methods for a database dialect.
 13//
 14// A dialect is a set of SQL statements that are specific to a database.
 15//
 16// By defining a store interface, we can support multiple databases
 17// with a single codebase.
 18//
 19// The underlying implementation does not modify the error. It is the callers
 20// responsibility to assert for the correct error, such as sql.ErrNoRows.
 21type Store interface {
 22	// CreateVersionTable creates the version table within a transaction.
 23	// This table is used to store goose migrations.
 24	CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error
 25
 26	// InsertVersion inserts a version id into the version table within a transaction.
 27	InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
 28	// InsertVersionNoTx inserts a version id into the version table without a transaction.
 29	InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
 30
 31	// DeleteVersion deletes a version id from the version table within a transaction.
 32	DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
 33	// DeleteVersionNoTx deletes a version id from the version table without a transaction.
 34	DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
 35
 36	// GetMigrationRow retrieves a single migration by version id.
 37	//
 38	// Returns the raw sql error if the query fails. It is the callers responsibility
 39	// to assert for the correct error, such as sql.ErrNoRows.
 40	GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error)
 41
 42	// ListMigrations retrieves all migrations sorted in descending order by id.
 43	//
 44	// If there are no migrations, an empty slice is returned with no error.
 45	ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error)
 46}
 47
 48// NewStore returns a new Store for the given dialect.
 49func NewStore(d Dialect) (Store, error) {
 50	var querier dialectquery.Querier
 51	switch d {
 52	case Postgres:
 53		querier = &dialectquery.Postgres{}
 54	case Mysql:
 55		querier = &dialectquery.Mysql{}
 56	case Sqlite3:
 57		querier = &dialectquery.Sqlite3{}
 58	case Sqlserver:
 59		querier = &dialectquery.Sqlserver{}
 60	case Redshift:
 61		querier = &dialectquery.Redshift{}
 62	case Tidb:
 63		querier = &dialectquery.Tidb{}
 64	case Clickhouse:
 65		querier = &dialectquery.Clickhouse{}
 66	case Vertica:
 67		querier = &dialectquery.Vertica{}
 68	case Ydb:
 69		querier = &dialectquery.Ydb{}
 70	case Turso:
 71		querier = &dialectquery.Turso{}
 72	case Starrocks:
 73		querier = &dialectquery.Starrocks{}
 74	default:
 75		return nil, fmt.Errorf("unknown querier dialect: %v", d)
 76	}
 77	return &store{querier: querier}, nil
 78}
 79
 80type GetMigrationResult struct {
 81	IsApplied bool
 82	Timestamp time.Time
 83}
 84
 85type ListMigrationsResult struct {
 86	VersionID int64
 87	IsApplied bool
 88}
 89
 90type store struct {
 91	querier dialectquery.Querier
 92}
 93
 94var _ Store = (*store)(nil)
 95
 96func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error {
 97	q := s.querier.CreateTable(tableName)
 98	_, err := tx.ExecContext(ctx, q)
 99	return err
100}
101
102func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
103	q := s.querier.InsertVersion(tableName)
104	_, err := tx.ExecContext(ctx, q, version, true)
105	return err
106}
107
108func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
109	q := s.querier.InsertVersion(tableName)
110	_, err := db.ExecContext(ctx, q, version, true)
111	return err
112}
113
114func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
115	q := s.querier.DeleteVersion(tableName)
116	_, err := tx.ExecContext(ctx, q, version)
117	return err
118}
119
120func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
121	q := s.querier.DeleteVersion(tableName)
122	_, err := db.ExecContext(ctx, q, version)
123	return err
124}
125
126func (s *store) GetMigration(
127	ctx context.Context,
128	db *sql.DB,
129	tableName string,
130	version int64,
131) (*GetMigrationResult, error) {
132	q := s.querier.GetMigrationByVersion(tableName)
133	var timestamp time.Time
134	var isApplied bool
135	err := db.QueryRowContext(ctx, q, version).Scan(&timestamp, &isApplied)
136	if err != nil {
137		return nil, err
138	}
139	return &GetMigrationResult{
140		IsApplied: isApplied,
141		Timestamp: timestamp,
142	}, nil
143}
144
145func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) {
146	q := s.querier.ListMigrations(tableName)
147	rows, err := db.QueryContext(ctx, q)
148	if err != nil {
149		return nil, err
150	}
151	defer rows.Close()
152
153	var migrations []*ListMigrationsResult
154	for rows.Next() {
155		var version int64
156		var isApplied bool
157		if err := rows.Scan(&version, &isApplied); err != nil {
158			return nil, err
159		}
160		migrations = append(migrations, &ListMigrationsResult{
161			VersionID: version,
162			IsApplied: isApplied,
163		})
164	}
165	if err := rows.Err(); err != nil {
166		return nil, err
167	}
168	return migrations, nil
169}