1package dialect
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "time"
8
9 "github.com/pressly/goose/v3/internal/dialect/dialectquery"
10)
11
12// Store is the interface that wraps the basic methods for a database dialect.
13//
14// A dialect is a set of SQL statements that are specific to a database.
15//
16// By defining a store interface, we can support multiple databases
17// with a single codebase.
18//
19// The underlying implementation does not modify the error. It is the callers
20// responsibility to assert for the correct error, such as sql.ErrNoRows.
21type Store interface {
22 // CreateVersionTable creates the version table within a transaction.
23 // This table is used to store goose migrations.
24 CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error
25
26 // InsertVersion inserts a version id into the version table within a transaction.
27 InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
28 // InsertVersionNoTx inserts a version id into the version table without a transaction.
29 InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
30
31 // DeleteVersion deletes a version id from the version table within a transaction.
32 DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
33 // DeleteVersionNoTx deletes a version id from the version table without a transaction.
34 DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
35
36 // GetMigrationRow retrieves a single migration by version id.
37 //
38 // Returns the raw sql error if the query fails. It is the callers responsibility
39 // to assert for the correct error, such as sql.ErrNoRows.
40 GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error)
41
42 // ListMigrations retrieves all migrations sorted in descending order by id.
43 //
44 // If there are no migrations, an empty slice is returned with no error.
45 ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error)
46}
47
48// NewStore returns a new Store for the given dialect.
49func NewStore(d Dialect) (Store, error) {
50 var querier dialectquery.Querier
51 switch d {
52 case Postgres:
53 querier = &dialectquery.Postgres{}
54 case Mysql:
55 querier = &dialectquery.Mysql{}
56 case Sqlite3:
57 querier = &dialectquery.Sqlite3{}
58 case Sqlserver:
59 querier = &dialectquery.Sqlserver{}
60 case Redshift:
61 querier = &dialectquery.Redshift{}
62 case Tidb:
63 querier = &dialectquery.Tidb{}
64 case Clickhouse:
65 querier = &dialectquery.Clickhouse{}
66 case Vertica:
67 querier = &dialectquery.Vertica{}
68 case Ydb:
69 querier = &dialectquery.Ydb{}
70 case Turso:
71 querier = &dialectquery.Turso{}
72 case Starrocks:
73 querier = &dialectquery.Starrocks{}
74 default:
75 return nil, fmt.Errorf("unknown querier dialect: %v", d)
76 }
77 return &store{querier: querier}, nil
78}
79
80type GetMigrationResult struct {
81 IsApplied bool
82 Timestamp time.Time
83}
84
85type ListMigrationsResult struct {
86 VersionID int64
87 IsApplied bool
88}
89
90type store struct {
91 querier dialectquery.Querier
92}
93
94var _ Store = (*store)(nil)
95
96func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error {
97 q := s.querier.CreateTable(tableName)
98 _, err := tx.ExecContext(ctx, q)
99 return err
100}
101
102func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
103 q := s.querier.InsertVersion(tableName)
104 _, err := tx.ExecContext(ctx, q, version, true)
105 return err
106}
107
108func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
109 q := s.querier.InsertVersion(tableName)
110 _, err := db.ExecContext(ctx, q, version, true)
111 return err
112}
113
114func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
115 q := s.querier.DeleteVersion(tableName)
116 _, err := tx.ExecContext(ctx, q, version)
117 return err
118}
119
120func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
121 q := s.querier.DeleteVersion(tableName)
122 _, err := db.ExecContext(ctx, q, version)
123 return err
124}
125
126func (s *store) GetMigration(
127 ctx context.Context,
128 db *sql.DB,
129 tableName string,
130 version int64,
131) (*GetMigrationResult, error) {
132 q := s.querier.GetMigrationByVersion(tableName)
133 var timestamp time.Time
134 var isApplied bool
135 err := db.QueryRowContext(ctx, q, version).Scan(×tamp, &isApplied)
136 if err != nil {
137 return nil, err
138 }
139 return &GetMigrationResult{
140 IsApplied: isApplied,
141 Timestamp: timestamp,
142 }, nil
143}
144
145func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) {
146 q := s.querier.ListMigrations(tableName)
147 rows, err := db.QueryContext(ctx, q)
148 if err != nil {
149 return nil, err
150 }
151 defer rows.Close()
152
153 var migrations []*ListMigrationsResult
154 for rows.Next() {
155 var version int64
156 var isApplied bool
157 if err := rows.Scan(&version, &isApplied); err != nil {
158 return nil, err
159 }
160 migrations = append(migrations, &ListMigrationsResult{
161 VersionID: version,
162 IsApplied: isApplied,
163 })
164 }
165 if err := rows.Err(); err != nil {
166 return nil, err
167 }
168 return migrations, nil
169}