dialect.go

  1package database
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8
  9	"github.com/pressly/goose/v3/internal/dialect/dialectquery"
 10)
 11
 12// Dialect is the type of database dialect.
 13type Dialect string
 14
 15const (
 16	DialectClickHouse Dialect = "clickhouse"
 17	DialectMSSQL      Dialect = "mssql"
 18	DialectMySQL      Dialect = "mysql"
 19	DialectPostgres   Dialect = "postgres"
 20	DialectRedshift   Dialect = "redshift"
 21	DialectSQLite3    Dialect = "sqlite3"
 22	DialectTiDB       Dialect = "tidb"
 23	DialectTurso      Dialect = "turso"
 24	DialectVertica    Dialect = "vertica"
 25	DialectYdB        Dialect = "ydb"
 26	DialectStarrocks  Dialect = "starrocks"
 27)
 28
 29// NewStore returns a new [Store] implementation for the given dialect.
 30func NewStore(dialect Dialect, tablename string) (Store, error) {
 31	if tablename == "" {
 32		return nil, errors.New("table name must not be empty")
 33	}
 34	if dialect == "" {
 35		return nil, errors.New("dialect must not be empty")
 36	}
 37	lookup := map[Dialect]dialectquery.Querier{
 38		DialectClickHouse: &dialectquery.Clickhouse{},
 39		DialectMSSQL:      &dialectquery.Sqlserver{},
 40		DialectMySQL:      &dialectquery.Mysql{},
 41		DialectPostgres:   &dialectquery.Postgres{},
 42		DialectRedshift:   &dialectquery.Redshift{},
 43		DialectSQLite3:    &dialectquery.Sqlite3{},
 44		DialectTiDB:       &dialectquery.Tidb{},
 45		DialectVertica:    &dialectquery.Vertica{},
 46		DialectYdB:        &dialectquery.Ydb{},
 47		DialectTurso:      &dialectquery.Turso{},
 48		DialectStarrocks:  &dialectquery.Starrocks{},
 49	}
 50	querier, ok := lookup[dialect]
 51	if !ok {
 52		return nil, fmt.Errorf("unknown dialect: %q", dialect)
 53	}
 54	return &store{
 55		tablename: tablename,
 56		querier:   dialectquery.NewQueryController(querier),
 57	}, nil
 58}
 59
 60type store struct {
 61	tablename string
 62	querier   *dialectquery.QueryController
 63}
 64
 65var _ Store = (*store)(nil)
 66
 67func (s *store) Tablename() string {
 68	return s.tablename
 69}
 70
 71func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
 72	q := s.querier.CreateTable(s.tablename)
 73	if _, err := db.ExecContext(ctx, q); err != nil {
 74		return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
 75	}
 76	return nil
 77}
 78
 79func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
 80	q := s.querier.InsertVersion(s.tablename)
 81	if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
 82		return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
 83	}
 84	return nil
 85}
 86
 87func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
 88	q := s.querier.DeleteVersion(s.tablename)
 89	if _, err := db.ExecContext(ctx, q, version); err != nil {
 90		return fmt.Errorf("failed to delete version %d: %w", version, err)
 91	}
 92	return nil
 93}
 94
 95func (s *store) GetMigration(
 96	ctx context.Context,
 97	db DBTxConn,
 98	version int64,
 99) (*GetMigrationResult, error) {
100	q := s.querier.GetMigrationByVersion(s.tablename)
101	var result GetMigrationResult
102	if err := db.QueryRowContext(ctx, q, version).Scan(
103		&result.Timestamp,
104		&result.IsApplied,
105	); err != nil {
106		if errors.Is(err, sql.ErrNoRows) {
107			return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
108		}
109		return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
110	}
111	return &result, nil
112}
113
114func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
115	q := s.querier.GetLatestVersion(s.tablename)
116	var version sql.NullInt64
117	if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
118		return -1, fmt.Errorf("failed to get latest version: %w", err)
119	}
120	if !version.Valid {
121		return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
122	}
123	return version.Int64, nil
124}
125
126func (s *store) ListMigrations(
127	ctx context.Context,
128	db DBTxConn,
129) ([]*ListMigrationsResult, error) {
130	q := s.querier.ListMigrations(s.tablename)
131	rows, err := db.QueryContext(ctx, q)
132	if err != nil {
133		return nil, fmt.Errorf("failed to list migrations: %w", err)
134	}
135	defer rows.Close()
136
137	var migrations []*ListMigrationsResult
138	for rows.Next() {
139		var result ListMigrationsResult
140		if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
141			return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
142		}
143		migrations = append(migrations, &result)
144	}
145	if err := rows.Err(); err != nil {
146		return nil, err
147	}
148	return migrations, nil
149}
150
151//
152//
153//
154// Additional methods that are not part of the core Store interface, but are extended by the
155// [controller.StoreController] type.
156//
157//
158//
159
160func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
161	q := s.querier.TableExists(s.tablename)
162	if q == "" {
163		return false, errors.ErrUnsupported
164	}
165	var exists bool
166	// Note, we do not pass the table name as an argument to the query, as the query should be
167	// pre-defined by the dialect.
168	if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil {
169		return false, fmt.Errorf("failed to check if table exists: %w", err)
170	}
171	return exists, nil
172}