1package database
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/pressly/goose/v3/internal/dialect/dialectquery"
10)
11
12// Dialect is the type of database dialect.
13type Dialect string
14
15const (
16 DialectClickHouse Dialect = "clickhouse"
17 DialectMSSQL Dialect = "mssql"
18 DialectMySQL Dialect = "mysql"
19 DialectPostgres Dialect = "postgres"
20 DialectRedshift Dialect = "redshift"
21 DialectSQLite3 Dialect = "sqlite3"
22 DialectTiDB Dialect = "tidb"
23 DialectTurso Dialect = "turso"
24 DialectVertica Dialect = "vertica"
25 DialectYdB Dialect = "ydb"
26 DialectStarrocks Dialect = "starrocks"
27)
28
29// NewStore returns a new [Store] implementation for the given dialect.
30func NewStore(dialect Dialect, tablename string) (Store, error) {
31 if tablename == "" {
32 return nil, errors.New("table name must not be empty")
33 }
34 if dialect == "" {
35 return nil, errors.New("dialect must not be empty")
36 }
37 lookup := map[Dialect]dialectquery.Querier{
38 DialectClickHouse: &dialectquery.Clickhouse{},
39 DialectMSSQL: &dialectquery.Sqlserver{},
40 DialectMySQL: &dialectquery.Mysql{},
41 DialectPostgres: &dialectquery.Postgres{},
42 DialectRedshift: &dialectquery.Redshift{},
43 DialectSQLite3: &dialectquery.Sqlite3{},
44 DialectTiDB: &dialectquery.Tidb{},
45 DialectVertica: &dialectquery.Vertica{},
46 DialectYdB: &dialectquery.Ydb{},
47 DialectTurso: &dialectquery.Turso{},
48 DialectStarrocks: &dialectquery.Starrocks{},
49 }
50 querier, ok := lookup[dialect]
51 if !ok {
52 return nil, fmt.Errorf("unknown dialect: %q", dialect)
53 }
54 return &store{
55 tablename: tablename,
56 querier: dialectquery.NewQueryController(querier),
57 }, nil
58}
59
60type store struct {
61 tablename string
62 querier *dialectquery.QueryController
63}
64
65var _ Store = (*store)(nil)
66
67func (s *store) Tablename() string {
68 return s.tablename
69}
70
71func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
72 q := s.querier.CreateTable(s.tablename)
73 if _, err := db.ExecContext(ctx, q); err != nil {
74 return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
75 }
76 return nil
77}
78
79func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
80 q := s.querier.InsertVersion(s.tablename)
81 if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
82 return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
83 }
84 return nil
85}
86
87func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
88 q := s.querier.DeleteVersion(s.tablename)
89 if _, err := db.ExecContext(ctx, q, version); err != nil {
90 return fmt.Errorf("failed to delete version %d: %w", version, err)
91 }
92 return nil
93}
94
95func (s *store) GetMigration(
96 ctx context.Context,
97 db DBTxConn,
98 version int64,
99) (*GetMigrationResult, error) {
100 q := s.querier.GetMigrationByVersion(s.tablename)
101 var result GetMigrationResult
102 if err := db.QueryRowContext(ctx, q, version).Scan(
103 &result.Timestamp,
104 &result.IsApplied,
105 ); err != nil {
106 if errors.Is(err, sql.ErrNoRows) {
107 return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
108 }
109 return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
110 }
111 return &result, nil
112}
113
114func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
115 q := s.querier.GetLatestVersion(s.tablename)
116 var version sql.NullInt64
117 if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
118 return -1, fmt.Errorf("failed to get latest version: %w", err)
119 }
120 if !version.Valid {
121 return -1, fmt.Errorf("latest %w", ErrVersionNotFound)
122 }
123 return version.Int64, nil
124}
125
126func (s *store) ListMigrations(
127 ctx context.Context,
128 db DBTxConn,
129) ([]*ListMigrationsResult, error) {
130 q := s.querier.ListMigrations(s.tablename)
131 rows, err := db.QueryContext(ctx, q)
132 if err != nil {
133 return nil, fmt.Errorf("failed to list migrations: %w", err)
134 }
135 defer rows.Close()
136
137 var migrations []*ListMigrationsResult
138 for rows.Next() {
139 var result ListMigrationsResult
140 if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
141 return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
142 }
143 migrations = append(migrations, &result)
144 }
145 if err := rows.Err(); err != nil {
146 return nil, err
147 }
148 return migrations, nil
149}
150
151//
152//
153//
154// Additional methods that are not part of the core Store interface, but are extended by the
155// [controller.StoreController] type.
156//
157//
158//
159
160func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
161 q := s.querier.TableExists(s.tablename)
162 if q == "" {
163 return false, errors.ErrUnsupported
164 }
165 var exists bool
166 // Note, we do not pass the table name as an argument to the query, as the query should be
167 // pre-defined by the dialect.
168 if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil {
169 return false, fmt.Errorf("failed to check if table exists: %w", err)
170 }
171 return exists, nil
172}